In [None]:
# default_exp data_block
%load_ext autoreload
%autoreload 2

In [None]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# Data Block

> Primary handlers for interfacing the openai gym envs

In [None]:
# export
from fastai.basic_data import *
from fastai.basic_train import *
from fastai.torch_core import *
from fastai.callbacks import *
from fastrl.wrappers import *
from fastrl.basic_agents import *
from dataclasses import asdict
from functools import partial
from fastprogress.fastprogress import IN_NOTEBOOK
from fastcore.utils import *
import torch.multiprocessing as mp
from functools import wraps
from queue import Empty
import textwrap
import logging
import gym

logging.basicConfig(format='[%(asctime)s] p%(process)s line:%(lineno)d %(levelname)s - %(message)s',
                    datefmt='%m-%d %H:%M:%S')
_logger=logging.getLogger(__name__)

In [None]:
# hide
_logger.setLevel('INFO')
from fastcore.foundation import *
import sys

## Dataset

`Dataset` instances are going to be a little different from the typically classification dataset that you might use in pytorch. Commonly, datasets have:
- A known size to iter through
- Maintain their state during the training sequence
- Randomly sample their dataset
- Have a common `x`/`y` or `input`/`target` data format

For our `ExperienceSourceDataset`, most of this is going to be different. 
- We can have multiple sources (envs)

You could think of a traditional dataset approach as being a mix of a `ExperienceSourceDataset` and a form of `ExperienceReplay`.

In [None]:
# export
def fix_s(x):
    "Flatten `x` to `(1,-1)` where `1` is the batch dim (B) unless **seems** to be an image e.g. has 3 dims (W,H,C), then it will attempt (B,W,H,C)."
    return (x if x.shape[0]==1 else 
            x.reshape(1,-1) if len(x.shape)==2 else
            np.expand_dims(x,0))

@dataclass
class Experience(object):
    s:np.array
    sp:np.array
    a:np.array
    r:np.array
    d:np.array
    agent_s:np.array
        
    # TODO possibly have x,y for more generic experience integration with datasets.
        
#     def __post_init__(self):
#         for k,v in asdict(self).items():
#             setattr(self,k,fix_s(v) if k in ['s','sp'] else np.array(v,dtype=float).reshape(1,-1))

In [None]:
exp=Experience(s=np.random.randint(1,3,(5,5)),
               sp=np.random.randint(1,3,(5,5)),
               a=np.random.randint(1,3,(1,2)),
               r=np.random.randint(1,3,(1,20)),
               d=np.random.randint(0,1,(5,5)),
               agent_s=np.random.randint(0,6,(1,5)))
asdict(exp)

{'s': array([[2, 2, 2, 1, 1],
        [2, 1, 1, 2, 1],
        [1, 1, 2, 1, 1],
        [2, 1, 1, 1, 1],
        [1, 2, 1, 1, 1]]),
 'sp': array([[2, 2, 1, 2, 1],
        [2, 2, 2, 1, 2],
        [1, 2, 1, 1, 1],
        [1, 2, 2, 2, 2],
        [2, 1, 2, 1, 1]]),
 'a': array([[1, 2]]),
 'r': array([[2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1]]),
 'd': array([[0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]]),
 'agent_s': array([[4, 5, 0, 0, 4]])}

## Datasets

In [None]:
# export
class ExperienceSourceCallback(LearnerCallback):
    def on_train_begin(self,*args,**kwargs):
        self.learn.data.train_dl.dataset.learn=self.learn
        if not self.learn.data.empty_val:
            self.learn.data.valid_dl.dataset.learn=self.learn

