In [1]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
#|default_exp envs.gym

In [4]:
#|export
# Python native modules
import os
import warnings
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
import gym
from fastai.torch_basics import *
from fastai.torch_core import *
from torchdata.dataloader2.graph import find_dps,traverse
from fastrl.data.dataloader2 import *
from torchdata.dataloader2 import DataLoader2,DataLoader2Iterator
from torchdata.dataloader2.graph import find_dps,traverse,DataPipe,IterDataPipe,MapDataPipe
# Local modules
from fastrl.core import *
from fastrl.pipes.core import *
from fastrl.pipes.iter.nskip import *
from fastrl.pipes.iter.nstep import *
from fastrl.pipes.iter.firstlast import *
from fastrl.pipes.iter.transforms import *
from fastrl.pipes.map.transforms import *
from fastrl.data.block import *

# Envs Gym
> Fastrl API for working with OpenAI Gyms

### Pipes

In [5]:
#|export
class GymTypeTransform(Transform):
    "Creates an gym.env"
    def encodes(self,o): return gym.make(o,new_step_api=True)

In [6]:
#|export
class GymStepper(dp.iter.IterDataPipe):
    def __init__(self,
        source_datapipe:Union[Iterable,dp.iter.IterDataPipe], # Calling `next()` should produce a `gym.Env`
        agent=None, # Optional `Agent` that accepts a `SimpleStep` to produce a list of actions.
        seed:int=None, # Optional seed to set the env to and also random action sames if `agent==None`
        synchronized_reset:bool=False, # Some `gym.Envs` require reset to be terminated on *all* envs before proceeding to step.
        include_images:bool=False,
    ):
        self.source_datapipe = source_datapipe
        self.agent = agent
        self.seed = seed
        self.include_images = include_images
        self.synchronized_reset = synchronized_reset
        self._env_ids = {}
        
    def env_reset(self,
      env:gym.Env, # The env to rest along with its numeric object id
      env_id:int # Resets env in `self._env_ids[env_id]`
    ) -> SimpleStep:
        state = env.reset(seed=self.seed)
        env.action_space.seed(seed=self.seed)
        episode_n = self._env_ids[env_id].episode_n+1 if env_id in self._env_ids else tensor(1)

        step = (self.no_agent_create_step if self.agent is None else self.agent.create_step)(
            state=tensor(state),
            next_state=tensor(state),
            terminated=tensor(False),
            truncated=tensor(False),
            reward=tensor(0),
            total_reward=tensor(0.),
            env_id=tensor(env_id),
            proc_id=tensor(os.getpid()),
            step_n=tensor(0),
            episode_n=episode_n,
            image=env.render(mode='rgb_array') if self.include_images else torch.FloatTensor([0])
        )
        self._env_ids[env_id] = step
        return step
    
    def no_agent_create_step(self,**kwargs): return SimpleStep(**kwargs)

    def __iter__(self) -> SimpleStep:
        for env in self.source_datapipe:
            assert issubclass(env.__class__,gym.Env),f'Expected subclass of gym.Env, but got {env.__class__}'    
            env_id = id(env)
            
            if env_id not in self._env_ids or self._env_ids[env_id].terminated:
                if self.synchronized_reset:
                    if env_id in self._env_ids \
                       and not self._env_ids[env_id].terminated \
                       and self._resetting_all:
                        # If this env has already been reset, and we are currently in the 
                        # self._resetting_all phase, then skip this so we can reset all remaining envs
                        continue
                    elif env_id not in self._env_ids \
                       or all([self._env_ids[s].terminated for s in self._env_ids])\
                       or self._resetting_all:
                        # If the id is not in the _env_ids, we can assume this is a fresh start.
                        # OR 
                        # If all the envs are terminated, then we can start doing a reset operation.
                        # OR
                        # If we are currently resetting all the envs anyways
                        # This means we want to reset ALL the envs before doing any steps.
                        self.env_reset(env,env_id)
                        # Move to the next env, eventually we will reset all the envs in sync.
                        # then we will be able to start calling `step` for each of them.
                        # _resetting_all is True when there are envs still "terminated".
                        self._resetting_all = any([self._env_ids[s].terminated for s in self._env_ids])
                        continue 
                    elif self._env_ids[env_id].terminated:
                        continue
                    else:
                        raise ValueError('This else should never happen.')
                else:
                    step = self.env_reset(env,env_id)
            else:
                step = self._env_ids[env_id]

            action = None
            for action in (self.agent([step]) if self.agent is not None else [env.action_space.sample()]):
                next_state,reward,terminated,truncated,_ = env.step(
                    self.agent.augment_actions(action) if self.agent is not None else action
                )

                step = (self.no_agent_create_step if self.agent is None else self.agent.create_step)(
                    state=tensor(step.next_state),
                    next_state=tensor(next_state),
                    action=tensor(action).float(),
                    terminated=tensor(terminated),
                    truncated=tensor(truncated),
                    reward=tensor(reward),
                    total_reward=step.total_reward+reward,
                    env_id=tensor(env_id),
                    proc_id=tensor(os.getpid()),
                    step_n=step.step_n+1,
                    episode_n=step.episode_n,
                    image=env.render(mode='rgb_array') if self.include_images else torch.FloatTensor([0])
                )
                self._env_ids[env_id] = step
                yield step
                if terminated: break
            if action is None: 
                raise Exception('The agent produced no actions. This should never occur.')
                
