In [1]:
# default_exp data_block
%load_ext autoreload
%autoreload 2

In [2]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# Data Block

> Primary handlers for interfacing the openai gym envs

In [32]:
# export
from fastai.basic_data import *
from fastai.basic_train import *
from fastai.torch_core import *
from fastai.callbacks import *
from fastrl.wrappers import *
from fastrl.basic_agents import *
from fastrl.metrics import *
from dataclasses import asdict,dataclass,field
from functools import partial
from fastprogress.fastprogress import IN_NOTEBOOK
from fastcore.utils import *
from fastcore.foundation import *
from torch.utils.data.dataset import IterableDataset
import torch.multiprocessing as mp
from functools import wraps
from queue import Empty
import textwrap
import logging
import gym

logging.basicConfig(format='[%(asctime)s] p%(process)s line:%(lineno)d %(levelname)s - %(message)s',
                    datefmt='%m-%d %H:%M:%S')
_logger=logging.getLogger(__name__)

In [4]:
# hide
_logger.setLevel('INFO')
import sys

## Dataset

`Dataset` instances are going to be a little different from the typically classification dataset that you might use in pytorch. Commonly, datasets have:
- A known size to iter through
- Maintain their state during the training sequence
- Randomly sample their dataset
- Have a common `x`/`y` or `input`/`target` data format

For our `ExperienceSourceDataset`, most of this is going to be different. 
- We can have multiple sources (envs)

You could think of a traditional dataset approach as being a mix of a `ExperienceSourceDataset` and a form of `ExperienceReplay`.

In [5]:
# export
@dataclass
class Experience(object):
    s:np.array
    sp:np.array
    a:np.array
    r:np.array
    d:np.array
    agent_s:np.array
        
    @property
    def x(self):return self.s.copy()
    @x.setter
    def x(self,v):self.s=v.copy()

add_docs(Experience,x='Should return the field for `xb` in the training loop. It must be copied on return or'
         ' else there will be strange multiple reference errors.'
         'This is intended to be fed directly into a model. The `self.s.copy()` is a single tensor to be directly fed into a regular nn.')

In [6]:
exp=Experience(s=np.random.randint(1,3,(5,5)),
               sp=np.random.randint(1,3,(5,5)),
               a=np.random.randint(1,3,(1,2)),
               r=np.random.randint(1,3,(1,20)),
               d=np.random.randint(0,1,(5,5)),
               agent_s=np.random.randint(0,6,(1,5)))
asdict(exp)

{'s': array([[2, 1, 1, 2, 2],
        [2, 1, 1, 1, 2],
        [2, 2, 2, 2, 2],
        [1, 2, 1, 1, 2],
        [2, 2, 2, 1, 2]]),
 'sp': array([[2, 2, 1, 1, 2],
        [2, 1, 1, 1, 1],
        [1, 2, 1, 1, 2],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 1]]),
 'a': array([[1, 2]]),
 'r': array([[1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 2, 1, 2]]),
 'd': array([[0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]]),
 'agent_s': array([[5, 4, 4, 3, 3]])}

## Datasets

We have a couple constraints on what the `ExperienceSourceDataset` is going to look like. For `__getitem__` it should return a tuple of 2 elements:

- `xb` which are the elements that are going to be directly fed into the model for getting actions.
- `yb` which are all of the elements provided by the environment that we can use for training.

Some semantics:
- `iter` means what is returned from `next(dataset)` for farthur what is returned from `[o for o in dataset]`.
- `step` means an individual step in the env. You can have multiple steps per `iter`.
- `batch_size` is the same thing as `step`.

In [73]:
# export
class ExperienceSourceCallback(LearnerCallback):
    def on_train_begin(self,*args,**kwargs):
        self.learn.data.train_dl.dataset.learn=self.learn
        if not self.learn.data.empty_val:
            self.learn.data.valid_dl.dataset.learn=self.learn

@dataclass
class ExperienceSourceDataset(IterableDataset):
    "Similar to fastai's `LabelList`, iters in-order samples from `1->len(envs)` `envs`."
    env:str
    n_envs:int=1
    learn:Optional[Learner]=None
    callback_fns:List[LearnerCallback]=field(default_factory=lambda:[ExperienceSourceCallback])
    # Behavior modification fields / param fields.
    skip_n_steps:int=1
    max_steps:Optional[int]=None
    pixels:bool=False
    # Env tracking fields
    per_env_d:np.array=None
    per_env_s:np.array=None
    per_env_steps:np.array=None
    per_env_rewards:np.array=None
    # Metric fields
    total_rewards:List=field(default_factory=list)
    total_steps:List=field(default_factory=list)
    # Private funcational fields.
    _env_idx:int=0
    _warned:bool=False
    _end_interation:bool=False
    _getitem_as_exp:bool=False
    
    