class ExperienceSourceDataset(Dataset):
    "Similar to fastai's `LabelList`, iters in-order samples from `1->len(envs)` `envs`."
    def __init__(self,env:str,n_envs=1,steps=1,max_episode_step=None,pixels=False):
        def make_env():
            env=gym.make("CartPole-v1")
            if pixels:env.reset()
            return env
            
        self.envs=[make_env() for _ in range(n_envs)]
        if pixels:self.envs=[PixelObservationWrapper(e) for e in self.envs]
        self.steps=steps
        self.max_episode_step=max_episode_step
        self.d=np.zeros((len(self.envs),))+1
        self.s=np.zeros((len(self.envs),*self.envs[0].observation_space.sample().shape))
        self.iterations=np.zeros((len(self.envs)))
        self.total_r=[]
        self.total_steps=[]
        self.learn=None
        self._warned=False
        self.callback_fns=[ExperienceSourceCallback]
        self.inc=0
        
    def __len__(self): return ifnone(self.max_episode_step,self.envs[0].spec.max_episode_steps)*len(self.envs)
    
    def get_action(self,idx):
        if self.learn is None:
            if not self._warned:_logger.warning('`self.learn` is None. will use random actions instead.')
            self._warned=True
            return self.envs[0].action_space.sample(),np.zeros((1,1))
        return self.learn.predict(self.s[idx])
    
    @log_args
    def __getitem__(self,_):
        idx=self.inc
        _logger.debug('Idx:%s,Iter:%s',idx,self.iterations[idx])
        if idx==0 and self.iterations[idx]==0:
            for i,e in enumerate(self.envs): self.s[i]=e.reset() # There is the possiblity this will equal None (and crash)?
            self.iterations=np.zeros((len(self.envs)),dtype=int)
        
        exps:List[Experience]=[]
        while True:
            a,agent_s=self.get_action(idx)
            sp,r,self.d[idx],_=self.envs[idx].step(a)
            self.total_r.append(r)
            self.total_steps.append(self.iterations[idx])
            exps.append(Experience(self.s[idx],sp,a,r,self.d[idx],agent_s=ifnone(agent_s,[])))
            self.s[idx]=sp
            self.iterations[idx]+=1
            if self.d[idx] or self.iterations[idx]>=len(self)-1:
                self.iterations[idx]=0
                if self.inc>=len(self.envs)-1:self.inc=0
                else:                       self.inc+=1
                if idx>=len(self.envs)-1:raise StopIteration()
                break
            if len(exps)%self.steps==0:
                break
        return [e.s for e in exps],[asdict(e) for e in exps]

In [None]:
import sys
# make_env = lambda: gym.make("Acrobot-v1")
ds=ExperienceSourceDataset("CartPole-v1",n_envs=3,max_episode_step=50,steps=5)
dl=DataLoader(ds,batch_size=3,num_workers=0)
for xb,yb in dl:
    _logger.critical('xb: %s, yb: %s',torch.cat(xb).shape,yb[0].keys())
    assert torch.cat(xb).shape[0]<=5*3
    sys.stdout.flush()
    
# make_env = lambda: gym.make("Acrobot-v1")
make_env = lambda: gym.make("CartPole-v1")
envs=[make_env() for _ in range(5)]
ds=ExperienceSourceDataset("CartPole-v1",n_envs=5,max_episode_step=50,steps=5)
dl=DataLoader(ds,batch_size=5,num_workers=0)
for xb,yb in dl:
    _logger.critical('xb: %s, yb: %s',torch.cat(xb).shape,yb[0].keys())
    assert torch.cat(xb).shape[0]<=5*5
    sys.stdout.flush()

[07-17 20:14:31] p260 line:6 CRITICAL - xb: torch.Size([15, 4]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])
[07-17 20:14:31] p260 line:6 CRITICAL - xb: torch.Size([9, 4]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])
[07-17 20:14:31] p260 line:6 CRITICAL - xb: torch.Size([3, 4]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])
[07-17 20:14:31] p260 line:16 CRITICAL - xb: torch.Size([15, 4]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])
[07-17 20:14:31] p260 line:16 CRITICAL - xb: torch.Size([25, 4]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])
[07-17 20:14:31] p260 line:16 CRITICAL - xb: torch.Size([10, 4]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])
[07-17 20:14:31] p260 line:16 CRITICAL - xb: torch.Size([20, 4]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])


In [None]:
ds=ExperienceSourceDataset('CartPole-v1',n_envs=3,max_episode_step=50,steps=5,pixels=True)
dl=DataLoader(ds,batch_size=3,num_workers=0)
for xb,yb in dl:
    _logger.critical('xb: %s, yb: %s',torch.cat(xb).shape,yb[0].keys())
    assert torch.cat(xb).shape[0]<=5*3
    sys.stdout.flush()