add_docs(
    GymStepper,
    """Accepts a `source_datapipe` or iterable whose `next()` produces a single `gym.Env`.
       Tracks multiple envs using `id(env)`.""",
    env_reset="Resets a env given the env_id.",
    no_agent_create_step="If there is no agent for creating the step output, then `GymStepper` will create its own"
)

In [7]:
#|hide
# Used here to avoid UserWarnings related to gym complaining about bounding box / action space format.
# There must be a bug in the CartPole-v1 env that is causing this to show. Also couldnt figure out the 
# regex, so instead we filter on the lineno, which is line 98.
warnings.filterwarnings("ignore",category=UserWarning,lineno=98)

## Iteration Examples

In [8]:
#|eval: False
import pandas as pd
from fastrl.agents.core import * 

class ConstantRunner(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe,constant=1,array_nestings=0): 
        self.source_datapipe = source_datapipe
        self.agent_base = find_dp(traverse(self.source_datapipe),AgentBase)
        self.constant = constant
        self.array_nestings = array_nestings
    
    def __iter__(self):
        for o in self.source_datapipe: 
            try: 
                if self.array_nestings==0: yield self.constant
                else:
                    yield [self.constant]*self.array_nestings
            except Exception:
                print('Failed on ',o)
                raise

agent = AgentBase(None,[])
agent = ConstantRunner(agent)
agent = AgentHead(agent)

pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,agent=agent,seed=0)

pd.DataFrame([step for step,_ in zip(*(pipe,range(10)))])[['state','next_state','action','terminated']]

Unnamed: 0,state,next_state,action,terminated
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
1,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
2,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
3,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
4,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
5,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
6,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
7,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
8,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
9,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]","[tensor(0.0353), tensor(0.7603), tensor(-0.0866), tensor(-1.2844)]",tensor(1.),tensor(False)


In [9]:
pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,seed=0)

pd.DataFrame([step for step,_ in zip(*(pipe,range(10)))])[['state','next_state','action','terminated']]

Unnamed: 0,state,next_state,action,terminated
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
1,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
2,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
3,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
4,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
5,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
6,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
7,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
8,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
9,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]","[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]",tensor(0.),tensor(False)


In [10]:
from torch.utils.data.dataloader_experimental import DataLoader2

def seed_worker(worker_id): torch.manual_seed(0)

dl = DataLoader2(pipe,num_workers=2,worker_init_fn=seed_worker)

pd.DataFrame([step for step,_ in zip(*(dl,range(10)))])[['state','next_state','action','terminated']]

