In [None]:
#hide
#skip
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 

In [None]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [None]:
# default_exp data.block

In [None]:
# export
# Python native modules
import os
from collections import deque
from time import sleep
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from torch.utils.data import Dataset
from torch import nn
import torch

# Local modules
from fastrl.core import *

In [None]:
import numpy as np
import gym
import time,sys
import torch.multiprocessing as mp
import pandas as pd


# Data Block
> Fastrl transforms for iterating through environments

In [None]:
# export
class DQN(Module):
    def __init__(self):
        self.policy=nn.Sequential(
            nn.Linear(4,50),
            nn.ReLU(),
            nn.Linear(50,2),
            nn.ReLU()
        )
    
    def forward(self,x): 
        return torch.argmax(self.policy(x),dim=0)

Development of this was helped by [IterableData documentation on multiple workers](https://github.com/pytorch/pytorch/blob/4949eea0ffb60dc81a0a78402fa59fdf68206718/torch/utils/data/dataset.py#L64)

This code is heavily modifed from https://github.com/Shmuma/ptan

Reference for env [semantics related to vectorized environments](https://github.com/openai/universe/blob/master/doc/env_semantics.rst)

Useful links:
- [torch multiprocessing](https://github.com/pytorch/pytorch/blob/a61a8d059efa0fb139a09e479b1a2c8dd1cf1a44/torch/utils/data/dataloader.py#L564)
- [torch worker](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/worker.py)

In [None]:
# export
def o2tensor_batch(o): 
    if not isinstance(o,Tensor): o=Tensor(o if is_listy(o) or isinstance(o,np.ndarray) else [o])
    if o.size()[0]==1 and len(o.size())>1: return o
    return o.unsqueeze(0)

def init_experience(but='',**kwargs): 
    "Returns dictionary with default values that can be overridden."
    experience=D(
        state=0,action=0,next_state=0,reward=0,done=False,
        step=0,env=0,image=0
    )
    for s in but.split(','):
        if s in experience:del experience[s]
    return BD(D(merge(experience,kwargs)).mapv(o2tensor_batch))

In [None]:
init_experience()

{'state': tensor([[0.]]),
 'action': tensor([[0.]]),
 'next_state': tensor([[0.]]),
 'reward': tensor([[0.]]),
 'done': tensor([[0.]]),
 'step': tensor([[0.]]),
 'env': tensor([[0.]]),
 'image': tensor([[0.]])}

In [None]:
init_experience(but='image,step')

{'state': tensor([[0.]]),
 'action': tensor([[0.]]),
 'next_state': tensor([[0.]]),
 'reward': tensor([[0.]]),
 'done': tensor([[0.]]),
 'env': tensor([[0.]])}

In [None]:
sum([init_experience(),init_experience()],init_experience())

{'state': tensor([[0.],
         [0.],
         [0.]]),
 'action': tensor([[0.],
         [0.],
         [0.]]),
 'next_state': tensor([[0.],
         [0.],
         [0.]]),
 'reward': tensor([[0.],
         [0.],
         [0.]]),
 'done': tensor([[0.],
         [0.],
         [0.]]),
 'step': tensor([[0.],
         [0.],
         [0.]]),
 'env': tensor([[0.],
         [0.],
         [0.]]),
 'image': tensor([[0.],
         [0.],
         [0.]])}

In [None]:
# export
def _state2experience(s,**kwargs):   return init_experience(state=s,next_state=s,step=torch.zeros((1,1)),**kwargs)
def _env_reset(o):                   return o.reset()
def _env_seed(o,seed):               return o.seed(seed)
def _env_render(o,mode='rgb_array'): return Tensor(o.render(mode=mode).copy())
def _env_step(o,*args,**kwargs):     return o.step(*args,**kwargs)

def cast_dtype(t,dtype): 
    if dtype==torch.float:    return t.float()
    elif dtype==torch.double: return t.double()
    elif dtype==torch.long:   return t.long()

class FakeAgent:
    def __init__(self,action_space): store_attr()
    def __call__(self,state,**kwargs):
        return (L([self.action_space.sample() for _ in range(state.shape[0])]),
                D(merge(kwargs,{'random_action':np.random.randint(0,3,(state.shape[0],1))})))

class ExperienceSource(Stateful):
    _stateattrs=('pool',)
    def __init__(self,env:str,agent=None,n_envs:int=1,steps_count:int=1,steps_delta:int=1,
                 seed:int=None,render=None,num_workers=0,but='',**kwargs):
        store_attr()
        self.env_kwargs=kwargs
        self.pool=L()
        if self.render is None: self.but+=',image'

    def _init_state(self):
        "Inits the histories, experiences, and the environment pool when sent to a `Process`"
        self.history,self.pool=L((deque(maxlen=self.steps_count),
                                  gym.make(self.env,**self.env_kwargs)) 
                                  for _ in range(self.n_envs)).zip().map(L) 
        self.pool.map(_env_seed,seed=self.seed)
        if self.agent is None: self.agent=FakeAgent(self.pool[0].action_space)
        self.reset_all()
        
    def reset_all(self):
        self.experiences=self.pool.map(_env_reset)
        self.experiences=self.experiences.map(_state2experience,but=self.but)
        self.experiences=sum(self.experiences[1:],self.experiences[0])
        self.attempt_render(self.experiences)
        
    def attempt_render(self,experiences,indexes=None):
        if self.render is not None: 
            pool=self.pool if indexes is None else self.pool[indexes]
            renders=pool.map(_env_render,mode=self.render)
            # No idea why we have to do this, but multiprocessing hangs forever otherwise
            if self.num_workers>0:sleep(0.1) 
            experiences['image']=torch.stack(tuple(renders)).unsqueeze(0)

    def __iter__(self):
        "Iterates through a list of environments."
        if not self.pool:self._init_state()
        while True: 
            try:
                # Only work on envs that are not done
                not_done_idxs=(self.experiences['done']==False).nonzero()[:,0]
                if len(not_done_idxs)==0: 
                    self.reset_all()
                    not_done_idxs=(self.experiences['done']==False).nonzero()[:,0]
                not_done_experiences=self.experiences[not_done_idxs]
                # Pass current experiences into agent
                actions,experiences=self.agent(**not_done_experiences)
                # Step through all envs.
                step_res=self.pool[not_done_idxs].zipwith(actions).starmap(_env_step)
                next_states,rewards,dones=step_res.zip()[:3].map(Tensor)
                rewards,dones=(v.reshape(len(not_done_idxs),-1) for v in(rewards,dones))
                # Add the image field if available
                self.attempt_render(self.experiences,not_done_idxs)
                new_exp=BD(next_state=next_states,reward=rewards,done=dones,
                           env=not_done_idxs.reshape(len(not_done_idxs),-1),
                           step=not_done_experiences['step']+1,
                           bd_batch_size=len(not_done_idxs))
                
                experiences=BD(merge(not_done_experiences,experiences,new_exp),
                               bd_batch_size=len(not_done_idxs))

                for idx in not_done_idxs: 
                    self.history[idx].append(experiences[idx])                
                    if len(self.history[idx])==self.steps_count and \
                           int(experiences[idx]['step'][0])%self.steps_delta==0:
                        yield tuple(self.history[idx])

                    if bool(experiences[idx]['done'][0]):
                        if 0<len(self.history[idx])<self.steps_count:
                            yield tuple(self.history[idx])
                        while len(self.history[idx])>1:
                            self.history[idx].popleft()
                            yield tuple(self.history[idx])

                for k in experiences:
                    dtype=experiences[k][not_done_idxs].dtype
                    if k not in self.experiences:
                        self.experiences[k]=torch.zeros(self.experiences.bs(),
                                                        *experiences[k][not_done_idxs].shape[1:])
                    if self.experiences[k][not_done_idxs].dtype!=dtype:
                        self.experiences[k]=cast_dtype(self.experiences[k],dtype)
                    self.experiences[k][not_done_idxs]=experiences[k][not_done_idxs]
            except ValueError:
                self.reset_all()
            
add_docs(ExperienceSource,
        """Iterates through `n_envs` of `env` feeding experience or states into `agent`.
           If `agent` is None, then random actions will be taken instead.
           It will return `steps_count` experiences every `steps_delta`.
           At the end of an env, it will return `steps_count-1` experiences per next. """,
        reset_all="resets the envs and experience",
        attempt_render="Updates `experiences` with images if `render is not None`. Optionally indexes can be passed.")

In [None]:
# export
class SourceDataset(IterableDataset):
    "Iterates through a `source` object. Allows for re-initing source connections when `num_workers>0`"
    def __init__(self,source=None): store_attr('source')
    def __iter__(self):             return iter(self.source)
    def wif(self):                  self.source._init_state()

In [None]:
source=ExperienceSource('CartPole-v1',None,n_envs=1)
dataset=SourceDataset(source)

data=None
for x in DataLoader(dataset,num_workers=2,n=50,persistent_workers=True,wif=dataset.wif):
    data=BD(*x) if data is None else data+BD(*x)
data.pandas()

Unnamed: 0,state,action,next_state,reward,done,step,env,random_action
0,"torch.Size([50, 4])",0.0,"torch.Size([50, 4])",1.0,0.0,1.0,0,1
1,"torch.Size([50, 4])",0.0,"torch.Size([50, 4])",1.0,0.0,1.0,0,1
2,"torch.Size([50, 4])",0.0,"torch.Size([50, 4])",1.0,0.0,2.0,0,1
3,"torch.Size([50, 4])",0.0,"torch.Size([50, 4])",1.0,0.0,2.0,0,1
4,"torch.Size([50, 4])",0.0,"torch.Size([50, 4])",1.0,0.0,3.0,0,2
5,"torch.Size([50, 4])",0.0,"torch.Size([50, 4])",1.0,0.0,3.0,0,0
6,"torch.Size([50, 4])",0.0,"torch.Size([50, 4])",1.0,0.0,4.0,0,2
7,"torch.Size([50, 4])",0.0,"torch.Size([50, 4])",1.0,0.0,4.0,0,1
8,"torch.Size([50, 4])",0.0,"torch.Size([50, 4])",1.0,0.0,5.0,0,2
9,"torch.Size([50, 4])",0.0,"torch.Size([50, 4])",1.0,0.0,5.0,0,2


> Note: I wonder if `BD` would benefit from using `default_collate`?

In [None]:
default_collate??

[0;31mSignature:[0m [0mdefault_collate[0m[0;34m([0m[0mbatch[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mdefault_collate[0m[0;34m([0m[0mbatch[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34mr"""Puts each data field into a tensor with outer dimension batch size"""[0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m    [0melem[0m [0;34m=[0m [0mbatch[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m
[0;34m[0m    [0melem_type[0m [0;34m=[0m [0mtype[0m[0;34m([0m[0melem[0m[0;34m)[0m[0;34m[0m
[0;34m[0m    [0;32mif[0m [0misinstance[0m[0;34m([0m[0melem[0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mout[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m
[0;34m[0m        [0;32mif[0m [0mtorch[0m[0;34m.[0m[0mutils[0m[0;34m.[0m[0mdata[0m[0;34m.[0m[0mget_worker_info[0m[0;34m([0m[0;34m)[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0

In [None]:
source=ExperienceSource('CartPole-v1',None,n_envs=1,render='rgb_array')
dataset=SourceDataset(source)

data=None
for x in DataLoader(dataset,num_workers=2,n=10,persistent_workers=True,wif=dataset.wif):
    data=BD(*x) if data is None else data+BD(*x)
data.pandas()

Unnamed: 0,state,action,next_state,reward,done,step,env,image,random_action
0,"torch.Size([10, 4])",0.0,"torch.Size([10, 4])",1.0,0.0,1.0,0,"torch.Size([10, 400, 600, 3])",2
1,"torch.Size([10, 4])",0.0,"torch.Size([10, 4])",1.0,0.0,1.0,0,"torch.Size([10, 400, 600, 3])",0
2,"torch.Size([10, 4])",0.0,"torch.Size([10, 4])",1.0,0.0,2.0,0,"torch.Size([10, 400, 600, 3])",1
3,"torch.Size([10, 4])",0.0,"torch.Size([10, 4])",1.0,0.0,2.0,0,"torch.Size([10, 400, 600, 3])",0
4,"torch.Size([10, 4])",0.0,"torch.Size([10, 4])",1.0,0.0,3.0,0,"torch.Size([10, 400, 600, 3])",1
5,"torch.Size([10, 4])",0.0,"torch.Size([10, 4])",1.0,0.0,3.0,0,"torch.Size([10, 400, 600, 3])",0
6,"torch.Size([10, 4])",0.0,"torch.Size([10, 4])",1.0,0.0,4.0,0,"torch.Size([10, 400, 600, 3])",2
7,"torch.Size([10, 4])",0.0,"torch.Size([10, 4])",1.0,0.0,4.0,0,"torch.Size([10, 400, 600, 3])",1
8,"torch.Size([10, 4])",0.0,"torch.Size([10, 4])",1.0,0.0,5.0,0,"torch.Size([10, 400, 600, 3])",1
9,"torch.Size([10, 4])",0.0,"torch.Size([10, 4])",1.0,0.0,5.0,0,"torch.Size([10, 400, 600, 3])",0


`ExperienceSource` is designed for iterating through `n_envs` environments.

A single experience is a `BD` which is a subclass of `dict` however will batch indexing:

In [None]:
del data['image']
data[0]

{'state': tensor([[-0.0213, -0.0315, -0.0271,  0.0376]]),
 'action': tensor([[0.]]),
 'next_state': tensor([[-0.0220,  0.1640, -0.0263, -0.2635]]),
 'reward': tensor([[1.]]),
 'done': tensor([[0.]]),
 'step': tensor([[1.]]),
 'env': tensor([[0]]),
 'random_action': tensor([[2]])}

However, an agent has full power to add fields to this dict wile running

In [None]:
show_doc(ExperienceSource._init_state)

<h4 id="ExperienceSource._init_state" class="doc_header"><code>ExperienceSource._init_state</code><a href="__main__.py#L28" class="source_link" style="float:right">[source]</a></h4>

> <code>ExperienceSource._init_state</code>()

Inits the histories, experiences, and the environment pool when sent to a `Process`

In [None]:
show_doc(ExperienceSource.__iter__)

<h4 id="ExperienceSource.__iter__" class="doc_header"><code>ExperienceSource.__iter__</code><a href="__main__.py#L51" class="source_link" style="float:right">[source]</a></h4>

> <code>ExperienceSource.__iter__</code>()

Iterates through a list of environments.

If the `self.pool` field is empty, it will call `_init_state` to reinitialize everything.

In [None]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():   
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import make_readme
    make_readme()
    notebook2script()
    notebook2html()

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
Converted .data.block_old.ipynb.
Converted 00_core.ipynb.
Converted 00_nbdev_extension.ipynb.
Converted 05_data.block.ipynb.
Converted 05_data.test_async.ipynb.
Converted 20_test_utils.ipynb.
Converted index.ipynb.
Converted nbdev_template.ipynb.
converting: /home/fastrl_user/fastrl/nbs/05_data.block.ipynb