[07-17 20:14:32] p260 line:4 CRITICAL - xb: torch.Size([15, 400, 600, 3]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])
[07-17 20:14:32] p260 line:4 CRITICAL - xb: torch.Size([9, 400, 600, 3]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])
[07-17 20:14:32] p260 line:4 CRITICAL - xb: torch.Size([15, 400, 600, 3]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])


In [None]:
# export
class FirstLastExperienceSourceDataset(ExperienceSourceDataset):
    "Similar to `ExperienceSourceDataset` but only keeps the first and last parts of a step. Can be seen as frame skipping."
    @log_args
    def __getitem__(self,idx):
        s,exps=super(FirstLastExperienceSourceDataset,self).__getitem__(idx)
        exp=exps[-1]
        exp['s']=exps[0]['s']
        return [exps[0]['s']],[exp]

In [None]:
ds=FirstLastExperienceSourceDataset("CartPole-v1",n_envs=3,max_episode_step=50,steps=5)
dl=DataLoader(ds,batch_size=3,num_workers=0)
index=0
for xb,yb in dl:
    _logger.critical('xb: %s, yb: %s',torch.cat(xb).shape,yb[0].keys())
    assert torch.cat(xb).shape[0]<=3
    sys.stdout.flush()
    index+=1

[07-17 20:14:32] p260 line:5 CRITICAL - xb: torch.Size([3, 4]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])
[07-17 20:14:32] p260 line:5 CRITICAL - xb: torch.Size([3, 4]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])
[07-17 20:14:32] p260 line:5 CRITICAL - xb: torch.Size([3, 4]), yb: dict_keys(['s', 'sp', 'a', 'r', 'd', 'agent_s'])


