In [None]:
#hide
#skip
! [ -e /content ] && pip install -Uqq fastrl['dev']  # upgrade fastrl on colab

In [None]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON

In [None]:
# default_exp data.block

In [None]:
# export
# Python native modules
import os
from collections import deque
from time import sleep
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from torch.utils.data import Dataset
from torch import nn
import torch

# Local modules
from fastrl.core import *

In [None]:
import numpy as np
import gym
import time,sys
import torch.multiprocessing as mp
import pandas as pd


# Data Block
> Fastrl transforms for iterating through environments

In [None]:
# export
class DQN(Module):
    def __init__(self):
        self.policy=nn.Sequential(
            nn.Linear(4,50),
            nn.ReLU(),
            nn.Linear(50,2),
            nn.ReLU()
        )
    
    def forward(self,x): 
        return torch.argmax(self.policy(x),dim=0)

Development of this was helped by [IterableData documentation on multiple workers](https://github.com/pytorch/pytorch/blob/4949eea0ffb60dc81a0a78402fa59fdf68206718/torch/utils/data/dataset.py#L64)

This code is heavily modifed from https://github.com/Shmuma/ptan

Reference for env [semantics related to vectorized environments](https://github.com/openai/universe/blob/master/doc/env_semantics.rst)

Useful links:
- [torch multiprocessing](https://github.com/pytorch/pytorch/blob/a61a8d059efa0fb139a09e479b1a2c8dd1cf1a44/torch/utils/data/dataloader.py#L564)
- [torch worker](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/worker.py)

In [None]:
# export
def o2tensor_batch(o): 
    if not isinstance(o,Tensor): o=Tensor(o if is_listy(o) or isinstance(o,np.ndarray) else [o])
    if o.size()[0]==1 and len(o.size())>1: return o
    return o.unsqueeze(0)

def init_experience(but='',**kwargs): 
    "Returns dictionary with default values that can be overridden."
    experience=D(
        state=0,action=0,next_state=0,reward=0,done=False,
        step=0,env=0,image=0
    )
    for s in but.split(','):
        if s in experience:del experience[s]
    return D(merge(experience,kwargs)).mapv(o2tensor_batch)

In [None]:
init_experience()

{'state': tensor([[0.]]),
 'action': tensor([[0.]]),
 'next_state': tensor([[0.]]),
 'reward': tensor([[0.]]),
 'done': tensor([[0.]]),
 'step': tensor([[0.]]),
 'env': tensor([[0.]]),
 'image': tensor([[0.]])}

In [None]:
init_experience(but='image,step')

{'state': tensor([[0.]]),
 'action': tensor([[0.]]),
 'next_state': tensor([[0.]]),
 'reward': tensor([[0.]]),
 'done': tensor([[0.]]),
 'env': tensor([[0.]])}

In [None]:
sum([init_experience(),init_experience()],init_experience())

{'state': tensor([[0.],
         [0.],
         [0.]]),
 'action': tensor([[0.],
         [0.],
         [0.]]),
 'next_state': tensor([[0.],
         [0.],
         [0.]]),
 'reward': tensor([[0.],
         [0.],
         [0.]]),
 'done': tensor([[0.],
         [0.],
         [0.]]),
 'step': tensor([[0.],
         [0.],
         [0.]]),
 'env': tensor([[0.],
         [0.],
         [0.]]),
 'image': tensor([[0.],
         [0.],
         [0.]])}

In [None]:
# export
def _state2experience(s,**kwargs):   return init_experience(state=s,step=torch.zeros((1,1)),**kwargs)
def _env_reset(o):                   return o.reset()
def _env_seed(o,seed):               return o.seed(seed)
def _env_render(o,mode='rgb_array'): return Tensor(o.render(mode=mode).copy())
def _env_step(o,*args,**kwargs):     return o.step(*args,**kwargs)

class FakeAgent:
    def __init__(self,action_space): store_attr()
    def __call__(self,state,**kwargs):
        return L([self.action_space.sample() for _ in range(state.shape[0])]),D(kwargs)

class ExperienceSource(Stateful):
    _stateattrs=('pool',)
    def __init__(self,env:str,agent=None,n_envs:int=1,steps_count:int=1,steps_delta:int=1,
                 seed:int=None,render=None,num_workers=0,but='',**kwargs):
        store_attr()
        self.env_kwargs=kwargs
        self.pool=L()
        if self.render is None: self.but+=',image'

    def _init_state(self):
        "Inits the histories, experiences, and the environment pool when sent to a `Process`"
        self.history,self.pool=L((deque(maxlen=self.steps_count),
                                  gym.make(self.env,**self.env_kwargs)) 
                                  for _ in range(self.n_envs)).zip().map(L) 
        self.pool.map(_env_seed,seed=self.seed)
        if self.agent is None: self.agent=FakeAgent(self.pool[0].action_space)
        self.reset_all()
        
    def reset_all(self):
        self.experiences=self.pool.map(_env_reset)
        self.experiences=self.experiences.map(_state2experience,but=self.but)
        self.experiences=sum(self.experiences[1:],self.experiences[0])
        self.attempt_render(self.experiences)
        
    def attempt_render(self,experiences,indexes=None):
        if self.render is not None: 
            pool=self.pool if indexes is None else self.pool[indexes]
            renders=pool.map(_env_render,mode=self.render)
            # No idea why we have to do this, but multiprocessing hangs forever otherwise
            if self.num_workers>0:sleep(0.1) 
            experiences['image']=torch.stack(tuple(renders))

    def __iter__(self):
        "Iterates through a list of environments."
        if not self.pool:self._init_state()
        while True: 
#             try:
            # Only work on envs that are not done
            not_done_idxs=(self.experiences['done']==False).nonzero()[:,0]
            if len(not_done_idxs)==0: 
                self.reset_all()
                not_done_idxs=(self.experiences['done']==False).nonzero()[:,0]
            not_done_experiences=self.experiences.filter(indexes=not_done_idxs)
            # Pass current experiences into agent
            actions,experiences=self.agent(**not_done_experiences)
            # Step through all envs.
            step_res=self.pool[not_done_idxs].zipwith(actions).starmap(_env_step)
            next_states,rewards,dones=step_res.zip()[:3].map(Tensor)
            # Add the image field if available
            self.attempt_render(self.experiences,not_done_idxs)
            
            new_exp=D(next_state=next_states,reward=rewards,done=dones,
                      env=not_done_idxs,step=not_done_experiences['step']+1)
            experiences=D(merge(not_done_experiences,experiences,new_exp))
            # TODO: Ugly, I shouldn't have to do this
            for k in experiences: 
                if self.experiences[k].shape!=experiences[k].shape:
                    self.experiences[k]=torch.zeros((self.experiences[k].shape[0],
                                                     *experiences[k].shape[1:]))
                if torch.is_floating_point(experiences[k]): 
                    self.experiences[k]=self.experiences[k].float()
                else:
                    self.experiences[k]=self.experiences[k].long()
                self.experiences[k][not_done_idxs]=experiences[k]
            
            if self.n_envs>1:
                experiences=parallel(partial(experiences.subset),not_done_idxs,
                                       threadpool=True,n_workers=0 if self.num_workers>0 else 2,progress=False)
            else:
                experiences=[experiences.subset(not_done_idxs)]
            # TODO: Ugly, I shouldn't have to do this
            if 'image' in experiences[0]: 
                experiences=[merge(d,{'image':d['image'].unsqueeze(0)}) for d in experiences]
                
                
            experiences=[{k:(e[k].unsqueeze(0) if not e[k].shape or e[k].shape[0]!=1 else e[k]) for k in e} for e in experiences]

            for idx in not_done_idxs: 
                self.history[idx].append(experiences[idx])                
                if len(self.history[idx])==self.steps_count and \
                       int(experiences[idx]['step'])%self.steps_delta==0:
                    yield tuple(self.history[idx])
                
                if bool(experiences[idx]['done']):
                    if 0<len(self.history[idx])<self.steps_count:
                        yield tuple(self.history[idx])
                    while len(self.history[idx])>1:
                        self.history[idx].popleft()
                        yield tuple(self.history[idx])

#             except ValueError:
#                 self.reset_all()
            
add_docs(ExperienceSource,
        """Iterates through `n_envs` of `env` feeding experience or states into `agent`.
           If `agent` is None, then random actions will be taken instead.
           It will return `steps_count` experiences every `steps_delta`.
           At the end of an env, it will return `steps_count-1` experiences per next. """,
        reset_all="resets the envs and experience",
        attempt_render="Updates `experiences` with images if `render is not None`. Optionally indexes can be passed.")

In [None]:
# export
class SourceDataset(IterableDataset):
    "Iterates through a `source` object. Allows for re-initing source connections when `num_workers>0`"
    def __init__(self,source=None): store_attr('source')
    def __iter__(self):             return iter(self.source)
    def wif(self):                  self.source._init_state()

In [None]:
source=ExperienceSource('CartPole-v1',None,n_envs=2,render='rgb_array')
dataset=SourceDataset(source)

data=None
for x in DataLoader(dataset,num_workers=0,n=50,persistent_workers=True,wif=dataset.wif):
    data=D(x) if data is None else data+D(x)
data.pandas()

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [None]:
source=ExperienceSource('CartPole-v1',None,n_envs=1,render='rgb_array')
dataset=SourceDataset(source)

data=None
for x in DataLoader(dataset,num_workers=0,n=10,persistent_workers=True,wif=dataset.wif):
    data=D(x) if data is None else data+D(x)
data.pandas()

`ExperienceSource` is designed for iterating through `n_envs` environments.

A single experience is a `dict`:

In [None]:
D(state=None,action=None,next_state=None,reward=None,rewards=None,
             step=None,steps=None)

However, an agent has full power to add fields to this dict wile running

In [None]:
show_doc(ExperienceSource._init_state)

In [None]:
show_doc(ExperienceSource.__iter__)

If the `self.pool` field is empty, it will call `_init_state` to reinitialize everything.

In [None]:

import gym
from queue import deque

class ExperienceSource(object):
    def __init__(self, env:str,agent,n_envs:int=1,steps_count:int=1,steps_delta:int=1,
                 vectorized:bool=False,seed:int=None):
        store_attr()
        self.pool:List[gym.Env]=[gym.make(self.env) for _ in range(self.n_envs)]
    
    def init_env(self,env,states,histories,cur_rewards,cur_steps,env_lens):
        env.seed(self.seed)
        obs=env.reset()
        if self.vectorized:
            obs_len = len(obs)
            states.extend(obs)
        else:
            obs_len = 1
            states.append(obs)
        env_lens.append(obs_len)

        for _ in range(obs_len):
            histories.append(deque(maxlen=self.steps_count))
            cur_rewards.append(0.0)
            cur_steps.append(0)
        
    def __iter__(self):
        states,histories,cur_rewards,cur_steps,env_lens=[],[],[],[],[]

        for env in self.pool: self.init_env(env,states,histories,cur_rewards,cur_steps,env_lens)

        iter_idx = 0
        while True:
            actions = [None] * len(states)
            states_input = []
            states_indices = []
            for idx, state in enumerate(states):
                if state is None:
                    actions[idx] = self.pool[0].action_space.sample()  # assume that all envs are from the same family
                else:
                    states_input.append(state)
                    states_indices.append(idx)
            if states_input:
                states_actions = self.agent(states_input)
                for idx, action in enumerate(states_actions):
                    g_idx = states_indices[idx]
                    actions[g_idx] = action
#             grouped_actions = _group_list(actions, env_lens)
            grouped_actions=np.split(actions,env_lens[:])

            global_ofs = 0
            for env_idx, (env, action_n) in enumerate(zip(self.pool, grouped_actions)):
                if self.vectorized:
                    next_state_n, r_n, is_done_n, _ = env.step(action_n)
                else:
                    next_state, r, is_done, _ = env.step(action_n[0])
                    next_state_n, r_n, is_done_n = [next_state], [r], [is_done]

                for ofs, (action, next_state, r, is_done) in enumerate(zip(action_n, next_state_n, r_n, is_done_n)):
                    idx = global_ofs + ofs
                    state = states[idx]
                    history = histories[idx]

                    cur_rewards[idx] += r
                    cur_steps[idx] += 1
                    if state is not None:
                        history.append(dict(state=state,next_state=next_state, action=action, reward=r, done=is_done,steps=cur_steps[idx],episode_reward=cur_rewards[idx],env=env_idx))
                    if len(history) == self.steps_count and iter_idx % self.steps_delta == 0:
                        yield tuple(history)
                    states[idx] = next_state
                    if is_done:
                        # in case of very short episode (shorter than our steps count), send gathered history
                        if 0 < len(history) < self.steps_count:
                            yield tuple(history)
                        # generate tail of history
                        while len(history) > 1:
                            history.popleft()
                            yield tuple(history)
                        cur_rewards[idx] = 0.0
                        cur_steps[idx] = 0
                        # vectorized envs are reset automatically
                        env.seed(self.seed)
                        states[idx]=env.reset() if not self.vectorized else None
                        history.clear()
                global_ofs += len(action_n)
            iter_idx += 1

In [None]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():   
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import make_readme
    make_readme()
    notebook2script()
    notebook2html()

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
Converted .data.block_old.ipynb.
Converted 00_core.ipynb.
Converted 00_nbdev_extension.ipynb.
Converted 05_data.block.ipynb.
Converted 05_data.test_async.ipynb.
Converted 20_test_utils.ipynb.
Converted index.ipynb.
Converted nbdev_template.ipynb.
converting: /home/fastrl_user/fastrl/nbs/05_data.block.ipynb
