In [None]:
#hide
#skip
%config Completer.use_jedi = False
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [None]:
# hide
from fastcore.imports import in_colab
from fastai.test_utils import synth_learner
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [None]:
# default_exp data.block

In [None]:
# export
# Python native modules
import os
from collections import deque
from time import sleep
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from torch.utils.data import Dataset
from torch import nn
import torch
import gym
import numpy as np
# Local modules
from fastrl.core import *

# Data Block
> Fastrl transforms for iterating through environments

In [None]:
# export
class DQN(Module):
    def __init__(self):
        self.policy=nn.Sequential(
            nn.Linear(4,50),
            nn.ReLU(),
            nn.Linear(50,2),
            nn.ReLU()
        )
    
    def forward(self,x): 
        return torch.argmax(self.policy(x),dim=0)

Development of this was helped by [IterableData documentation on multiple workers](https://github.com/pytorch/pytorch/blob/4949eea0ffb60dc81a0a78402fa59fdf68206718/torch/utils/data/dataset.py#L64)

This code is heavily modifed from https://github.com/Shmuma/ptan

Reference for env [semantics related to vectorized environments](https://github.com/openai/universe/blob/master/doc/env_semantics.rst)

Useful links:
- [torch multiprocessing](https://github.com/pytorch/pytorch/blob/a61a8d059efa0fb139a09e479b1a2c8dd1cf1a44/torch/utils/data/dataloader.py#L564)
- [torch worker](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/worker.py)

In [None]:
# export
def init_experience(but='',**kwargs): 
    "Returns dictionary with default values that can be overridden."
    experience=D(
        state=0,action=0,next_state=0,reward=0,done=False,
        step=0,env=0,image=0
    )
    for s in but.split(','):
        if s in experience: del experience[s]
    return BD.merge(experience,kwargs)

In [None]:
init_experience()

{'state': TensorBatch([[0]]),
 'action': TensorBatch([[0]]),
 'next_state': TensorBatch([[0]]),
 'reward': TensorBatch([[0]]),
 'done': TensorBatch([[False]]),
 'step': TensorBatch([[0]]),
 'env': TensorBatch([[0]]),
 'image': TensorBatch([[0]])}

In [None]:
init_experience()
init_experience(but='image,step')
sum([init_experience(),init_experience()],init_experience())

{'state': TensorBatch([[0],
         [0],
         [0]]),
 'action': TensorBatch([[0],
         [0],
         [0]]),
 'next_state': TensorBatch([[0],
         [0],
         [0]]),
 'reward': TensorBatch([[0],
         [0],
         [0]]),
 'done': TensorBatch([[False],
         [False],
         [False]]),
 'step': TensorBatch([[0],
         [0],
         [0]]),
 'env': TensorBatch([[0],
         [0],
         [0]]),
 'image': TensorBatch([[0],
         [0],
         [0]])}

This is a outlinming if a better way to setup a source. This will largely pattern after an learner.

In [None]:
# export
_events = L.split('after_create before_fit before_epoch before_train before_batch after_pred after_loss \
    before_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train \
    after_train before_validate after_cancel_validate after_validate after_cancel_epoch \
    after_epoch after_cancel_fit after_fit')

mk_class('event', **_events.map_dict(),
         doc="All possible events as attributes to get tab-completion and typo-proofing")

@funcs_kwargs(as_method=True)
class SourceCallback(Stateful,GetAttr):
    "Basic class handling tweaks of the training loop by changing a `Learner` in various events"
    order,_default,source,run,run_train,run_valid = 0,'source',None,True,True,True
    _methods = _events

    def __init__(self, **kwargs): assert not kwargs, f'Passed unknown events: {kwargs}'
    def __repr__(self): return type(self).__name__

    def __call__(self, event_name):
        "Call `self.{event_name}` if it's defined"
        _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
               (self.run_valid and not getattr(self, 'training', False)))
        res = None
        if self.run and _run: res = getattr(self, event_name, noop)()
        if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
        return res

    def __setattr__(self, name, value):
        if hasattr(self.source,name):
            warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.source.{name}` to avoid this")
        super().__setattr__(name, value)

    @property
    def name(self):
        "Name of the `Callback`, camel-cased and with '*Callback*' removed"
        return class2attr(self, 'Callback')

In [None]:
# export
_loop = ['Start Fit', 'before_fit', 'Start Epoch Loop', 'before_epoch', 'Start Train', 'before_train',
         'Start Batch Loop', 'before_batch', 'after_pred', 'after_loss', 'before_backward', 'before_step',
         'after_step', 'after_cancel_batch', 'after_batch','End Batch Loop','End Train',
         'after_cancel_train', 'after_train', 'Start Valid', 'before_validate','Start Batch Loop',
         '**CBs same as train batch**', 'End Batch Loop', 'End Valid', 'after_cancel_validate',
         'after_validate', 'End Epoch Loop', 'after_cancel_epoch', 'after_epoch', 'End Fit',
         'after_cancel_fit', 'after_fit']

class SourcePrototype(Stateful,Learner):
    _stateattrs=('pool',)
    def __init__(self,agent=None,cbs=None,seed:int=None,render=None,num_workers=0,but='',**kwargs):
        store_attr(but='cbs')
        self.env_kwargs=kwargs
        self.pool=L()
        self.cbs=L()
#         self.add_cbs(L(defaults.callbacks)+L(cbs)) # Look into what should be a default
        self.add_cbs(L(cbs))
    
    def show_training_loop(self):
        indent = 0
        for s in _loop:
            if s.startswith('Start'): print(f'{" "*indent}{s}'); indent += 2
            elif s.startswith('End'): indent -= 2; print(f'{" "*indent}{s}')
            else: print(f'{" "*indent} - {s:15}:', self.ordered_cbs(s))

In [None]:
SourcePrototype().show_training_loop()

Start Fit
   - before_fit     : []
  Start Epoch Loop
     - before_epoch   : []
    Start Train
       - before_train   : []
      Start Batch Loop
         - before_batch   : []
         - after_pred     : []
         - after_loss     : []
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : []
      End Batch Loop
    End Train
     - after_cancel_train: []
     - after_train    : []
    Start Valid
       - before_validate: []
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: []
     - after_validate : []
  End Epoch Loop
   - after_cancel_epoch: []
   - after_epoch    : []
End Fit
 - after_cancel_fit: []
 - after_fit      : []


In [None]:
# export
def _state2experience(s,**kwargs):   return init_experience(state=s,next_state=s,step=torch.zeros((1,1)),**kwargs)
def _env_reset(o):                   return o.reset()
def _env_seed(o,seed):               return o.seed(seed)
def _env_render(o,mode='rgb_array'): return TensorBatch(o.render(mode=mode).copy())
def _env_step(o,*args,**kwargs):     return o.step(*args,**kwargs)

def cast_dtype(t,dtype): 
    if dtype==torch.float:    return t.float()
    elif dtype==torch.double: return t.double()
    elif dtype==torch.long:   return t.long()

class FakeAgent:
    def __init__(self,action_space): store_attr()
    def __call__(self,state,**kwargs):
        return (L([self.action_space.sample() for _ in range(state.shape[0])]),
                D(merge(kwargs,{'random_action':np.random.randint(0,3,(state.shape[0],1))})))

class ExperienceSource(Stateful):
    _stateattrs=('pool',)
    def __init__(self,env:str,agent=None,n_envs:int=1,steps_count:int=1,steps_delta:int=1,
                 seed:int=None,render=None,num_workers=0,but='',**kwargs):
        store_attr()
        self.env_kwargs=kwargs
        self.pool=L()
        if self.render is None: self.but+=',image'

    def _init_state(self):
        "Inits the histories, experiences, and the environment pool when sent to a `Process`"
        self.history,self.pool=L((deque(maxlen=self.steps_count),
                                  gym.make(self.env,**self.env_kwargs)) 
                                  for _ in range(self.n_envs)).zip().map(L) 
        self.pool.map(_env_seed,seed=self.seed)
        if self.agent is None: self.agent=FakeAgent(self.pool[0].action_space)
        self.reset_all()
        
    def reset_all(self):
        self.experiences=self.pool.map(_env_reset)
        self.experiences=self.experiences.map(_state2experience,but=self.but)
        self.experiences=sum(self.experiences)
        self.attempt_render(self.experiences)
        
    def attempt_render(self,experiences,indexes=None):
        if self.render is not None: 
            pool=self.pool if indexes is None else self.pool[indexes]
            renders=pool.map(_env_render,mode=self.render)
            # No idea why we have to do this, but multiprocessing hangs forever otherwise
            if self.num_workers>0:sleep(0.1) 
            experiences['image']=torch.stack(tuple(renders)).unsqueeze(0)

    def __iter__(self):
        "Iterates through a list of environments."
        if not self.pool:self._init_state()
        while True: 
            # Only work on envs that are not done
            not_done_idxs=(self.experiences['done']==False).nonzero()[:,0]
            if len(not_done_idxs)==0: 
                self.reset_all()
                not_done_idxs=(self.experiences['done']==False).nonzero()[:,0]
            not_done_idxs=not_done_idxs.reshape(-1,)
            not_done_experiences=self.experiences[not_done_idxs]
            # Pass current experiences into agent
            actions,experiences=self.agent(**not_done_experiences)
            # Step through all envs.
            step_res=self.pool[not_done_idxs].zipwith(actions).starmap(_env_step)
            next_states,rewards,dones=step_res.zip()[:3].map(TensorBatch)
            rewards,dones=(v.reshape(len(not_done_idxs),-1) for v in (rewards,dones))
            # Add the image field if available
            self.attempt_render(self.experiences,not_done_idxs)
            new_exp=BD(next_state=next_states,reward=rewards,done=dones,
                       env=not_done_idxs.reshape(not_done_experiences.bs,-1),
                       step=not_done_experiences['step']+1)

            experiences=BD.merge(not_done_experiences,experiences,new_exp)
            for i,idx in enumerate(not_done_idxs): 
                self.history[idx].append(experiences[i])                
                if len(self.history[idx])==self.steps_count and \
                       int(experiences[i]['step'][0])%self.steps_delta==0:
                    yield sum(self.history[idx])

                if bool(experiences[i]['done'][0]):
                    if 0<len(self.history[idx])<self.steps_count:
                        yield sum(self.history[idx])
                    while len(self.history[idx])>1:
                        self.history[idx].popleft()
                        yield sum(self.history[idx])
            experiences['state']=experiences['next_state']

            for k in experiences:
                dtype=experiences[k].dtype
                if k not in self.experiences:
                    self.experiences[k]=TensorBatch(torch.zeros(self.experiences.bs,
                                                    *experiences[k].shape[1:]))
                if self.experiences[k].dtype!=dtype:
                    self.experiences[k]=cast_dtype(self.experiences[k],dtype)
                self.experiences[k][not_done_idxs]=experiences[k]
            
add_docs(ExperienceSource,
        """Iterates through `n_envs` of `env` feeding experience or states into `agent`.
           If `agent` is None, then random actions will be taken instead.
           It will return `steps_count` experiences every `steps_delta`.
           At the end of an env, it will return `steps_count-1` experiences per next. """,
        reset_all="resets the envs and experience",
        attempt_render="Updates `experiences` with images if `render is not None`. Optionally indexes can be passed.")

In [None]:
# export
class SourceDataset(IterableDataset):
    "Iterates through a `source` object. Allows for re-initing source connections when `num_workers>0`"
    def __init__(self,source=None): store_attr('source')
    def __iter__(self):             return iter(self.source)
    def wif(self):                  self.source._init_state()

In [None]:
def train_loop(source,num_workers=2):
    dataset=SourceDataset(source)
    data=None
    for x in DataLoader(dataset,num_workers=num_workers,n=50,persistent_workers=True,wif=dataset.wif):
        data=x if data is None else data+x
    return data

In [None]:
source=ExperienceSource('CartPole-v1',None,steps_count=3,n_envs=1)
data=train_loop(source,num_workers=0)
data.pandas().head(10)

Unnamed: 0,state,action,next_state,reward,done,step,env,random_action
0,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,1.0,0,0
1,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,2.0,0,0
2,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,3.0,0,2
3,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,2.0,0,0
4,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,3.0,0,2
5,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,4.0,0,2
6,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,3.0,0,2
7,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,4.0,0,2
8,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,5.0,0,0
9,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,4.0,0,2


In [None]:
source=ExperienceSource('CartPole-v1',None,steps_count=3,n_envs=2)
data=train_loop(source,num_workers=0)
data.pandas().head(10)

Unnamed: 0,state,action,next_state,reward,done,step,env,random_action
0,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,1.0,0,1
1,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,2.0,0,0
2,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,3.0,0,0
3,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,1.0,1,0
4,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,2.0,1,2
5,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,3.0,1,0
6,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,2.0,0,0
7,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,3.0,0,0
8,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,4.0,0,2
9,"torch.Size([144, 4])",0,"torch.Size([144, 4])",1.0,False,2.0,1,2


In [None]:
source=ExperienceSource('CartPole-v1',None,n_envs=3)
data=train_loop(source)
data.pandas().head(10)

Unnamed: 0,state,action,next_state,reward,done,step,env,random_action
0,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,1.0,0,0
1,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,1.0,0,0
2,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,1.0,1,2
3,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,1.0,1,0
4,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,1.0,2,1
5,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,1.0,2,2
6,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,2.0,0,1
7,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,2.0,0,1
8,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,2.0,1,0
9,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,2.0,1,2


In [None]:
source=ExperienceSource('CartPole-v1',None,n_envs=1,steps_delta=3)
data=train_loop(source)
data.pandas().head(10)

Unnamed: 0,state,action,next_state,reward,done,step,env,random_action
0,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,3.0,0,0
1,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,3.0,0,1
2,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,6.0,0,0
3,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,6.0,0,0
4,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,9.0,0,2
5,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,9.0,0,2
6,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,12.0,0,2
7,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,12.0,0,2
8,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,15.0,0,2
9,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,15.0,0,1


In [None]:
source=ExperienceSource('CartPole-v1',None,n_envs=1)
data=train_loop(source,num_workers=0)
data.pandas().head(10)

Unnamed: 0,state,action,next_state,reward,done,step,env,random_action
0,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,1.0,0,2
1,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,2.0,0,1
2,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,3.0,0,2
3,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,4.0,0,2
4,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,5.0,0,0
5,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,6.0,0,1
6,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,7.0,0,0
7,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,8.0,0,0
8,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,9.0,0,0
9,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,10.0,0,2


> Note: I wonder if `BD` would benefit from using `default_collate`?

In [None]:
source=ExperienceSource('CartPole-v1',None,n_envs=1,render='rgb_array')
data=train_loop(source,num_workers=0)
data.pandas()

Unnamed: 0,state,action,next_state,reward,done,step,env,image,random_action
0,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,1.0,0,"torch.Size([50, 1, 400, 600, 3])",0
1,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,2.0,0,"torch.Size([50, 1, 400, 600, 3])",1
2,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,3.0,0,"torch.Size([50, 1, 400, 600, 3])",1
3,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,4.0,0,"torch.Size([50, 1, 400, 600, 3])",0
4,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,5.0,0,"torch.Size([50, 1, 400, 600, 3])",1
5,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,6.0,0,"torch.Size([50, 1, 400, 600, 3])",2
6,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,7.0,0,"torch.Size([50, 1, 400, 600, 3])",0
7,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,8.0,0,"torch.Size([50, 1, 400, 600, 3])",1
8,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,9.0,0,"torch.Size([50, 1, 400, 600, 3])",0
9,"torch.Size([50, 4])",0,"torch.Size([50, 4])",1.0,False,10.0,0,"torch.Size([50, 1, 400, 600, 3])",1


`ExperienceSource` is designed for iterating through `n_envs` environments.

A single experience is a `BD` which is a subclass of `dict` however will batch indexing:

In [None]:
del data['image']
data[0]

{'state': TensorBatch([[-0.0447, -0.0040, -0.0132,  0.0191]]),
 'action': TensorBatch([[0]]),
 'next_state': TensorBatch([[-0.0447, -0.1989, -0.0128,  0.3076]]),
 'reward': TensorBatch([[1.]]),
 'done': TensorBatch([[False]]),
 'step': TensorBatch([[1.]]),
 'env': TensorBatch([[0]]),
 'random_action': TensorBatch([[0]])}

However, an agent has full power to add fields to this dict wile running

In [None]:
if not in_colab(): show_doc(ExperienceSource._init_state)

<h4 id="ExperienceSource._init_state" class="doc_header"><code>ExperienceSource._init_state</code><a href="__main__.py#L28" class="source_link" style="float:right">[source]</a></h4>

> <code>ExperienceSource._init_state</code>()

Inits the histories, experiences, and the environment pool when sent to a `Process`

In [None]:
if not in_colab(): show_doc(ExperienceSource.__iter__)

<h4 id="ExperienceSource.__iter__" class="doc_header"><code>ExperienceSource.__iter__</code><a href="__main__.py#L51" class="source_link" style="float:right">[source]</a></h4>

> <code>ExperienceSource.__iter__</code>()

Iterates through a list of environments.

If the `self.pool` field is empty, it will call `_init_state` to reinitialize everything.

In [None]:
# export
class FirstLastExperienceSource(ExperienceSource):
    gamma=0.99
    def __iter__(self):
        for res in super().__iter__():
#             print(res)
            element,remainder=res[0],{} if res.bs==1 else res[1:]
            reward=element['reward']
            if 'reward' in remainder:
                for e in reversed(remainder['reward']):
        #                 print(e)
                    reward*=self.gamma
                    reward+=e
            element.bs=1
#             print(element,element.bs)
            yield element

In [None]:
source=FirstLastExperienceSource('CartPole-v1',None,steps_count=3,steps_delta=2,n_envs=2)
data=train_loop(source,num_workers=2)
data.pandas().head(10)

Unnamed: 0,state,action,next_state,reward,done,step,env,random_action
0,"torch.Size([50, 4])",0,"torch.Size([50, 4])",2.9701,False,2.0,0,1
1,"torch.Size([50, 4])",0,"torch.Size([50, 4])",2.9701,False,2.0,0,1
2,"torch.Size([50, 4])",0,"torch.Size([50, 4])",2.9701,False,2.0,1,2
3,"torch.Size([50, 4])",0,"torch.Size([50, 4])",2.9701,False,2.0,1,0
4,"torch.Size([50, 4])",0,"torch.Size([50, 4])",2.9701,False,4.0,0,1
5,"torch.Size([50, 4])",0,"torch.Size([50, 4])",2.9701,False,4.0,0,1
6,"torch.Size([50, 4])",0,"torch.Size([50, 4])",2.9701,False,4.0,1,2
7,"torch.Size([50, 4])",0,"torch.Size([50, 4])",2.9701,False,4.0,1,1
8,"torch.Size([50, 4])",0,"torch.Size([50, 4])",2.9701,False,6.0,0,2
9,"torch.Size([50, 4])",0,"torch.Size([50, 4])",2.9701,False,6.0,0,0


## Fastai integration

In [None]:
# export
class IterableTfmdLists(TfmdLists):
    def _after_item(self, o): return self.tfms(next(o))

In [None]:
# export
class IterableDataBlock(DataBlock):
    @delegates(DataBlock)
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
        self.tl_type=IterableTfmdLists
        
    def datasets(self, source, verbose=False):
        self.source = source                     ; pv(f"Collecting items from {source}", verbose)
        items = (self.get_items or noop)(source) ; pv(f"Found {len(items)} items", verbose)
        splits = (self.splitter or RandomSplitter())(items)
        pv(f"{len(splits)} datasets of sizes {','.join([str(len(s)) for s in splits])}", verbose)
        return Datasets(items, tfms=self._combine_type_tfms(), splits=splits, dl_type=self.dl_type, n_inp=self.n_inp, verbose=verbose,
                        tl_type=self.tl_type)

In [None]:
# export
@patch
def __iter__(self:Datasets):
    for i in cycle(range(len(self))): yield self[i]

@patch
def __init__(self:Datasets, items=None, tfms=None, tls=None, n_inp=None, dl_type=None,tl_type=TfmdLists, **kwargs):
    super(Datasets,self).__init__(dl_type=dl_type)
    self.tls = L(tls if tls else [tl_type(items, t, **kwargs) for t in L(ifnone(tfms,[None]))])
    self.n_inp = ifnone(n_inp, max(1, len(self.tls)-1))

In [None]:
dqn=DQN().share_memory()
opt=Adam(dqn.parameters(),lr=0.01)

source=ExperienceSource('CartPole-v1',None,steps_count=1,n_envs=1)
# dataset=SourceDataset(source)
gym_block=IterableDataBlock(
    splitter=lambda o:[[0]],
)

dls=gym_block.dataloaders([iter(source)],n=50,bs=5,indexed=False,shuffle=False,num_workers=0,wif=source._init_state(),verbose=True)

for x in dls[0]:pass

Collecting items from [<generator object ExperienceSource.__iter__ at 0x7f1ceb4ec8d0>]
Found 1 items
1 datasets of sizes 1
Setting up Pipeline: 
Setting up Pipeline: 
Setting up after_item: Pipeline: ToTensor
Setting up before_batch: Pipeline: 
Setting up after_batch: Pipeline: 


In [None]:
learn=Learner(model=dqn,dls=dls,loss_func=nn.MSELoss)

In [None]:
# learn.fit(2,lr=0.01)

## Transforms
> Batch correct

In [None]:
# export
class BatchRepair(Transform):
    def encodes(self,d:(dict,D,BD)):
        # If the bs is 1, check if all the shapes in the dict are if shape [1,1,...]
        # If so, then we need to reduce a dimension
        if all([len(o.shape)>2 and sum(o.shape[:2])==2 for o in d.values()]):
            d=BD(d).mapv(Self.squeeze(0))
        # If the bs is not 1 but there is a dim 1 in the index 1, and all the shapes
        # have more than 2 dimensions, then this means that shape index 1 shape needs to be
        # fixed. e.g.: shape [5,1,...]
        if all([len(o.shape)>2 and o.shape[0]!=1 and o.shape[1]==1 for o in d.values()]):
            d=BD(d).mapv(Self.squeeze(1))
        return BD(d)

In [None]:
dqn=DQN().share_memory()
opt=Adam(dqn.parameters(),lr=0.01)

source=ExperienceSource('CartPole-v1',None,steps_count=1,n_envs=1,render='rgb_array')
# dataset=SourceDataset(source)
gym_block=IterableDataBlock(
    blocks=(TransformBlock(batch_tfms=[BatchRepair]),),
    splitter=lambda o:[[0]]
)

dls=gym_block.dataloaders([iter(source)],n=90,bs=64,indexed=False,shuffle=False,num_workers=0,wif=source._init_state(),verbose=True)

data=None
for x in dls[0]: data=x[0] if data is None else data+x[0]
data.pandas()

Collecting items from [<generator object ExperienceSource.__iter__ at 0x7f1c46c978d0>]
Found 1 items
1 datasets of sizes 1
Setting up Pipeline: 
Setting up after_item: Pipeline: ToTensor
Setting up before_batch: Pipeline: 
Setting up after_batch: Pipeline: BatchRepair


Unnamed: 0,state,action,next_state,reward,done,step,env,image,random_action
0,"torch.Size([90, 4])",0,"torch.Size([90, 4])",1.0,False,1.0,0,"torch.Size([90, 1, 400, 600, 3])",0
1,"torch.Size([90, 4])",0,"torch.Size([90, 4])",1.0,False,2.0,0,"torch.Size([90, 1, 400, 600, 3])",2
2,"torch.Size([90, 4])",0,"torch.Size([90, 4])",1.0,False,3.0,0,"torch.Size([90, 1, 400, 600, 3])",2
3,"torch.Size([90, 4])",0,"torch.Size([90, 4])",1.0,False,4.0,0,"torch.Size([90, 1, 400, 600, 3])",1
4,"torch.Size([90, 4])",0,"torch.Size([90, 4])",1.0,False,5.0,0,"torch.Size([90, 1, 400, 600, 3])",2
...,...,...,...,...,...,...,...,...,...
85,"torch.Size([90, 4])",0,"torch.Size([90, 4])",1.0,False,10.0,0,"torch.Size([90, 1, 400, 600, 3])",0
86,"torch.Size([90, 4])",0,"torch.Size([90, 4])",1.0,False,11.0,0,"torch.Size([90, 1, 400, 600, 3])",0
87,"torch.Size([90, 4])",0,"torch.Size([90, 4])",1.0,False,12.0,0,"torch.Size([90, 1, 400, 600, 3])",2
88,"torch.Size([90, 4])",0,"torch.Size([90, 4])",1.0,False,13.0,0,"torch.Size([90, 1, 400, 600, 3])",1


In [None]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():   
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import make_readme
    make_readme()
    notebook2script()
    notebook2html()

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
Converted 00_core.ipynb.
Converted 00_nbdev_extension.ipynb.
Converted 04_callback.core.ipynb.
Converted 05_data.block.ipynb.
Converted 05_data.test_async.ipynb.
Converted 20_test_utils.ipynb.
Converted index.ipynb.
Converted nbdev_template.ipynb.
No notebooks were modified