Unnamed: 0,state,next_state,action,terminated
0,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
1,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
2,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
3,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
4,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[tensor(0.)],[tensor(False)]
5,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[tensor(0.)],[tensor(False)]
6,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
7,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
8,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
9,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]


In [11]:
pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,synchronized_reset=True)

pd.DataFrame([step for step,_ in zip(*(dl,range(10)))])[['state','next_state','action','terminated']]

Unnamed: 0,state,next_state,action,terminated
0,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
1,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
2,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
3,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
4,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[tensor(0.)],[tensor(False)]
5,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[tensor(0.)],[tensor(False)]
6,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
7,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
8,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
9,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]


## Tests

We create 3 envs and put a max iteration count at 180. Each env will run for 18 steps before ending, which means
we expect there to be 10 total episodes.

In [12]:
import pandas as pd

envs = ['CartPole-v1']*3
n_episodes = 3

pipe = dp.map.Mapper(envs)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
# We want to cycle through the envs enough times that their epsiode sum to 9, 3 episodes each
pipe = pipe.cycle(count=(18*len(envs))) 
pipe = GymStepper(pipe,seed=0)

All the of the environments should reach max 18 steps given a seed of 0...\
The total number of iterations should be `( 18 * n_envs) * n_episodes_per_env = 162`...

In [13]:
steps = list(pipe)
gsteps = groupby(steps,lambda o:int(o.step_n))
test_len(gsteps.keys(),18)
pd.DataFrame([step for step in steps])[['state','terminated','env_id','episode_n','step_n']][::10]

Unnamed: 0,state,terminated,env_id,episode_n,step_n
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]",tensor(False),tensor(140658056028048),tensor(1),tensor(1)
10,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(False),tensor(140658055948816),tensor(1),tensor(4)
20,"[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]",tensor(False),tensor(140658056533712),tensor(1),tensor(7)
30,"[tensor(0.0217), tensor(-0.4009), tensor(-0.0929), tensor(0.2661)]",tensor(False),tensor(140658056028048),tensor(1),tensor(11)
40,"[tensor(0.0094), tensor(0.1879), tensor(-0.0961), tensor(-0.6926)]",tensor(False),tensor(140658055948816),tensor(1),tensor(14)
50,"[tensor(0.0325), tensor(0.7771), tensor(-0.1570), tensor(-1.6694)]",tensor(False),tensor(140658056533712),tensor(1),tensor(17)
60,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(False),tensor(140658056028048),tensor(2),tensor(3)
70,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",tensor(False),tensor(140658055948816),tensor(2),tensor(6)
80,"[tensor(0.0417), tensor(-0.4040), tensor(-0.1113), tensor(0.3342)]",tensor(False),tensor(140658056533712),tensor(2),tensor(9)
90,"[tensor(0.0096), tensor(-0.0083), tensor(-0.0886), tensor(-0.3733)]",tensor(False),tensor(140658056028048),tensor(2),tensor(13)


All of the step groups should be the same length...

In [14]:
group_sz = None
for name,group in gsteps.items():
    if group_sz is None: group_sz = len(group)
    else:                assert len(group)==group_sz,f' Got lengths {len(group)} and {group_sz} for {name}.\n\n{group}'

Each step group's state and next_states should match across envs...

In [15]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other in group[1:]: test_eq(e1.state,other.state)
    for other in group[1:]: test_eq(e1.next_state,other.next_state)

Each step group value should not show up/be duplicated in any other step groups...

In [16]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other_name,other_group in gsteps.items():
        if other_name==name: continue
        for other in other_group[1:]: test_ne(e1.state,other.state)
        for other in other_group[1:]: test_ne(e1.next_state,other.next_state)

Given 3 envs, single steps, epsiodes of 18 steps in len, 3 episodes each, run for 162 iterations, we should
expect there to be 9 dones.

In [17]:
test_eq(sum([o.terminated for o in steps]),tensor([9]))