#     def __init__(self,env:str,n_envs=1,skip_n_steps=1,max_steps=None,pixels=False):
    def __post_init__(self):
        def make_env():
            env_o=gym.make(self.env)
            if self.pixels:env_o.reset()
            return env_o
            
        self.envs=[make_env() for _ in range(self.n_envs)]
        if self.pixels:self.envs=[PixelObservationWrapper(e) for e in self.envs]
            
        self.per_env_d:np.array=np.zeros((len(self.envs),))+1
        self.per_env_s:np.array=np.zeros((len(self.envs),*self.envs[0].observation_space.sample().shape))
        self.per_env_steps:np.array=np.zeros((len(self.envs)))
        self.per_env_rewards:np.array=np.zeros((len(self.envs)))
        
    def __len__(self): return ifnone(self.max_steps,self.envs[0].spec.max_episode_steps)
    @log_args
    def is_last_step(self,d,steps,idx,max_steps)->bool:return bool(d) or steps>=max_steps
    @log_args
    def at_start(self,idx,steps):return idx==0 and steps==0
    @log_args
    def skip_loop(self,steps,skip_steps,d):return steps%skip_steps!=0 and not d
    @log_args
    def cycle_env(self,idx,n_envs):self._env_idx=0 if idx==n_envs-1 else idx+1
    def stop_loop(self):
        self._end_interation=False
        raise StopIteration()
    @log_args
    def reset_envs(self,idx):
        for i,e in enumerate(self.envs):self.per_env_s[i]=e.reset() # TODO: There is the possiblity this will equal None (and crash)?
        self.per_env_steps=np.zeros((len(self.envs)),dtype=int)
        self.per_env_rewards=np.zeros((len(self.envs)))
    @log_args
    def pick_action(self,idx):
        if self.learn is None:
            if not self._warned:_logger.warning('`self.learn` is None. will use random actions instead.')
            self._warned=True
            return self.envs[0].action_space.sample(),np.zeros((1,1))
        return self.learn.predict(self.per_env_s[idx])
    
    def pop_total_rewards(self):
        total_rewards=self.total_rewards
        if total_rewards:self.total_rewards,self.total_steps=[],[]
        return total_rewards

    def pop_reward_steps(self):
        total_rewards,total_steps=self.total_rewards,self.total_steps
        res=list(zip(total_rewards,total_steps))
        if res:self.total_rewards,self.total_steps=[],[]
        return res
    
    def __iter__(self):
        while True:
            try:
                o=self.__getitem__(None)
                yield o
            except StopIteration:
                return
    
    @log_args
    def __getitem__(self,_):
        idx=self._env_idx  # This is the current env that we are running on
        if self._end_interation:self.stop_loop()
        if self.at_start(idx,self.per_env_steps[idx]):self.reset_envs(idx)
        
        while True:
            a,agent_s=self.pick_action(idx)
            sp,r,d,_=self.envs[idx].step(a)
            
            exp=Experience(self.per_env_s[idx],sp,a,r,d,agent_s=ifnone(agent_s,[]))
            
            self.per_env_rewards[idx]+=r
            self.per_env_s[idx]=sp
            self.per_env_steps[idx]+=1
            self.per_env_d[idx]=d
            
            if self.skip_loop(self.per_env_steps[idx],self.skip_n_steps,d):continue
            
            if self.is_last_step(self.per_env_d[idx],self.per_env_steps[idx],idx,len(self)):
                self.total_rewards.append(self.per_env_rewards[idx])
                self.total_steps.append(self.per_env_steps[idx])
                self.per_env_steps[idx]=0
                
                self.cycle_env(idx,len(self.envs))
                self._end_interation=True
                
            if self._getitem_as_exp:return exp
            return exp.x,asdict(exp) # Must copy or else torch keeps only last tensor

add_docs(ExperienceSourceDataset,
         __init__='Provides an easy interface to iterate through an env or list of envs.\n Some importants notes:\n'
                  'an iteration through a loop. This is the maximum steps, and may be less due to the environment ending early.\n'
                  '- `skip_n_steps` are the number of steps to skip in the returned elements. This can be seen as frame skipping.',
        pick_action='Returns the action and learner\'s `self.learn.model`\'s state as determined by the learner for a state `self.s`'  
                    'belonging to env `idx`. If the `self.learn` is None, a random action using the `action_space` from `self.envs[0]` is used.\n\n'
                    'For the sake of clarity, the return type is Tuple[Tensor,Tensor] which can be understood as [a,agent_s] or [action, agent state].\n\n'
                    'While the returned action like `1`,`2` if discrete and `[0.2,0.5,0.2,0.1]` is a continuous action output, the agent state\n'
                    'is simply the raw values are were used to get the action. This information can be useful where for example we have an agent playing\n'
                    'the cartpole game. The action can either be `0` or `1`. This is considered the "action".\n'
                    'The agent state is the result of `self.learn.model(self.s)`. This is the raw `nn.Module` output and in the cartpole env case, is likely a\n'
                    'tensor of `Tensor([0.35,0.75])`, which for a discrete agent would an action of `1`. (we do an argmax, thus the action is which ever \n'
                    'neuron has the highest expected reward. This can be used to determine how confident the agent was when taking an action.',
        is_last_step='An env has reached it\'s last step when it is either `d` is true or `steps` is more than or equal to `self.max_steps-1`. '
                     'The reason this method has so many parameters is due to the use of `log_args` decorator. Having these all these parameters passed on makes '
                     'debugging much easier.',
        cycle_env='Increments the private `self._step` field to cyucle through the envs. Requires passing in parameters for easy debugging using `log_args`.',
        stop_loop='Raises the `StopIteration` exception and resets the `self._end_interation=False`. Used when the current env is done.',
        reset_envs='Performs a very interesting function for reseting all the envs. One might wonder "why not reset the envs individually?". The reason\n '
                   'is that performing a reset within a single process for most `gym` envs actually resets the overall renderer. This means that if you reset\n '
                   'one env, it affects all the others that haven\'t finished yet. In order to avoid this issue,\n '
                   'we only reset when all the envs are ready to be reset. The `idx` is here only for debugging.',
        at_start='Similar to `is_last_step`. We determine that the dataset as looped through all the envs and needs to reset them.\n '
                 'We have all the important params passed in for debugging via `log_args`.',
        pop_total_rewards='Returns and clears the `total_rewards` fields. Each element of total reward is a single episode or full iteration through an env.',
        pop_reward_steps='Returns and clears the `total_rewards` and `total_steps` fields. Each element in each represents data over a single episode.',
        skip_loop='If the current step should be skipped per `self.skip-step` and the current step is not done then we have the loop do a pass over.')

