In [31]:
# default_exp metrics
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# Metrics

> Metrics for reinforcement learning

In [33]:
# export
from fastai.callback import *
from fastai.basic_train import *
from fastai.core import *
from fastai.torch_core import *
from dataclasses import dataclass
import torch.multiprocessing as mp
import logging

logging.basicConfig(format='[%(asctime)s] p%(process)s line:%(lineno)d %(levelname)s - %(message)s',
                    datefmt='%m-%d %H:%M:%S')
logging.getLogger('fastrl.data_block').setLevel('CRITICAL')
_logger=logging.getLogger(__name__)

In [34]:
# hide
from fastrl.data_block import *
from fastrl.basic_agents import *
from fastrl.basic_train import *
import sys

_logger.setLevel('INFO')

In [35]:
# export
@dataclass
class TotalRewards(object):
    rewards:float

class RewardMetric(LearnerCallback):
    _order=-20

    def on_train_begin(self, **kwargs):
        metric_names = ['train_reward'] if self.learn.recorder.no_val or self.learn.data.empty_val else ['train_reward', 'valid_reward']
        self.learn.recorder.add_metric_names(metric_names)
        for ds in [self.learn.data.train_ds,None if self.learn.data.empty_val else self.learn.data.valid_ds]:
            if hasattr(ds,'metric_queue') and ds.metric_queue is None:
                ds.metric_queue=mp.JoinableQueue(ds.queue_sz*len(ds)) # Make sure this queue has more space to prevent locking
                

    def on_epoch_end(self,last_metrics,**kwargs: Any):
        rewards=[]
        for ds in [self.learn.data.train_ds,None if self.learn.data.empty_val else self.learn.data.valid_ds]:
            if ds is None:continue
            rs=[]
            if hasattr(ds,'metric_queue'): 
                if ds.metric_queue is not None:
                    while not ds.metric_queue.empty():
                        rs.append(ds.metric_queue.get().rewards)
            else:rs=ds.pop_total_rewards()
            rewards.append(np.mean(rs))
        return add_metrics(last_metrics,rewards)

In [36]:
# @safe_fit
def dqn_grad_fitter_2(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset,grad_queue:mp.JoinableQueue,loss_queue:mp.JoinableQueue,
                      pause_event:mp.Event,cancel_event:mp.Event,metric_queue:mp.JoinableQueue=None):
    dataset=ds()
    while not cancel_event.is_set(): 
        for xb,yb in dataset:
            sys.stdout.flush()
            grad_queue.put(xb)
            loss_queue.put(0.5)
            if pause_event.is_set():cancel_event.wait(0.1) 
            if cancel_event.is_set():break
        if metric_queue is not None:
            rs=dataset.pop_total_rewards()
            if len(rs)!=0:metric_queue.put(TotalRewards(np.mean(rs)))
        if cancel_event.is_set():break

In [37]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',data_exp=False,display=False,firstlast=True,add_valid=False,n_processes=2,n_envs=2,queue_sz=400)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback,RewardMetric])
setattr(learn,'fitter',dqn_grad_fitter_2)
learn.fit(10,lr=0.01,wd=1)

epoch,train_loss,valid_loss,train_reward,time
0,0.5,#na#,21.609756,00:00
1,0.499999,#na#,20.36,00:00
2,0.499999,#na#,20.416667,00:00
3,0.499999,#na#,23.190476,00:00
4,0.499999,#na#,26.1,00:00
5,0.499999,#na#,19.8,00:00
6,0.499999,#na#,18.481481,00:00
7,0.499999,#na#,23.5,00:00
8,0.499999,#na#,26.363636,00:00
9,0.499999,#na#,21.391304,00:00




In [38]:
# hide
from nbdev.export import *
notebook2script()
notebook2html(n_workers=0)

Converted 00_core.ipynb.
Converted 01_wrappers.ipynb.
Converted 02_callbacks.ipynb.
Converted 03_basic_agents.ipynb.
Converted 04_metrics.ipynb.
Converted 05_data_block.ipynb.
Converted 06_basic_train.ipynb.
Converted 12_a3c.a3c_data.ipynb.
Converted Untitled.ipynb.
Converted index.ipynb.
Converted notes.ipynb.


converting: /opt/project/fastrl/nbs/04_metrics.ipynb
converting: /opt/project/fastrl/nbs/05_data_block.ipynb