The max episode numbers for each env should sum to 9 where for each env, it should reach and finish 3 episodes...

In [18]:
gsteps = groupby(steps,lambda o:int(o.env_id))
test_len(gsteps.keys(),3)
env1,env2,env3 = L(gsteps.values()).map(L).map(Self.map(Self.episode_n()).map(int))
test_eq(max(env1)+max(env2)+max(env3),9)

### Test the `synchronized_reset` param...
> In this case, we will have iterate through the 3 envs without producing a step on warmup.

In [19]:
import pandas as pd

envs = ['CartPole-v1']*3
n_episodes = 3

pipe = dp.map.Mapper(envs)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
# We want to cycle through the envs enough times that their epsiode sum to 9, 3 episodes each
# We add an additional +3 cycles since `synchronized_reset` cycles through the envs additional times
# to make sure they are all reset prior to stepping
pipe = pipe.cycle(count=(18*len(envs))+3) 
pipe = GymStepper(pipe,seed=0,synchronized_reset=True)

In [20]:
steps = list(pipe)
gsteps = groupby(steps,lambda o:int(o.step_n))
test_len(gsteps.keys(),18)
pd.DataFrame([step for step in steps])[['state','terminated','env_id','episode_n','step_n']][::10]

Unnamed: 0,state,terminated,env_id,episode_n,step_n
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]",tensor(False),tensor(140658715693264),tensor(1),tensor(1)
10,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(False),tensor(140658057039312),tensor(1),tensor(4)
20,"[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]",tensor(False),tensor(140658715694416),tensor(1),tensor(7)
30,"[tensor(0.0217), tensor(-0.4009), tensor(-0.0929), tensor(0.2661)]",tensor(False),tensor(140658715693264),tensor(1),tensor(11)
40,"[tensor(0.0094), tensor(0.1879), tensor(-0.0961), tensor(-0.6926)]",tensor(False),tensor(140658057039312),tensor(1),tensor(14)
50,"[tensor(0.0325), tensor(0.7771), tensor(-0.1570), tensor(-1.6694)]",tensor(False),tensor(140658715694416),tensor(1),tensor(17)
60,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(False),tensor(140658715693264),tensor(2),tensor(3)
70,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",tensor(False),tensor(140658057039312),tensor(2),tensor(6)
80,"[tensor(0.0417), tensor(-0.4040), tensor(-0.1113), tensor(0.3342)]",tensor(False),tensor(140658715694416),tensor(2),tensor(9)
90,"[tensor(0.0096), tensor(-0.0083), tensor(-0.0886), tensor(-0.3733)]",tensor(False),tensor(140658715693264),tensor(2),tensor(13)


All of the step groups should be the same length...

In [21]:
group_sz = None
for name,group in gsteps.items():
    if group_sz is None: group_sz = len(group)
    else:                assert len(group)==group_sz,f' Got lengths {len(group)} and {group_sz} for {name}.\n\n{group}'

Each step group's state and next_states should match across envs...

In [22]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other in group[1:]: test_eq(e1.state,other.state)
    for other in group[1:]: test_eq(e1.next_state,other.next_state)

Each step group value should not show up/be duplicated in any other step groups...

In [23]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other_name,other_group in gsteps.items():
        if other_name==name: continue
        for other in other_group[1:]: test_ne(e1.state,other.state)
        for other in other_group[1:]: test_ne(e1.next_state,other.next_state)

Given 3 envs, single steps, epsiodes of 18 steps in len, 3 episodes each, run for 162 iterations, we should
expect there to be 9 dones.

In [24]:
test_eq(sum([o.terminated for o in steps]),tensor([9]))

The max episode numbers for each env should sum to 9 where for each env, it should reach and finish 3 episodes...

In [25]:
gsteps = groupby(steps,lambda o:int(o.env_id))
test_len(gsteps.keys(),3)
env1,env2,env3 = L(gsteps.values()).map(L).map(Self.map(Self.episode_n()).map(int))
test_eq(max(env1)+max(env2)+max(env3),9)

