In [None]:
#hide
#skip
! [ -e /content ] && pip install -Uqq fastrl['dev']  # upgrade fastrl on colab

In [None]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON

In [None]:
# default_exp data.block

In [None]:
# export
# Python native modules
import os
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from torch.utils.data import Dataset
from torch import nn


import numpy as np
import gym
import time,sys
import torch.multiprocessing as mp
import pandas as pd

# Local modules

# Data Block
> Fastrl transforms for iterating through environments

In [None]:
# export
class DQN(Module):
    def __init__(self):
        self.policy=nn.Sequential(
            nn.Linear(4,50),
            nn.ReLU(),
            nn.Linear(50,2),
            nn.ReLU()
        )
    
    def forward(self,x): 
        return torch.argmax(self.policy(x),dim=0)

Development of this was helped by [IterableData documentation on multiple workers](https://github.com/pytorch/pytorch/blob/4949eea0ffb60dc81a0a78402fa59fdf68206718/torch/utils/data/dataset.py#L64)

This code is heavily modifed from https://github.com/Shmuma/ptan

Reference for env [semantics related to vectorized environments](https://github.com/openai/universe/blob/master/doc/env_semantics.rst)

In [None]:
# export
class ExperienceSource(object):
    def __init__(self,env:str,agent,n_envs:int=1,steps_count:int=1,steps_delta:int=1,
                 vectorized:bool=False,seed:int=None,**kwargs):
        store_attr()
        self.pool:List[gym.Env]=[gym.make(self.env) for _ in range(self.n_envs)]
            
    

In [None]:

import gym
from queue import deque

class ExperienceSource(object):
    def __init__(self, env:str,agent,n_envs:int=1,steps_count:int=1,steps_delta:int=1,
                 vectorized:bool=False,seed:int=None):
        store_attr()
        self.pool:List[gym.Env]=[gym.make(self.env) for _ in range(self.n_envs)]
    
    def init_env(self,env,states,histories,cur_rewards,cur_steps,env_lens):
        env.seed(self.seed)
        obs=env.reset()
        if self.vectorized:
            obs_len = len(obs)
            states.extend(obs)
        else:
            obs_len = 1
            states.append(obs)
        env_lens.append(obs_len)

        for _ in range(obs_len):
            histories.append(deque(maxlen=self.steps_count))
            cur_rewards.append(0.0)
            cur_steps.append(0)
        
    def __iter__(self):
        states,histories,cur_rewards,cur_steps,env_lens=[],[],[],[],[]

        for env in self.pool: self.init_env(env,states,histories,cur_rewards,cur_steps,env_lens)

        iter_idx = 0
        while True:
            actions = [None] * len(states)
            states_input = []
            states_indices = []
            for idx, state in enumerate(states):
                if state is None:
                    actions[idx] = self.pool[0].action_space.sample()  # assume that all envs are from the same family
                else:
                    states_input.append(state)
                    states_indices.append(idx)
            if states_input:
                states_actions = self.agent(states_input)
                for idx, action in enumerate(states_actions):
                    g_idx = states_indices[idx]
                    actions[g_idx] = action
#             grouped_actions = _group_list(actions, env_lens)
            grouped_actions=np.split(actions,env_lens[:])

            global_ofs = 0
            for env_idx, (env, action_n) in enumerate(zip(self.pool, grouped_actions)):
                if self.vectorized:
                    next_state_n, r_n, is_done_n, _ = env.step(action_n)
                else:
                    next_state, r, is_done, _ = env.step(action_n[0])
                    next_state_n, r_n, is_done_n = [next_state], [r], [is_done]

                for ofs, (action, next_state, r, is_done) in enumerate(zip(action_n, next_state_n, r_n, is_done_n)):
                    idx = global_ofs + ofs
                    state = states[idx]
                    history = histories[idx]

                    cur_rewards[idx] += r
                    cur_steps[idx] += 1
                    if state is not None:
                        history.append(dict(state=state,next_state=next_state, action=action, reward=r, done=is_done,steps=cur_steps[idx],episode_reward=cur_rewards[idx],env=env_idx))
                    if len(history) == self.steps_count and iter_idx % self.steps_delta == 0:
                        yield tuple(history)
                    states[idx] = next_state
                    if is_done:
                        # in case of very short episode (shorter than our steps count), send gathered history
                        if 0 < len(history) < self.steps_count:
                            yield tuple(history)
                        # generate tail of history
                        while len(history) > 1:
                            history.popleft()
                            yield tuple(history)
                        cur_rewards[idx] = 0.0
                        cur_steps[idx] = 0
                        # vectorized envs are reset automatically
                        env.seed(self.seed)
                        states[idx]=env.reset() if not self.vectorized else None
                        history.clear()
                global_ofs += len(action_n)
            iter_idx += 1

In [None]:
agent=lambda s: [0]

data=None

for e,_ in zip(ExperienceSource('CartPole-v1',agent,seed=0,steps_count=3),range(70)):
    if data is None: data=pd.DataFrame(e)
    else:            data=data.append(pd.DataFrame(e),ignore_index=True)
data

Unnamed: 0,state,next_state,action,reward,done,steps,episode_reward,env
0,"[-0.04456399437492168, 0.04653909372423204, 0.013269094558410327, -0.020998265615229175]","[-0.04363321250043704, -0.14877061164085567, 0.012849129246105744, 0.27584150116514483]",0,1.0,False,1,1.0,0
1,"[-0.04363321250043704, -0.14877061164085567, 0.012849129246105744, 0.27584150116514483]","[-0.04660862473325415, -0.3440735048698027, 0.018365959269408642, 0.572549197993572]",0,1.0,False,2,2.0,0
2,"[-0.04660862473325415, -0.3440735048698027, 0.018365959269408642, 0.572549197993572]","[-0.05349009483065021, -0.5394480966584043, 0.02981694322928008, 0.8709609494144344]",0,1.0,False,3,3.0,0
3,"[-0.04363321250043704, -0.14877061164085567, 0.012849129246105744, 0.27584150116514483]","[-0.04660862473325415, -0.3440735048698027, 0.018365959269408642, 0.572549197993572]",0,1.0,False,2,2.0,0
4,"[-0.04660862473325415, -0.3440735048698027, 0.018365959269408642, 0.572549197993572]","[-0.05349009483065021, -0.5394480966584043, 0.02981694322928008, 0.8709609494144344]",0,1.0,False,3,3.0,0
...,...,...,...,...,...,...,...,...
184,"[-0.09759162379326478, -1.1265756817485255, 0.10029304257894092, 1.7938742678977837]","[-0.12012313742823529, -1.3226681727103522, 0.1361705279368966, 2.1159716654401786]",0,1.0,False,7,7.0,0
185,"[-0.12012313742823529, -1.3226681727103522, 0.1361705279368966, 2.1159716654401786]","[-0.14657650088244234, -1.5188614363481507, 0.17848996124570016, 2.4474478802002593]",0,1.0,False,8,8.0,0
186,"[-0.09759162379326478, -1.1265756817485255, 0.10029304257894092, 1.7938742678977837]","[-0.12012313742823529, -1.3226681727103522, 0.1361705279368966, 2.1159716654401786]",0,1.0,False,7,7.0,0
187,"[-0.12012313742823529, -1.3226681727103522, 0.1361705279368966, 2.1159716654401786]","[-0.14657650088244234, -1.5188614363481507, 0.17848996124570016, 2.4474478802002593]",0,1.0,False,8,8.0,0


In [None]:
# export
class TestDataset(IterableDataset):
    def __init__(self,start=1,end=10,policy=None,device='cpu',n_envs=1):
        store_attr('start,end,policy,device,n_envs')
        
    def init_envs(self,n):
        self.envs=[gym.make('CartPole-v1') for i in range(n)]
        
    def __iter__(self):
        worker_info=torch.utils.data.get_worker_info()
        
        if worker_info is None:  # single-process data loading, return the full iterator
            self.init_envs(self.n_envs)
        else:  # in a worker process
            # split workload
            per_worker=int(math.ceil(self.n_envs/worker_info.num_workers))
            self.init_envs(per_worker)
        return iter(range(iter_start, iter_end))

In [None]:
ds=TestDataset()
for i in ds:print(i)

In [None]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():   
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import make_readme
    make_readme()
    notebook2script()
    notebook2html()