In [46]:
ds=ExperienceSourceDataset("CartPole-v1",n_envs=15,max_steps=50,skip_n_steps=1)
bs=10

dl=DataLoader(ds,batch_size=bs,num_workers=0)
final_yb=None
data=[]
for xb,yb in dl:
    assert len(xb)<=bs
    assert len(yb['s'])<=bs
    assert len(yb['d'])<=bs
    final_yb=yb
    data.append(xb)
assert bool(final_yb['d'][-1])
if len(data)>1: # If there is more that 1 attempted batches, all the previous ones should be the correct size. The last one might be shorter
    assert all(len(arr)==bs for arr in data[:-1])



In [47]:
ds=ExperienceSourceDataset("MountainCar-v0",n_envs=15,skip_n_steps=1)
bs=10
mountain_car_steps=200

dl=DataLoader(ds,batch_size=bs,num_workers=0)
final_yb=None
data=[]
for xb,yb in dl:
    assert len(xb)<=bs
    assert len(yb['s'])<=bs
    assert len(yb['d'])<=bs
    final_yb=yb
    data.append(xb)

test_eq(sum([len(o) for o in data]),mountain_car_steps)
# assert bool(final_yb['d'][-1]),final_yb['d'][-1] MountainCar will likely not be done without some intervention.
if len(data)>1: # If there is more that 1 attempted batches, all the previous ones should be the correct size. The last one might be shorter
    assert all(len(arr)==bs for arr in data[:-1])
data.clear()

r=ds.pop_total_rewards()[0]
test_eq(r,-200)



In [48]:
ds=ExperienceSourceDataset("MountainCar-v0",n_envs=15,skip_n_steps=2)
bs=10
mountain_car_steps=200//2 # If we skip 2 steps, this should be 100

dl=DataLoader(ds,batch_size=bs,num_workers=0)
final_yb=None
data=[]
for xb,yb in dl:
    assert len(xb)<=bs
    assert len(yb['s'])<=bs
    assert len(yb['d'])<=bs
    final_yb=yb
    data.append(xb)

test_eq(sum([len(o) for o in data]),mountain_car_steps)
# assert bool(final_yb['d'][-1]),final_yb['d'][-1] MountainCar will likely not be done without some intervention.
if len(data)>1: # If there is more that 1 attempted batches, all the previous ones should be the correct size. The last one might be shorter
    assert all(len(arr)==bs for arr in data[:-1])
data.clear()

r=ds.pop_total_rewards()[0]
test_eq(r,-200) # r should still be -200



In [49]:
ds=ExperienceSourceDataset("MountainCar-v0",n_envs=15,skip_n_steps=1,pixels=True)
bs=10
mountain_car_steps=200

dl=DataLoader(ds,batch_size=bs,num_workers=0)
final_yb=None
data=[]
for xb,yb in dl:
    assert len(xb)<=bs
    assert len(yb['s'])<=bs
    assert len(yb['d'])<=bs
    final_yb=yb
    data.append(xb)

    
test_eq(data[0].shape,(bs,400,600,3))
test_eq(sum([len(o) for o in data]),mountain_car_steps)
# assert bool(final_yb['d'][-1]),final_yb['d'][-1] MountainCar will likely not be done without some intervention.
if len(data)>1: # If there is more that 1 attempted batches, all the previous ones should be the correct size. The last one might be shorter
    assert all(len(arr)==bs for arr in data[:-1])
data.clear()

r,steps=ds.pop_reward_steps()[0]
test_eq(r,-200)
test_eq(steps,200)
test_eq(len(ds.pop_reward_steps()),0)



In [50]:
# export
class FirstLastExperienceSourceDataset(ExperienceSourceDataset):
    "Similar to `ExperienceSourceDataset` but only keeps the first and last parts of a step. Can be seen as frame skipping."
    def __init__(self,*args,discount=0.99,skip_n_steps=1,**kwargs):
        super(FirstLastExperienceSourceDataset,self).__init__(*args,skip_n_steps=1,**kwargs)
        self.discount=discount
        self.fst_lst_steps=skip_n_steps
        self._getitem_as_exp=True
    
    @log_args
    def __getitem__(self,_):
        exp_ls=[]
        while True:
            exp=super(FirstLastExperienceSourceDataset,self).__getitem__(_)
            exp_ls.append(exp)
            if len(exp_ls)>=self.fst_lst_steps or self._end_interation:break

        exp=exp_ls[-1]
        exp.x=exp_ls[0].x

        total_reward=0.0
        for e in reversed(exp_ls):
            total_reward*=self.discount
            total_reward+=e.r
        exp.r=total_reward

        return exp.x,asdict(exp)

In [51]:
import pytest
ds=FirstLastExperienceSourceDataset("CartPole-v1",n_envs=15,skip_n_steps=4)
bs=10