In [26]:
import pandas as pd

envs = ['CartPole-v1']*10

pipe = dp.map.Mapper(envs)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle(count=(18*len(envs))) 
# Turn off the seed so that some envs end before others...
pipe = GymStepper(pipe,synchronized_reset=True)
steps = list(pipe)

Since the seed is turned off the only properties we are to expect are:
    
    - If an env finishes, no steps from that env should be seen until all 9 of the other envs finish

In [27]:
def synchronized_reset_checker(steps):
    env_id_done_tracker = {}
    did_syncs_happen = False
    for d,env_id,idx in [(bool(o.terminated),int(o.env_id),i) for i,o in enumerate(steps)]:

        if d: 
            env_id_done_tracker[env_id] = idx
            continue

        if env_id in env_id_done_tracker:
            if len(env_id_done_tracker)!=len(envs):
                raise Exception(f'env_id {env_id} was iterated through when it should not have been! idx: {idx}')
        if len(env_id_done_tracker)==len(envs):
            did_syncs_happen = True
            env_id_done_tracker = {}

    if not did_syncs_happen: 
        raise Exception('There should have at least been 1 time where all the envs had to reset, which did not happen.')
synchronized_reset_checker(steps)

For sanity, we should expect that without `synchronized_reset` envs will be reset and stepped through before other 
envs are reset, `synchronized_reset_checker` should fail.

In [28]:
pipe = dp.map.Mapper(envs)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle(count=(18*len(envs))) 
# Turn off the seed so that some envs end before others...
pipe = GymStepper(pipe)
steps = list(pipe)

In [29]:
with ExceptionExpected(regex='was iterated through when it should not have been'):
    synchronized_reset_checker(steps)

In [30]:
#|export
def GymTransformBlock(
    agent:DataPipe, # An AgentHead
    seed:Optional[int]=None, # The seed for the gym to use
    # Used by `NStepper`, outputs tuples / chunks of assiciated steps
    nsteps:int=1, 
    # Used by `NSkipper` to skip a certain number of steps (agent still gets called for each)
    nskips:int=1,
    # Whether when nsteps>1 to merge it into a single `StepType`
    firstlast:bool=False,
    # Functions to run once, at the beginning of the pipeline
    type_tfms:Optional[List[Callable]]=None,
    # Functions to run over individual steps before batching
    item_tfms:Optional[List[Callable]]=None,
    # Functions to run over batches (as specified by `bs`)
    batch_tfms:Optional[List[Callable]]=None,
    # The batch size, which is different from `nsteps` in that firstlast will be 
    # run prior to batching, and a batch of steps might come from multiple envs,
    # where nstep is associated with a single env
    bs:int=1,
    # The prefered default is for the pipeline to be infinate, and the learner
    # decides how much to iter. If this is not None, then the pipeline will run for 
    # that number of `n`
    n:Optional[int]=None,
    # Whether to reset all the envs at the same time as opposed to reseting them 
    # the moment an episode ends. 
    synchronized_reset:bool=False,
    # Should be used only for validation / logging, will grab a render of the gym
    # and assign to the `StepType` image field. This data should not be used for training.
    # If it images are needed for training, then you should wrap the env instead. 
    include_images:bool=False,
    # Additional pipelines to insert, replace, remove
    dp_augmentation_fns:Tuple[DataPipeAugmentationFn]=None
) -> TransformBlock:
    "Basic OpenAi gym `DataPipeGraph` with first-last, nstep, and nskip capability"
    def _GymTransformBlock(
        # `source` likely will be an iterable that gets pushed into the pipeline when an 
        # experiment is actually being run.
        source:Any,
        # Any parameters needed for the dataloader
        num_workers:int=0,
        # This param must exist: as_dataloader for the datablock to create dataloaders
        as_dataloader:bool=False
    ) -> DataPipeOrDataLoader:
        _type_tfms = ifnone(type_tfms,GymTypeTransform)
        "This is the function that is actually run by `DataBlock`"
        pipe = dp.map.Mapper(source)
        pipe = TypeTransformer(pipe,_type_tfms)
        pipe = dp.iter.MapToIterConverter(pipe)
        pipe = dp.iter.InMemoryCacheHolder(pipe)
        pipe = pipe.cycle() # Cycle through the envs inf
        pipe = GymStepper(pipe,agent=agent,seed=seed,
                          include_images=include_images,synchronized_reset=synchronized_reset)
        if nskips!=1: pipe = NSkipper(pipe,n=nskips)
        if nsteps!=1:
            pipe = NStepper(pipe,n=nsteps)
            if firstlast:
                pipe = FirstLastMerger(pipe)
            else:
                pipe = NStepFlattener(pipe) # We dont want to flatten if using FirstLastMerger
        if n is not None: pipe = pipe.header(limit=n)
        pipe = ItemTransformer(pipe,item_tfms)
        pipe  = pipe.batch(batch_size=bs)
        pipe = BatchTransformer(pipe,batch_tfms)
        
        pipe = apply_dp_augmentation_fns(pipe,ifnone(dp_augmentation_fns,()))
        
        if as_dataloader:
            pipe = DataLoader2(
                datapipe=pipe,
                reading_service=PrototypeMultiProcessingReadingService(
                    num_workers = num_workers,
                    protocol_client_type = InputItemIterDataPipeQueueProtocolClient,
                    protocol_server_type = InputItemIterDataPipeQueueProtocolServer,
                    pipe_type = item_input_pipe_type,
                    eventloop = SpawnProcessForDataPipeline
                ) if num_workers>0 else None
            )
        return pipe 
    return _GymTransformBlock