In [None]:
# export
class AsyncExperienceSourceCallback(LearnerCallback):
    _order = -11
    
    def on_epoch_begin(self,**kwargs):
        ds=(self.learn.data.train_ds if self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        if not self.learn.data.empty_val:ds.pause_event.clear()

        if len(ds.data_proc_list)==0:
            if not hasattr(self.learn,'fitter'):
                _logger.warning('Using the default fitter function which will likely not work. Make sure your `AgentLearner` has a `fitter` attribute to actually run/train.')
            
            for proc_idx in range(ds.n_processes):
                _logger.info('Starting Process')
                data_proc=self.load_process()
                data_proc.start()
                ds.data_proc_list.append(data_proc)
                
    def load_process(self):raise NotImplementedError()
    def empty_queues(self):raise NotImplementedError()

    def on_batch_end(self,**kwargs):
        # If not training, pause train ds, otherwise pause valid ds
        ds=(self.learn.data.train_ds if not self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        if not self.learn.data.empty_val:ds.pause_event.set()
    
    def on_train_end(self,**kwargs):
        for ds in [self.learn.data.train_ds,None if self.learn.data.empty_val else self.learn.data.valid_ds]:
            if ds is None: continue
            ds.cancel_event.set()
            for _ in range(5):
                self.empty_queues()
            for proc in ds.data_proc_list: proc.join(timeout=0)
                
class AsyncGradExperienceSourceCallback(AsyncExperienceSourceCallback):
    def load_process(self):
        ds=(self.learn.data.train_ds if self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        return mp.Process(target=getattr(self.learn,'fitter',grad_fitter), 
                          args=(self.learn.model,self.learn.agent,ds.ds_cls),
                          kwargs={'grad_queue':ds.grad_queue,'loss_queue':ds.loss_queue,'pause_event':ds.pause_event,'cancel_event':ds.cancel_event})
    def empty_queues(self):
        ds=(self.learn.data.train_ds if self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        while not ds.grad_queue.empty(): ds.grad_queue.get()
        while not ds.loss_queue.empty(): ds.loss_queue.get()

class AsyncDataExperienceSourceCallback(AsyncExperienceSourceCallback):
    def load_process(self):
        ds=(self.learn.data.train_ds if self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        return mp.Process(target=getattr(self.learn,'fitter',data_fitter), 
                          args=(self.learn.model,self.learn.agent,ds.ds_cls),
                          kwargs={'data_queue':ds.data_queue,'pause_event':ds.pause_event,'cancel_event':ds.cancel_event})
    def empty_queues(self):
        ds=(self.learn.data.train_ds if self.learn.model.training or self.learn.data.empty_val else
            self.learn.data.valid_ds)
        while not ds.data_queue.empty():ds.data_queue.get()

In [None]:
# export
def safe_fit(f):
    @wraps(f)
    def wrap(*args,cancel_event,**kwargs):
        try:
            return f(*args,cancel_event=cancel_event,**kwargs)
        finally:
            cancel_event.set()
            for k,v in kwargs.items():
                if k.__contains__('queue'):v.put(None)
            return None
    return wrap

@safe_fit
def grad_fitter(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset,grad_queue:mp.JoinableQueue,
                loss_queue:mp.JoinableQueue,pause_event:mp.Event,cancel_event:mp.Event):
    "Updates a `train_queue` with `model.parameters()` and `loss_queue` with the loss. Note that this is only an example grad_fitter."
    while not cancel_event.is_set(): # We are expecting the  grad_fitter to loop unless cancel_event is set
        cancel_event.wait(0.1)
        grad_queue.put(None)         # Adding `None` to `train_queue` will trigger an eventual ending of training
        loss_queue.put(None)
        if pause_event.is_set():     # There needs to be the ability for the grad_fitter to pause e.g. if waiting for validation to end.
            cancel_event.wait(0.1)   # Using cancel_event to wait allows the main process to end this Process.
        break

@safe_fit
def data_fitter(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset,data_queue:mp.JoinableQueue,
                pause_event:mp.Event,cancel_event:mp.Event):
    _logger.warning('Using the `test_fitter` function. Make sure your `AgentLearner` has a `data_fitter` to actually run/train.')
    while not cancel_event.is_set(): # We are expecting the  grad_fitter to loop unless cancel_event is set
        cancel_event.wait(0.1)
        data_queue.put(None)         # Adding `None` to `train_queue` will trigger an eventual ending of training
        if pause_event.is_set():     # There needs to be the ability for the grad_fitter to pause e.g. if waiting for validation to end.
            cancel_event.wait(0.1)   # Using cancel_event to wait allows the main process to end this Process.     
            
def _soft_queue_get(q:mp.Queue,e:mp.Event):
    entry=None
    while not e.is_set():
        try:
            entry=q.get_nowait()
            break
        except Empty:pass
    return entry
            
class AsyncGradExperienceSourceDataset(ExperienceSourceDataset):
    "Contains dataloaders of multiple sub-datasets and executes them using `n_processes`. `xb` is the gradients from the agents, `yb` is the loss."
    def __init__(self,env_name:str,n_envs=1,ds_cls=ExperienceSourceDataset,max_episode_step=None,n_processes=1,*args,**kwargs):
        self.n_processes=n_processes
        self.n_envs=n_envs
        self.env_name=env_name
        self.ds_cls=ds_cls
        self.pause_event=mp.Event()                               # If the event is set, then the Process will freeze.
        self.cancel_event=mp.Event()                              # If the event is set, then the Process will freeze.
        self.max_episode_step=max_episode_step
        self.grad_queue=mp.JoinableQueue(maxsize=self.n_processes)
        self.loss_queue=mp.JoinableQueue(maxsize=self.n_processes)
        self.data_proc_list=[]
        self.callback_fns=[AsyncGradExperienceSourceCallback]
        self._env=gym.make(self.env_name)
    
    def __len__(self): return ifnone(self.max_episode_step,self._env.spec.max_episode_steps)*self.n_envs
    
        
    def __getitem__(self,idx):
        if len(self.data_proc_list)==0: raise StopIteration()
        train_entry=_soft_queue_get(self.grad_queue,self.cancel_event)

        if train_entry is None:
            raise StopIteration()
        
        train_loss_entry=_soft_queue_get(self.loss_queue,self.cancel_event)
        return train_entry,[train_loss_entry]     
    
class AsyncDataExperienceSourceDataset(ExperienceSourceDataset):
    "Contains dataloaders of multiple sub-datasets and executes them using `n_processes`. `xb` is the gradients from the agents, `yb` is the loss."
    def __init__(self,env_name:str,n_envs=1,ds_cls=ExperienceSourceDataset,max_episode_step=None,n_processes=1,*args,**kwargs):
        self.n_processes=n_processes
        self.n_envs=n_envs
        self.ds_cls=ds_cls
        self.env_name=env_name
        self.pause_event=mp.Event()                               # If the event is set, then the Process will freeze.
        self.cancel_event=mp.Event()                              # If the event is set, then the Process will freeze.
        self.max_episode_step=max_episode_step
        self.data_queue=mp.JoinableQueue(maxsize=self.n_processes)
        self.data_proc_list=[]
        self.callback_fns=[AsyncDataExperienceSourceCallback]
        self._env=gym.make(self.env_name)
        
    def __len__(self): return ifnone(self.max_episode_step,self._env.spec.max_episode_steps)*self.n_envs
        
    def __getitem__(self,idx):
        if len(self.data_proc_list)==0: raise StopIteration()
        train_entry=_soft_queue_get(self.data_queue,self.cancel_event)

        if train_entry is None:
            raise StopIteration()

        return [e['s'] for e in train_entry],train_entry

In [None]:
add_docs(AsyncGradExperienceSourceDataset,
         __init__=textwrap.fill("""The `AsyncGradExperienceSourceDataset` class is instantiated via passing 
         the `env_name` that we want to train on, and a `partial` class of a `ExperienceSourceDataset` called `ds_cls`."""))

## Dataset Shower
We can define a wrapper around a dataset which will show up to `rows * cols` environments.

In [None]:
# export
if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

In [None]:
%matplotlib inline

In [None]:
# export
class DatasetDisplayWrapper(object):
    def __init__(self,ds,rows=2,cols=2,max_w=800):
        "Wraps a ExperienceSourceDataset instance showing multiple envs in a `rows` by `cols` grid in a Jupyter notebook."
        # Ref: https://stackoverflow.com/questions/1443129/completely-wrap-an-object-in-python
        # We are basically Wrapping any instance of ExperienceSourceDataset (kind of cool right?)
        clss=(ExperienceSourceDataset,FirstLastExperienceSourceDataset,
              AsyncGradExperienceSourceDataset,AsyncDataExperienceSourceDataset)
        assert issubclass(ds.__class__,clss),'Currently this only works with the ExperienceSourceDataset and Async*ExperienceSourceDataset class only.'
        self.__class__ = type(ds.__class__.__name__,(self.__class__, ds.__class__),{})
        self.__dict__=ds.__dict__
        self.rows,self.cols,self.max_w=rows,cols,max_w
        self.current_display=None
        if not IN_NOTEBOOK: 
            _logger.warning('It seems you are not running in a notebook. Nothing is going to be displayed.')
            return
        
        if self.envs[0].render('rgb_array') is None: self.envs[0].reset()
        rdr=self.envs[0].render('rgb_array')
        if rdr.shape[1]*self.cols>max_w:
            _logger.warning('Max Width is %s but %s*%s is greater than. Decreasing the number of cols to %s, rows increase by %s',
                            max_w,rdr.shape[1],self.cols,max_w%rdr.shape[1],max_w%rdr.shape[1])
            self.cols=max_w%rdr.shape[1]
            self.rows+=max_w%rdr.shape[1]
        self.max_displays=self.cols*self.rows
        self.current_display=np.zeros(shape=(self.rows*rdr.shape[0],self.cols*rdr.shape[1],rdr.shape[2])).astype('uint8')
        _logger.info('%s, %s, %s, %s, %s',0,0//self.cols,0%self.cols,rdr.shape,self.current_display.shape)

    def __getitem__(self,idx):
        o=super(DatasetDisplayWrapper,self).__getitem__(idx)
        idx=idx%self.max_displays
        if self.current_display is not None and idx<self.rows*self.cols:
            display.clear_output(wait=True)
            im=self.envs[idx].render(mode='rgb_array')
            self.current_display[(idx//self.cols)*im.shape[0]:(idx//self.cols)*im.shape[0]+im.shape[0],
                                 (idx%self.cols)*im.shape[1]:(idx%self.cols)*im.shape[1]+im.shape[1],:]=im
            new_im=PIL.Image.fromarray(self.current_display)
            display.display(new_im)
        else:
            display.display(PIL.Image.fromarray(self.current_display))
        return o

## DataBunch

In [None]:
# export
class ExperienceSourceDataBunch(DataBunch):
    @classmethod
    def from_env(cls,env:str,n_envs=1,firstlast=False,display=False,max_steps=None,skip_step=1,path:PathOrStr='.',add_valid=True,
                 cols=1,rows=1,max_w=800):
        def create_ds(make_empty=False):
            _ds_cls=FirstLastExperienceSourceDataset if firstlast else ExperienceSourceDataset
            make_env = lambda: gym.make(env)
            envs=[make_env() for _ in range(n_envs)]
            _ds=_ds_cls(envs,max_episode_step=0 if make_empty else max_steps,steps=skip_step)
            if display:_ds=DatasetDisplayWrapper(_ds,cols=cols,rows=rows,max_w=max_w)
            return _ds
            
        dss=(create_ds(),create_ds(not add_valid))
        return cls.create(*dss,bs=n_envs,num_workers=0)
    
    @classmethod
    def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64,
               val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,
               device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, **dl_kwargs)->'DataBunch':
        "Create a `DataBunch` from `train_ds`, `valid_ds` and maybe `test_ds` with a batch size of `bs`. Passes `**dl_kwargs` to `DataLoader()`"
        datasets = cls._init_ds(train_ds, valid_ds, test_ds)
        val_bs = ifnone(val_bs, bs)
        dls = [DataLoader(d, b, shuffle=s, drop_last=s, num_workers=num_workers, **dl_kwargs) for d,b,s in
               zip(datasets, (bs,val_bs,val_bs,val_bs), (False,False,False,False)) if d is not None]
        return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)

In [None]:
data=ExperienceSourceDataBunch.from_env('CartPole-v1',n_envs=5,display=False,firstlast=False,add_valid=False)
for xb,yb in data.train_dl:
    test_eq(len(xb),1)
    test_eq(tuple(xb[0].shape),(5,4))



In [None]:
data=ExperienceSourceDataBunch.from_env('CartPole-v1',n_envs=5,display=False,firstlast=True,add_valid=False)
for xb,yb in data.train_dl:
    test_eq(len(xb),1)
    test_eq(tuple(xb[0].shape),(5,4))



In [None]:
# export
class AsyncExperienceSourceDataBunch(ExperienceSourceDataBunch):
    @classmethod
    def from_env(cls,env:str,n_envs=1,data_exp=True,firstlast=False,display=False,max_steps=None,skip_step=1,path:PathOrStr='.',add_valid=True,
                 cols=1,rows=1,max_w=800,n_processes=1):
        def create_ds(make_empty=False):
            _sub_ds_cls=FirstLastExperienceSourceDataset if firstlast else ExperienceSourceDataset
            _sub_ds_cls=partial(_sub_ds_cls,env=env,n_envs=1,max_episode_step=0 if make_empty else max_steps,steps=skip_step)
            _ds_cls=AsyncDataExperienceSourceDataset if data_exp else AsyncGradExperienceSourceDataset
            _ds=_ds_cls(env,max_episode_step=0 if make_empty else max_steps,steps=skip_step,ds_cls=_sub_ds_cls,n_processes=n_processes)
            if display:_ds=DatasetDisplayWrapper(_ds,cols=cols,rows=rows,max_w=max_w)
            return _ds
            
#         dss=(create_ds(),create_ds() if add_valid else None)
        return cls.create(create_ds(),create_ds(not add_valid),bs=n_envs,num_workers=0)

In [None]:
# export
@safe_fit
def dqn_fitter(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset,data_queue:mp.JoinableQueue,
                pause_event:mp.Event,cancel_event:mp.Event):
    dataset=ds()
    while not cancel_event.is_set(): 
        for xb,yb in dataset:
            data_queue.put(yb)
            if pause_event.is_set():cancel_event.wait(0.1) 
            if cancel_event.is_set():break
        if cancel_event.is_set():break
                
@safe_fit
def dqn_grad_fitter(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset,grad_queue:mp.JoinableQueue,loss_queue:mp.JoinableQueue,
                    pause_event:mp.Event,cancel_event:mp.Event):
    dataset=ds()
    while not cancel_event.is_set(): 
        for xb,yb in dataset:
            sys.stdout.flush()
            grad_queue.put(xb)
            loss_queue.put(0.5)
            if pause_event.is_set():cancel_event.wait(0.1) 
            if cancel_event.is_set():break
        if cancel_event.is_set():break

@safe_fit
def buggy_dqn_fitter(model:nn.Module,agent:BaseAgent,ds:ExperienceSourceDataset,data_queue:mp.JoinableQueue,
                pause_event:mp.Event,cancel_event:mp.Event):
    dataset=ds()
    while not cancel_event.is_set(): 
        for xb,yb in dataset:
            data_queue.put(yb)
            if pause_event.is_set():cancel_event.wait(0.1) 
            if cancel_event.is_set():break
            raise Exception('Crashing on purpose')
        if cancel_event.is_set():break

In [None]:
from fastrl.basic_train import *

If no fitter is added, then calling the fit function will likelly result in a `TypeError` where the `smooth_loss` is missing.

In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',data_exp=True,display=False,max_steps=50,firstlast=True,add_valid=False,n_processes=1,n_envs=1)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback],loss_func=fake_loss)
with pytest.raises(TypeError):
    learn.fit(1,lr=0.01,wd=1)

epoch,train_loss,valid_loss,time


[07-17 23:07:49] p260 line:15 INFO - Starting Process


In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',data_exp=False,display=False,firstlast=True,add_valid=False,n_processes=2,n_envs=2)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback],loss_func=fake_loss)
setattr(learn,'fitter',dqn_grad_fitter)
learn.fit(1,lr=0.01,wd=1)

epoch,train_loss,valid_loss,time
0,0.5,#na#,00:00


[07-17 23:07:55] p260 line:15 INFO - Starting Process
[07-17 23:07:55] p260 line:15 INFO - Starting Process


In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',display=False,firstlast=True,add_valid=False,max_steps=50,n_processes=2,n_envs=2)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback],loss_func=fake_loss)
setattr(learn,'fitter',dqn_fitter)
learn.fit(1,lr=0.01,wd=1)

epoch,train_loss,valid_loss,time
0,0.5,#na#,00:00


[07-17 22:14:22] p260 line:15 INFO - Starting Process
[07-17 22:14:22] p260 line:15 INFO - Starting Process


In [None]:
import pytest
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',display=False,firstlast=True,add_valid=False,max_steps=50,n_processes=2,n_envs=2)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback],loss_func=fake_loss)
setattr(learn,'fitter',buggy_dqn_fitter)
with pytest.raises(TypeError):
    learn.fit(1,lr=0.01,wd=1)

epoch,train_loss,valid_loss,time


[07-17 22:14:23] p260 line:15 INFO - Starting Process
[07-17 22:14:23] p260 line:15 INFO - Starting Process


In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',display=False,firstlast=True,add_valid=False,max_steps=50,n_processes=2,n_envs=2)
model=nn.Sequential(nn.Linear(4,5),nn.ReLU(),nn.Linear(5,2))
agent=DQNAgent(model=model)
learn=AgentLearner(data,model,agent=agent,callback_fns=[FakeRunCallback],loss_func=fake_loss)
setattr(learn,'fitter',dqn_fitter)
learn.fit(1,lr=0.01,wd=1)

epoch,train_loss,valid_loss,time
0,0.5,#na#,00:00


[07-17 22:14:24] p260 line:15 INFO - Starting Process
[07-17 22:14:24] p260 line:15 INFO - Starting Process


In [None]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',n_envs=5,display=False,max_steps=50,firstlast=False,add_valid=False)
for xb,yb in data.train_dl:
    test_eq(len(xb),1)
    test_eq(tuple(xb[0].shape),(5,4))

In [None]:
# hide
from nbdev.export import *
notebook2script()
notebook2html(n_workers=0)

Converted 00_core.ipynb.
Converted 01_wrappers.ipynb.
Converted 02_callbacks.ipynb.
Converted 03_basic_agents.ipynb.
Converted 05_data_block.ipynb.
Converted 06_basic_train.ipynb.
Converted 12_a3c.a3c_data.ipynb.
Converted index.ipynb.
Converted notes.ipynb.
converting: /opt/project/fastrl/nbs/12_a3c.a3c_data.ipynb
converting: /opt/project/fastrl/nbs/05_data_block.ipynb