dl=DataLoader(ds,batch_size=bs,num_workers=0)
final_yb=None
data=[]
for xb,yb in dl:
    assert len(xb)<=bs
    assert len(yb['s'])<=bs
    assert len(yb['d'])<=bs
    test_eq(yb['r'][0],pytest.approx(3.99,0.1))
    final_yb=yb
    data.append(xb)

if len(data)>1: # If there is more that 1 attempted batches, all the previous ones should be the correct size. The last one might be shorter
    assert all(len(arr)==bs for arr in data[:-1])
data.clear()



In [52]:
ds=FirstLastExperienceSourceDataset("MountainCar-v0",n_envs=15,skip_n_steps=1)
bs=10
mountain_car_steps=200

dl=DataLoader(ds,batch_size=bs,num_workers=0)
final_yb=None
data=[]
for xb,yb in dl:
    assert len(xb)<=bs
    assert len(yb['s'])<=bs
    assert len(yb['d'])<=bs
    final_yb=yb
    data.append(xb)

test_eq(sum([len(o) for o in data]),mountain_car_steps) # Should be half since first last by default merges 2 steps
# assert bool(final_yb['d'][-1]),final_yb['d'][-1] MountainCar will likely not be done without some intervention.
if len(data)>1: # If there is more that 1 attempted batches, all the previous ones should be the correct size. The last one might be shorter
    assert all(len(arr)==bs for arr in data[:-1])
data.clear()

r=ds.pop_total_rewards()[0]
test_eq(r,-200)



## ExperienceSourceDataBunch

In [53]:
# export
class ExperienceSourceDataBunch(DataBunch):
    @classmethod
    def from_env(cls,env:str,n_envs=1,firstlast=False,display=False,max_steps=None,skip_n_steps=1,path:PathOrStr='.',add_valid=True,
                 cols=1,rows=1,max_w=800,bs=1):
        def create_ds(make_empty=False):
            _ds_cls=FirstLastExperienceSourceDataset if firstlast else ExperienceSourceDataset
            _ds=_ds_cls(env,max_steps=0 if make_empty else max_steps,skip_n_steps=skip_n_steps)
            if display:_ds=DatasetDisplayWrapper(_ds,cols=cols,rows=rows,max_w=max_w)
            return _ds
            
        dss=(create_ds(),create_ds(not add_valid))
        return cls.create(*dss,bs=bs,num_workers=0)
    
    @classmethod
    def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64,
               val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,
               device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, **dl_kwargs)->'DataBunch':
        "Create a `DataBunch` from `train_ds`, `valid_ds` and maybe `test_ds` with a batch size of `bs`. Passes `**dl_kwargs` to `DataLoader()`"
        datasets = cls._init_ds(train_ds, valid_ds, test_ds)
        val_bs = ifnone(val_bs, bs)
        dls = [DataLoader(d, b, shuffle=s, drop_last=s, num_workers=num_workers, **dl_kwargs) for d,b,s in
               zip(datasets, (bs,val_bs,val_bs,val_bs), (False,False,False,False)) if d is not None]
        return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)

In [54]:
data=ExperienceSourceDataBunch.from_env('CartPole-v1',n_envs=5,display=False,firstlast=False,add_valid=False,bs=5)
for xb,yb in data.train_dl:
    assert xb.shape[0]<=5



In [55]:
data=ExperienceSourceDataBunch.from_env('CartPole-v1',n_envs=5,display=False,firstlast=True,add_valid=False,bs=5)
for xb,yb in data.train_dl:
    assert len(xb)<=5



# Async ExperienceSources

Async Experience sources have the challenge of running single process ExperienceSources in separate threads. Some questions are how rigid we want to make the actual fit look.

There are currently 2 ways to setup an Async dataset:
- Agent gets the data collected from the child processes. The model gets updated on the main thread which intern reflects in the child processes.
- A sub learner runs inside each process and fits up until the back prop. Instead of doing back prop in the process, we collect the gradients and update them in the main thread. 

The first one is farely straight forward and only requires the model, the agent, and an uninstantiated Dataset.
The second is more complex. We basically have 2 fit procedures shared between the main process and the child processes. 

In [103]:
getattr??