In [34]:
import pandas as pd
pd.set_option('display.max_rows', 50)

envs = ['CartPole-v1']*3
block = GymTransformBlock(None,nsteps=2,nskips=2,firstlast=True,bs=1,n=100)
pipes = block(envs)
pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward']][:50]

Unnamed: 0,state,action,terminated,reward
0,"[tensor(-0.0063), tensor(0.0144), tensor(0.0117), tensor(-0.0126)]",tensor(1.),tensor(False),tensor(1.9900)
1,"[tensor(-0.0009), tensor(0.0114), tensor(0.0244), tensor(-0.0227)]",tensor(0.),tensor(False),tensor(1.9900)
2,"[tensor(0.0410), tensor(-0.0036), tensor(-0.0285), tensor(0.0234)]",tensor(0.),tensor(False),tensor(1.9900)
3,"[tensor(-0.0060), tensor(0.2094), tensor(0.0115), tensor(-0.3016)]",tensor(0.),tensor(False),tensor(1.9900)
4,"[tensor(-0.0006), tensor(-0.1840), tensor(0.0239), tensor(0.2775)]",tensor(1.),tensor(False),tensor(1.9900)
5,"[tensor(0.0409), tensor(-0.1983), tensor(-0.0280), tensor(0.3070)]",tensor(1.),tensor(False),tensor(1.9900)
6,"[tensor(-0.0015), tensor(0.2091), tensor(0.0054), tensor(-0.2962)]",tensor(0.),tensor(False),tensor(1.9900)
7,"[tensor(-0.0041), tensor(0.2054), tensor(0.0293), tensor(-0.2907)]",tensor(1.),tensor(False),tensor(1.9900)
8,"[tensor(0.0369), tensor(0.1926), tensor(-0.0218), tensor(-0.2939)]",tensor(1.),tensor(False),tensor(1.9900)
9,"[tensor(0.0029), tensor(-0.1812), tensor(-0.0006), tensor(0.2907)]",tensor(1.),tensor(False),tensor(1.9900)


In [32]:
import pandas as pd
pd.set_option('display.max_rows', 50)

envs = ['CartPole-v1']*3
block = GymTransformBlock(None,nsteps=1,nskips=1,firstlast=True,bs=1,n=100)
pipes = block(envs)

pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward']][:50]

Unnamed: 0,state,action,terminated,reward
0,"[tensor(-0.0287), tensor(0.0127), tensor(-0.0329), tensor(-0.0367)]",tensor(1.),tensor(False),tensor(1.)
1,"[tensor(-0.0218), tensor(0.0449), tensor(-0.0045), tensor(-0.0490)]",tensor(0.),tensor(False),tensor(1.)
2,"[tensor(0.0354), tensor(-0.0155), tensor(-0.0413), tensor(-0.0165)]",tensor(0.),tensor(False),tensor(1.)
3,"[tensor(-0.0284), tensor(0.2083), tensor(-0.0336), tensor(-0.3396)]",tensor(1.),tensor(False),tensor(1.)
4,"[tensor(-0.0209), tensor(-0.1502), tensor(-0.0055), tensor(0.2423)]",tensor(0.),tensor(False),tensor(1.)
5,"[tensor(0.0351), tensor(-0.2100), tensor(-0.0416), tensor(0.2629)]",tensor(0.),tensor(False),tensor(1.)
6,"[tensor(-0.0243), tensor(0.4039), tensor(-0.0404), tensor(-0.6427)]",tensor(0.),tensor(False),tensor(1.)
7,"[tensor(-0.0239), tensor(-0.3452), tensor(-0.0007), tensor(0.5332)]",tensor(1.),tensor(False),tensor(1.)
8,"[tensor(0.0309), tensor(-0.4045), tensor(-0.0364), tensor(0.5422)]",tensor(1.),tensor(False),tensor(1.)
9,"[tensor(-0.0162), tensor(0.2093), tensor(-0.0532), tensor(-0.3630)]",tensor(1.),tensor(False),tensor(1.)


In [33]:
import pandas as pd
pd.set_option('display.max_rows', 50)

envs = ['CartPole-v1']*3
block = GymTransformBlock(None,nsteps=2,nskips=1,firstlast=True)
pipes = block.pipe_fn(envs,bs=1,n=100,**block.pipe_fn_kwargs)
pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward']][:50]

AttributeError: 'function' object has no attribute 'pipe_fn'

## Multi Processing

In [None]:
%%writefile ../external_run_scripts/spawn_multiproc.py
import torch
import torchdata.datapipes as dp
from torch.utils.data.dataloader_experimental import DataLoader2
       
class PointlessLoop(dp.iter.IterDataPipe):
    def __init__(self,datapipe=None):
        self.datapipe = datapipe
    
    def __iter__(self):
        while True:
            yield torch.LongTensor(4).detach().clone()
            

if __name__=='__main__':
    from torch.multiprocessing import Pool, Process, set_start_method
    try:
         set_start_method('spawn')
    except RuntimeError:
        pass


    pipe = PointlessLoop()
    pipe = pipe.header(limit=10)
    dls = [DataLoader2(pipe,num_workers=1)]
    # Setup the Learner
    print('type: ',type(dls[0]))
    for o in dls[0]:
        print(o)

Overwriting ../external_run_scripts/spawn_multiproc.py


In [None]:
%%python ../external_run_scripts/spawn_multiproc.py
print('hi')

type:  <class 'torch.utils.data.dataloader.DataLoader'>
tensor([[ 732160890645865842, 2314885530818453514, 2336931105441411616,
         7306086968447301733]])
tensor([[8602280404079371623, 8101817809722304878, 2314885530817031781,
         7306844544654385184]])
tensor([[2965946305821897833, 7379464786773281338, 3691363335768469024,
         2830311512315084125]])
tensor([[0, 0, 0, 0]])
tensor([[0, 0, 0, 0]])
tensor([[0, 0, 0, 0]])
tensor([[0, 0, 0, 0]])
tensor([[0, 0, 0, 0]])
tensor([[139733290683744, 139733290683744,               1,  94711839977328]])
tensor([[139733290683744, 139733290683744,               1,  94711839978096]])


In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()