[0;31mDocstring:[0m
getattr(object, name[, default]) -> value

Get a named attribute from an object; getattr(x, 'y') is equivalent to x.y.
When a default argument is given, it is returned when the attribute doesn't
exist; without it, an exception is raised in that case.
[0;31mType:[0m      builtin_function_or_method


In [104]:
# export
def getattrsoftly(o:Optional[object],name:str,default):return getattr(o,name,default) if o is not None else default

In [115]:
# export
class AsyncExperienceSourceCallback(LearnerCallback):pass

def _fitter(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset):pass

@dataclass
class AsyncExperienceSourceDataset(Dataset):
    env:str
    ds_cls:ExperienceSourceDataset.__class__
    n_envs:int=1
    learn:Optional[Learner]=None
    callback_fns:List[LearnerCallback]=field(default_factory=lambda:[AsyncExperienceSourceCallback])
    n_processes:int=1
    bs:int=1
    pause_event:mp.Event=mp.Event()
    cancel_event:mp.Event=mp.Event()   
    queue_sz:Optional[int]=None
    metric_queue:Optional[mp.JoinableQueue]=None
    fitter_fn:Optional[Callable]=_fitter
    fitter_kwargs:Dict=None
    _proc_list:List[mp.Process]=field(default_factory=list)
    
    def __post_init__(self):
        self.queue_sz=ifnone(self.queue_sz,self.n_processes)
        
    def __enter__(self):
        self.start_procs()
        return self
    
    def __len__(self):return self.bs
    
    def __exit__(self,*exc):
        return False
    
    def start_procs(self):
        if self.learn is None:                   _logger.warning('`learn` is None. Will be using random actions for env iteration.')
        elif not hasattr(self.learn,'fitter_fn'):_logger.warning('`learn` does not have a `fitter_fn`. Using default.')
        self.populate_proc_list(self.fitter_fn,self.fitter_kwargs)
            
    def populate_proc_list(self,fitter_fn,kwargs):
        for i in range(self.n_processes):
            self._proc_list.append(mp.Process(target=fitter_fn,kwargs=kwargs))
        
    def end_procs(self):pass
    def pause(self):self.pause_event.set()
    
    def __getitem__(self,_):raise NotImplimentedError()

In [114]:
ds=AsyncExperienceSourceDataset('CartPole-v1',ExperienceSourceDataset,n_envs=2,n_processes=2,bs=128)
with ds:
    for xb,yb in ds:
        print(xb)



NameError: name 'NotImplimentedError' is not defined

In [90]:
# export
add_docs(AsyncExperienceSourceDataset,
        __init__='Asynchronous form of ExperienceSourceDataset. Requires using `with AsyncExperienceSourceDataset():` or calling `start_procs()`'
        ' before looping and `end_procs()` after looking. Using the context manager would be better due to exception handling so we can avoid any'
        ' hanging processes. Returns the data samples from all processes, however can be modfied to behave differently.',
        start_procs='Start the background processes for the environments. Also called in the context manager `__enter__` method.',
        end_procs='End the processes cleanly. Also called in the context manager `__exit__` method.',
        pause='Will cause the environments to stop iterating through episodes. Used for switching between test and train datasets.')

In [None]:
# export
class AsyncExperienceSourceCallback(LearnerCallback):
    _order = -11
    
    def on_epoch_begin(self,**kwargs):
        ds=(self.learn.data.train_ds if self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        if not self.learn.data.empty_val:ds.pause_event.clear()

        if len(ds.data_proc_list)==0:
            if not hasattr(self.learn,'fitter'):
                _logger.warning('Using the default fitter function which will likely not work. Make sure your `AgentLearner` has a `fitter` attribute to actually run/train.')
            
            for proc_idx in range(ds.n_processes):
                _logger.info('Starting Process')
                data_proc=self.load_process()
                data_proc.start()
                ds.data_proc_list.append(data_proc)
                
    def load_process(self):raise NotImplementedError()
    def empty_queues(self):raise NotImplementedError()
        
    def on_batch_begin(self,last_target,last_input,**kwargs):
        return {'last_input':[last_input]}


    def on_batch_end(self,**kwargs):
        # If not training, pause train ds, otherwise pause valid ds
        ds=(self.learn.data.train_ds if not self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        if not self.learn.data.empty_val:ds.pause_event.set()
    
    def on_train_end(self,**kwargs):
        for ds in [self.learn.data.train_ds,None if self.learn.data.empty_val else self.learn.data.valid_ds]:
            if ds is None: continue
            ds.cancel_event.set()
            for _ in range(5):
                self.empty_queues()
            for proc in ds.data_proc_list: proc.join(timeout=0)
                
class AsyncGradExperienceSourceCallback(AsyncExperienceSourceCallback):
    def load_process(self):
        ds=(self.learn.data.train_ds if self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        return mp.Process(target=getattr(self.learn,'fitter',grad_fitter), 
                          args=(self.learn.model,self.learn.agent,ds.ds_cls),
                          kwargs={'grad_queue':ds.grad_queue,'loss_queue':ds.loss_queue,'pause_event':ds.pause_event,'cancel_event':ds.cancel_event,
                                  'metric_queue':ds.metric_queue})
    def empty_queues(self):
        ds=(self.learn.data.train_ds if self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        while not ds.grad_queue.empty(): ds.grad_queue.get()
        while not ds.loss_queue.empty(): ds.loss_queue.get()

class AsyncDataExperienceSourceCallback(AsyncExperienceSourceCallback):
    def load_process(self):
        ds=(self.learn.data.train_ds if self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        return mp.Process(target=getattr(self.learn,'fitter',data_fitter), 
                          args=(self.learn.model,self.learn.agent,ds.ds_cls),
                          kwargs={'data_queue':ds.data_queue,'pause_event':ds.pause_event,'cancel_event':ds.cancel_event,
                                  'metric_queue':ds.metric_queue})
    def empty_queues(self):
        ds=(self.learn.data.train_ds if self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        while not ds.data_queue.empty():ds.data_queue.get()

In [None]:
# export
def safe_fit(f):
    @wraps(f)
    def wrap(*args,cancel_event,**kwargs):
        try:
            return f(*args,cancel_event=cancel_event,**kwargs)
        finally:
            cancel_event.set()
            for k,v in kwargs.items():
                if k.__contains__('queue') and v is not None:v.put(None)
            return None
    return wrap

@safe_fit
def grad_fitter(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset,grad_queue:mp.JoinableQueue,
                loss_queue:mp.JoinableQueue,pause_event:mp.Event,cancel_event:mp.Event,metric_queue:mp.JoinableQueue=None):
    "Updates a `train_queue` with `model.parameters()` and `loss_queue` with the loss. Note that this is only an example grad_fitter."
    while not cancel_event.is_set(): # We are expecting the  grad_fitter to loop unless cancel_event is set
        cancel_event.wait(0.1)
        grad_queue.put(None)         # Adding `None` to `train_queue` will trigger an eventual ending of training
        loss_queue.put(None)
        if pause_event.is_set():     # There needs to be the ability for the grad_fitter to pause e.g. if waiting for validation to end.
            cancel_event.wait(0.1)   # Using cancel_event to wait allows the main process to end this Process.
        break

@safe_fit
def data_fitter(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset,data_queue:mp.JoinableQueue,
                pause_event:mp.Event,cancel_event:mp.Event,metric_queue:mp.JoinableQueue=None):
    _logger.warning('Using the `test_fitter` function. Make sure your `AgentLearner` has a `data_fitter` to actually run/train.')
    while not cancel_event.is_set(): # We are expecting the  grad_fitter to loop unless cancel_event is set
        cancel_event.wait(0.1)
        data_queue.put(None)         # Adding `None` to `train_queue` will trigger an eventual ending of training
        if pause_event.is_set():     # There needs to be the ability for the grad_fitter to pause e.g. if waiting for validation to end.
            cancel_event.wait(0.1)   # Using cancel_event to wait allows the main process to end this Process.     
            
def _soft_queue_get(q:mp.Queue,e:mp.Event):
    entry=None
    while not e.is_set():
        try:
            entry=q.get_nowait()
            break
        except Empty:pass
    return entry
            
class AsyncGradExperienceSourceDataset(ExperienceSourceDataset):
    "Contains dataloaders of multiple sub-datasets and executes them using `n_processes`. `xb` is the gradients from the agents, `yb` is the loss."
    def __init__(self,env_name:str,n_envs=1,ds_cls=ExperienceSourceDataset,max_steps=None,bs=None,n_processes=1,queue_sz=None,*args,**kwargs):
        self.n_processes=n_processes
        self.n_envs=n_envs
        self.env_name=env_name
        self.ds_cls=ds_cls
        self.pause_event=mp.Event()                               # If the event is set, then the Process will freeze.
        self.cancel_event=mp.Event()                              # If the event is set, then the Process will freeze.
        self.max_steps=max_steps
        self.queue_sz=ifnone(queue_sz,self.n_processes)
        self.grad_queue=mp.JoinableQueue(maxsize=self.queue_sz)
        self.loss_queue=mp.JoinableQueue(maxsize=self.queue_sz)
        self.metric_queue:mp.JoinableQueue=None
        self.data_proc_list=[]
        self.callback_fns=[AsyncGradExperienceSourceCallback]
        self._env=gym.make(self.env_name)
        self.bs=ifnone(bs,ifnone(self.max_steps,self._env.spec.max_episode_steps))
    
    def __len__(self): return ifnone(self.max_steps,self._env.spec.max_episode_steps)*self.n_envs
#     def __len__(self): return self.bs
    
        
    def __getitem__(self,idx):
        if len(self.data_proc_list)==0: raise StopIteration()
        train_entry=_soft_queue_get(self.grad_queue,self.cancel_event)

        if train_entry is None:
            raise StopIteration()
        
        train_loss_entry=_soft_queue_get(self.loss_queue,self.cancel_event)
        return train_entry,[train_loss_entry]     
    
class AsyncDataExperienceSourceDataset(ExperienceSourceDataset):
    "Contains dataloaders of multiple sub-datasets and executes them using `n_processes`. `xb` is the gradients from the agents, `yb` is the loss."
    def __init__(self,env_name:str,n_envs=1,ds_cls=ExperienceSourceDataset,max_steps=None,bs=None,n_processes=1,queue_sz=None,**kwargs):
        self.n_processes=n_processes
        self.n_envs=n_envs
        self.ds_cls=ds_cls
        self.env_name=env_name
        self.pause_event=mp.Event()                               # If the event is set, then the Process will freeze.
        self.cancel_event=mp.Event()                              # If the event is set, then the Process will freeze.
        self.max_steps=max_steps
        self.queue_sz=ifnone(queue_sz,self.n_processes)
        self.data_queue=mp.JoinableQueue(maxsize=self.queue_sz)
        self.metric_queue:mp.JoinableQueue=None
        self.data_proc_list=[]
        self.callback_fns=[AsyncDataExperienceSourceCallback]
        self._env=gym.make(self.env_name)
        self.bs=ifnone(bs,ifnone(self.max_steps,self._env.spec.max_episode_steps))
        
#     def __len__(self): return self.bs
    def __len__(self): return ifnone(self.max_steps,self._env.spec.max_episode_steps)*self.n_envs
        
    def __getitem__(self,_):
        if len(self.data_proc_list)==0: raise StopIteration()
        train_entry=_soft_queue_get(self.data_queue,self.cancel_event)
        if train_entry is None:raise StopIteration()
        return Experience(**train_entry).x,train_entry

In [None]:
add_docs(AsyncGradExperienceSourceDataset,
         __init__=textwrap.fill("""The `AsyncGradExperienceSourceDataset` class is instantiated via passing 
         the `env_name` that we want to train on, and a `partial` class of a `ExperienceSourceDataset` called `ds_cls`.
         `bs` is a field here since the length of the dataset is not necessarily the length of an episode. If `None` it will be the length
         of a single episode of the environment. Agents such as A3C will likely change this."""))

In [None]:
# export
class AsyncExperienceSourceDataBunch(ExperienceSourceDataBunch):
    @classmethod
    def from_env(cls,env:str,n_envs=1,data_exp=True,firstlast=False,display=False,max_steps=None,skip_n_steps=1,path:PathOrStr='.',add_valid=True,
                 cols=1,rows=1,max_w=800,n_processes=1,queue_sz=None,bs=1):
        def create_ds(make_empty=False):
            _sub_ds_cls=FirstLastExperienceSourceDataset if firstlast else ExperienceSourceDataset
            _sub_ds_cls=partial(_sub_ds_cls,env=env,n_envs=n_envs,max_steps=0 if make_empty else max_steps,skip_n_steps=skip_n_steps)
            _ds_cls=AsyncDataExperienceSourceDataset if data_exp else AsyncGradExperienceSourceDataset
            _ds=_ds_cls(env,max_steps=0 if make_empty else max_steps,skip_n_steps=skip_n_steps,ds_cls=_sub_ds_cls,n_processes=n_processes,queue_sz=queue_sz)
            for k,v in {'env':env,'n_envs':n_envs,'skip_n_steps':skip_n_steps}.items():
                if not hasattr(_ds,k):setattr(_ds,k,v)
            if display:_ds=DatasetDisplayWrapper(_ds,cols=cols,rows=rows,max_w=max_w)
            return _ds
            
#         dss=(create_ds(),create_ds() if add_valid else None)
        return cls.create(create_ds(),create_ds(not add_valid),bs=bs,num_workers=0)

In [None]:
# export
@safe_fit
def dqn_fitter(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset,data_queue:mp.JoinableQueue,
               pause_event:mp.Event,cancel_event:mp.Event,metric_queue:mp.JoinableQueue=None):
    dataset=ds()
    while not cancel_event.is_set(): 
        for xb,yb in dataset:
            data_queue.put(yb)
            if pause_event.is_set():cancel_event.wait(0.1) 
            if cancel_event.is_set():break
        if metric_queue is not None:metric_queue.put(TotalRewards(np.mean(dataset.pop_total_rewards())))
        if cancel_event.is_set():break
                
@safe_fit
def dqn_grad_fitter(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset,grad_queue:mp.JoinableQueue,loss_queue:mp.JoinableQueue,
                    pause_event:mp.Event,cancel_event:mp.Event,metric_queue:mp.JoinableQueue=None):
    dataset=ds()
    while not cancel_event.is_set(): 
        for xb,yb in dataset:
            sys.stdout.flush()
            grad_queue.put(xb)
            loss_queue.put(0.5)
            if pause_event.is_set():cancel_event.wait(0.1) 
            if cancel_event.is_set():break
        if metric_queue is not None:metric_queue.put(TotalRewards(np.mean(dataset.pop_total_rewards())))
        if cancel_event.is_set():break

@safe_fit
def buggy_dqn_fitter(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset,data_queue:mp.JoinableQueue,
                pause_event:mp.Event,cancel_event:mp.Event,metric_queue:mp.JoinableQueue=None):
    dataset=ds()
    while not cancel_event.is_set(): 
        for xb,yb in dataset:
            data_queue.put(yb)
            if pause_event.is_set():cancel_event.wait(0.1) 
            if cancel_event.is_set():break
            raise Exception('Crashing on purpose')
        if cancel_event.is_set():break

In [None]:
from fastrl.basic_train import *
import pytest

If no fitter is added, then calling the fit function will likelly result in a `TypeError` where the `smooth_loss` is missing.

In [None]:
class FakeRunCallback2(LearnerCallback):
    def on_backward_begin(self,*args,**kwargs): return {'skip_bwd':True,'skip_validate':True}
#     def on_batch_begin(self,last_target,last_input,**kwargs):
#         print(*last_input)
#         print('hello')
#         sys.stdout.flush()
#         return {'last_input':last_input.unsqueeze(0)}

In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',data_exp=True,display=False,max_steps=50,firstlast=False,add_valid=False,n_processes=1,n_envs=1,bs=4)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback2])
setattr(learn,'fitter',dqn_fitter)
learn.fit(2,lr=0.01,wd=1)

In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',data_exp=True,display=False,max_steps=50,firstlast=True,add_valid=False,n_processes=4,n_envs=2,bs=4)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback2])
setattr(learn,'fitter',dqn_fitter)
learn.fit(1,lr=0.01,wd=1)

In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',data_exp=True,display=False,max_steps=50,firstlast=True,add_valid=False,n_processes=1,n_envs=1)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback])
with pytest.raises(TypeError):
    learn.fit(1,lr=0.01,wd=1)

In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',data_exp=False,display=False,firstlast=True,add_valid=False,n_processes=2,n_envs=2)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback])
setattr(learn,'fitter',dqn_grad_fitter)
learn.fit(1,lr=0.01,wd=1)

In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',display=False,firstlast=True,add_valid=False,max_steps=50,n_processes=2,n_envs=2)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback])
setattr(learn,'fitter',dqn_fitter)
learn.fit(3,lr=0.01,wd=1)

In [None]:
import pytest
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',display=False,firstlast=True,add_valid=False,max_steps=50,n_processes=2,n_envs=2)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback])
setattr(learn,'fitter',buggy_dqn_fitter)
with pytest.raises(TypeError):
    learn.fit(1,lr=0.01,wd=1)

In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',display=False,firstlast=True,add_valid=False,max_steps=50,n_processes=2,n_envs=2,bs=128)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback])
setattr(learn,'fitter',dqn_fitter)
learn.fit(1,lr=0.01,wd=1)

In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',n_envs=5,display=False,max_steps=50,firstlast=False,add_valid=False)
for xb,yb in data.train_dl:
    test_eq(len(xb),1)
    test_eq(tuple(xb[0].shape),(5,4))

## Dataset Shower
We can define a wrapper around a dataset which will show up to `rows * cols` environments.

In [20]:
# export
if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

In [21]:
%matplotlib inline

In [22]:
# export
class DatasetDisplayWrapper(object):
    def __init__(self,ds,rows=2,cols=2,max_w=800):
        "Wraps a ExperienceSourceDataset instance showing multiple envs in a `rows` by `cols` grid in a Jupyter notebook."
        # Ref: https://stackoverflow.com/questions/1443129/completely-wrap-an-object-in-python
        # We are basically Wrapping any instance of ExperienceSourceDataset (kind of cool right?)
        clss=(ExperienceSourceDataset,FirstLastExperienceSourceDataset,
              AsyncGradExperienceSourceDataset,AsyncDataExperienceSourceDataset)
        assert issubclass(ds.__class__,clss),'Currently this only works with the ExperienceSourceDataset and Async*ExperienceSourceDataset class only.'
        self.__class__ = type(ds.__class__.__name__,(self.__class__, ds.__class__),{})
        self.__dict__=ds.__dict__
        self.rows,self.cols,self.max_w=rows,cols,max_w
        self.current_display=None
        if not IN_NOTEBOOK: 
            _logger.warning('It seems you are not running in a notebook. Nothing is going to be displayed.')
            return
        
        if self.envs[0].render('rgb_array') is None: self.envs[0].reset()
        rdr=self.envs[0].render('rgb_array')
        if rdr.shape[1]*self.cols>max_w:
            _logger.warning('Max Width is %s but %s*%s is greater than. Decreasing the number of cols to %s, rows increase by %s',
                            max_w,rdr.shape[1],self.cols,max_w%rdr.shape[1],max_w%rdr.shape[1])
            self.cols=max_w%rdr.shape[1]
            self.rows+=max_w%rdr.shape[1]
        self.max_displays=self.cols*self.rows
        self.current_display=np.zeros(shape=(self.rows*rdr.shape[0],self.cols*rdr.shape[1],rdr.shape[2])).astype('uint8')
        _logger.info('%s, %s, %s, %s, %s',0,0//self.cols,0%self.cols,rdr.shape,self.current_display.shape)

    def __getitem__(self,_):
        idx=self._env_idx
        o=super(DatasetDisplayWrapper,self).__getitem__(idx)
        idx=idx%self.max_displays
        if self.current_display is not None and idx<self.rows*self.cols:
            display.clear_output(wait=True)
            im=self.envs[idx].render(mode='rgb_array')
            self.current_display[(idx//self.cols)*im.shape[0]:(idx//self.cols)*im.shape[0]+im.shape[0],
                                 (idx%self.cols)*im.shape[1]:(idx%self.cols)*im.shape[1]+im.shape[1],:]=im
            new_im=PIL.Image.fromarray(self.current_display)
            display.display(new_im)
        else:
            display.display(PIL.Image.fromarray(self.current_display))
        return o

In [116]:
ds=ExperienceSourceDataset("MountainCar-v0",n_envs=15,skip_n_steps=1,pixels=True)
ds=DatasetDisplayWrapper(ds,1,1,800)
dl=DataLoader(ds,batch_size=1,num_workers=0)
for xb,yb in dl:
    pass

NameError: name 'AsyncGradExperienceSourceDataset' is not defined

In [120]:
# # hide
from nbdev.export import *
notebook2script()
notebook2html(n_workers=0)

Converted 00_core.ipynb.
Converted 01_wrappers.ipynb.
Converted 02_callbacks.ipynb.
Converted 03_basic_agents.ipynb.
Converted 04_metrics.ipynb.
Converted 05_data_block.ipynb.
Converted 06_basic_train.ipynb.
Converted 12_a3c.a3c_data.ipynb.
Converted Untitled.ipynb.
Converted index.ipynb.
Converted notes.ipynb.


converting: /opt/project/fastrl/nbs/05_data_block.ipynb
An error occurred while executing the following cell:
------------------
from nbdev.showdoc import show_doc
from fastrl.data_block import *
------------------

[0;31m---------------------------------------------------------------------------[0m
[0;31mAssertionError[0m                            Traceback (most recent call last)
[0;32m<ipython-input-1-b022c5ef8e74>[0m in [0;36m<module>[0;34m[0m
[1;32m      1[0m [0;32mfrom[0m [0mnbdev[0m[0;34m.[0m[0mshowdoc[0m [0;32mimport[0m [0mshow_doc[0m[0;34m[0m[0;34m[0m[0m
[0;32m----> 2[0;31m [0;32mfrom[0m [0mfastrl[0m[0;34m.[0m[0mdata_block[0m [0;32mimport[0m [0;34m*[0m[0;34m[0m[0;34m[0m[0m
[0m
[0;32m/opt/project/fastrl/fastrl/data_block.py[0m in [0;36m<module>[0;34m[0m
[1;32m    313[0m         [0mstart_procs[0m[0;34m=[0m[0;34m'Start the background processes for the environments. Also called in the context manager `__enter__` method.'

Exception: Conversion failed on the following:
05_data_block.ipynb