In [None]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [None]:
#|default_exp envs.gym

In [None]:
#|export
# Python native modules
import os
import warnings
from typing import Callable,Any
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
import gym
import torch
from fastrl.torch_core import *
from torchdata.dataloader2.graph import find_dps,traverse
from fastrl.data.dataloader2 import *
from torchdata.dataloader2 import DataLoader2,DataLoader2Iterator
from torchdata.dataloader2.graph import find_dps,traverse,DataPipe,IterDataPipe,MapDataPipe
# Local modules
from fastrl.core import *
from fastrl.pipes.core import *
from fastrl.pipes.iter.nskip import *
from fastrl.pipes.iter.nstep import *
from fastrl.pipes.iter.firstlast import *
from fastrl.pipes.iter.transforms import *
from fastrl.pipes.map.transforms import *
from fastrl.data.block import *

# Envs Gym
> Fastrl API for working with OpenAI Gyms

### Pipes

In [None]:
#|export
class GymTypeTransform(Transform):
    "Creates an gym.env"
    def encodes(self,o): return gym.make(o,render_mode='rgb_array')

In [None]:
#|export
class GymStepper(dp.iter.IterDataPipe):
    def __init__(self,
        source_datapipe:Union[Iterable,dp.iter.IterDataPipe], # Calling `next()` should produce a `gym.Env`
        agent=None, # Optional `Agent` that accepts a `SimpleStep` to produce a list of actions.
        seed:int=None, # Optional seed to set the env to and also random action sames if `agent==None`
        synchronized_reset:bool=False, # Some `gym.Envs` require reset to be terminated on *all* envs before proceeding to step.
        include_images:bool=False,
    ):
        self.source_datapipe = source_datapipe
        self.agent = agent
        self.seed = seed
        self.include_images = include_images
        self.synchronized_reset = synchronized_reset
        self._env_ids = {}
        
    def env_reset(self,
      env:gym.Env, # The env to rest along with its numeric object id
      env_id:int # Resets env in `self._env_ids[env_id]`
    ) -> SimpleStep:
        state, info = env.reset(seed=self.seed)
        env.action_space.seed(seed=self.seed)
        episode_n = self._env_ids[env_id].episode_n+1 if env_id in self._env_ids else tensor(1)

        step = (self.no_agent_create_step if self.agent is None else self.agent.create_step)(
            state=tensor(state),
            next_state=tensor(state),
            terminated=tensor(False),
            truncated=tensor(False),
            reward=tensor(0),
            total_reward=tensor(0.),
            env_id=tensor(env_id),
            proc_id=tensor(os.getpid()),
            step_n=tensor(0),
            episode_n=episode_n,
            # image=env.render(mode='rgb_array') if self.include_images else torch.FloatTensor([0])
            image=env.render() if self.include_images else torch.FloatTensor([0])
        )
        self._env_ids[env_id] = step
        return step
    
    def no_agent_create_step(self,**kwargs): return SimpleStep(**kwargs)

    def __iter__(self) -> SimpleStep:
        for env in self.source_datapipe:
            assert issubclass(env.__class__,gym.Env),f'Expected subclass of gym.Env, but got {env.__class__}'    
            env_id = id(env)
            
            if env_id not in self._env_ids or self._env_ids[env_id].terminated:
                if self.synchronized_reset:
                    if env_id in self._env_ids \
                       and not self._env_ids[env_id].terminated \
                       and self._resetting_all:
                        # If this env has already been reset, and we are currently in the 
                        # self._resetting_all phase, then skip this so we can reset all remaining envs
                        continue
                    elif env_id not in self._env_ids \
                       or all([self._env_ids[s].terminated for s in self._env_ids])\
                       or self._resetting_all:
                        # If the id is not in the _env_ids, we can assume this is a fresh start.
                        # OR 
                        # If all the envs are terminated, then we can start doing a reset operation.
                        # OR
                        # If we are currently resetting all the envs anyways
                        # This means we want to reset ALL the envs before doing any steps.
                        self.env_reset(env,env_id)
                        # Move to the next env, eventually we will reset all the envs in sync.
                        # then we will be able to start calling `step` for each of them.
                        # _resetting_all is True when there are envs still "terminated".
                        self._resetting_all = any([self._env_ids[s].terminated for s in self._env_ids])
                        continue 
                    elif self._env_ids[env_id].terminated:
                        continue
                    else:
                        raise ValueError('This else should never happen.')
                else:
                    step = self.env_reset(env,env_id)
            else:
                step = self._env_ids[env_id]

            action = None
            for action in (self.agent([step]) if self.agent is not None else [env.action_space.sample()]):
                next_state,reward,terminated,truncated,_ = env.step(
                    self.agent.augment_actions(action) if self.agent is not None else action
                )

                step = (self.no_agent_create_step if self.agent is None else self.agent.create_step)(
                    state=tensor(step.next_state),
                    next_state=tensor(next_state),
                    action=tensor(action).float(),
                    terminated=tensor(terminated),
                    truncated=tensor(truncated),
                    reward=tensor(reward),
                    total_reward=step.total_reward+reward,
                    env_id=tensor(env_id),
                    proc_id=tensor(os.getpid()),
                    step_n=step.step_n+1,
                    episode_n=step.episode_n,
                    # image=env.render(mode='rgb_array') if self.include_images else torch.FloatTensor([0])
                    image=env.render() if self.include_images else torch.FloatTensor([0])
                )
                self._env_ids[env_id] = step
                yield step
                if terminated: break
            if action is None: 
                raise Exception('The agent produced no actions. This should never occur.')
                
add_docs(
    GymStepper,
    """Accepts a `source_datapipe` or iterable whose `next()` produces a single `gym.Env`.
       Tracks multiple envs using `id(env)`.""",
    env_reset="Resets a env given the env_id.",
    no_agent_create_step="If there is no agent for creating the step output, then `GymStepper` will create its own"
)

In [None]:
#|hide
# Used here to avoid UserWarnings related to gym complaining about bounding box / action space format.
# There must be a bug in the CartPole-v1 env that is causing this to show. Also couldnt figure out the 
# regex, so instead we filter on the lineno, which is line 98.
warnings.filterwarnings("ignore",category=UserWarning,lineno=98)

## Iteration Examples

In [None]:
import pandas as pd
from fastrl.agents.core import * 

In [None]:
#|eval: False
class ConstantRunner(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe,constant=1,array_nestings=0): 
        self.source_datapipe = source_datapipe
        self.agent_base = find_dp(traverse(self.source_datapipe),AgentBase)
        self.constant = constant
        self.array_nestings = array_nestings
    
    def __iter__(self):
        for o in self.source_datapipe: 
            try: 
                if self.array_nestings==0: yield self.constant
                else:
                    yield [self.constant]*self.array_nestings
            except Exception:
                print('Failed on ',o)
                raise

agent = AgentBase(None,[])
agent = ConstantRunner(agent)
agent = AgentHead(agent)

pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,agent=agent,seed=0)

pd.DataFrame([step for step,_ in zip(*(pipe,range(10)))])[['state','next_state','action','terminated']]

Unnamed: 0,state,next_state,action,terminated
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
1,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
2,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
3,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
4,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
5,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
6,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
7,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
8,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
9,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]","[tensor(0.0353), tensor(0.7603), tensor(-0.0866), tensor(-1.2844)]",tensor(1.),tensor(False)


In [None]:
pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,seed=0)

pd.DataFrame([step for step,_ in zip(*(pipe,range(10)))])[['state','next_state','action','terminated']]

Unnamed: 0,state,next_state,action,terminated
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
1,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
2,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
3,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
4,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
5,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
6,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
7,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
8,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
9,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]","[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]",tensor(0.),tensor(False)


In [None]:
from torch.utils.data.dataloader_experimental import DataLoader2

In [None]:
def seed_worker(worker_id): torch.manual_seed(0)

dl = DataLoader2(pipe,num_workers=2,worker_init_fn=seed_worker)

pd.DataFrame([step for step,_ in zip(*(dl,range(10)))])[['state','next_state','action','terminated']]

Unnamed: 0,state,next_state,action,terminated
0,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
1,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
2,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
3,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
4,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[tensor(0.)],[tensor(False)]
5,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[tensor(0.)],[tensor(False)]
6,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
7,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
8,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
9,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]


In [None]:
pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,synchronized_reset=True)

pd.DataFrame([step for step,_ in zip(*(dl,range(10)))])[['state','next_state','action','terminated']]

Unnamed: 0,state,next_state,action,terminated
0,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
1,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
2,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
3,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[tensor(0.)],[tensor(False)]
4,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[tensor(0.)],[tensor(False)]
5,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[tensor(0.)],[tensor(False)]
6,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
7,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
8,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]
9,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[tensor(0.)],[tensor(False)]


## Tests

We create 3 envs and put a max iteration count at 180. Each env will run for 18 steps before ending, which means
we expect there to be 10 total episodes.

In [None]:
envs = ['CartPole-v1']*3
n_episodes = 3

pipe = dp.map.Mapper(envs)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
# We want to cycle through the envs enough times that their epsiode sum to 9, 3 episodes each
pipe = pipe.cycle(count=(18*len(envs))) 
pipe = GymStepper(pipe,seed=0)

All the of the environments should reach max 18 steps given a seed of 0...\
The total number of iterations should be `( 18 * n_envs) * n_episodes_per_env = 162`...

In [None]:
steps = list(pipe)
gsteps = groupby(steps,lambda o:int(o.step_n))
test_len(gsteps.keys(),18)
pd.DataFrame([step for step in steps])[['state','terminated','env_id','episode_n','step_n']][::10]

Unnamed: 0,state,terminated,env_id,episode_n,step_n
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]",tensor(False),tensor(140532810719952),tensor(1),tensor(1)
10,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(False),tensor(140532809147408),tensor(1),tensor(4)
20,"[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]",tensor(False),tensor(140532716950928),tensor(1),tensor(7)
30,"[tensor(0.0217), tensor(-0.4009), tensor(-0.0929), tensor(0.2661)]",tensor(False),tensor(140532810719952),tensor(1),tensor(11)
40,"[tensor(0.0094), tensor(0.1879), tensor(-0.0961), tensor(-0.6926)]",tensor(False),tensor(140532809147408),tensor(1),tensor(14)
50,"[tensor(0.0325), tensor(0.7771), tensor(-0.1570), tensor(-1.6694)]",tensor(False),tensor(140532716950928),tensor(1),tensor(17)
60,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(False),tensor(140532810719952),tensor(2),tensor(3)
70,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",tensor(False),tensor(140532809147408),tensor(2),tensor(6)
80,"[tensor(0.0417), tensor(-0.4040), tensor(-0.1113), tensor(0.3342)]",tensor(False),tensor(140532716950928),tensor(2),tensor(9)
90,"[tensor(0.0096), tensor(-0.0083), tensor(-0.0886), tensor(-0.3733)]",tensor(False),tensor(140532810719952),tensor(2),tensor(13)


All of the step groups should be the same length...

In [None]:
group_sz = None
for name,group in gsteps.items():
    if group_sz is None: group_sz = len(group)
    else:                assert len(group)==group_sz,f' Got lengths {len(group)} and {group_sz} for {name}.\n\n{group}'

Each step group's state and next_states should match across envs...

In [None]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other in group[1:]: test_eq(e1.state,other.state)
    for other in group[1:]: test_eq(e1.next_state,other.next_state)

Each step group value should not show up/be duplicated in any other step groups...

In [None]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other_name,other_group in gsteps.items():
        if other_name==name: continue
        for other in other_group[1:]: test_ne(e1.state,other.state)
        for other in other_group[1:]: test_ne(e1.next_state,other.next_state)

Given 3 envs, single steps, epsiodes of 18 steps in len, 3 episodes each, run for 162 iterations, we should
expect there to be 9 dones.

In [None]:
test_eq(sum([o.terminated for o in steps]),tensor([9]))

The max episode numbers for each env should sum to 9 where for each env, it should reach and finish 3 episodes...

In [None]:
gsteps = groupby(steps,lambda o:int(o.env_id))
test_len(gsteps.keys(),3)
env1,env2,env3 = L(gsteps.values()).map(L).map(Self.map(Self.episode_n()).map(int))
test_eq(max(env1)+max(env2)+max(env3),9)

### Test the `synchronized_reset` param...
> In this case, we will have iterate through the 3 envs without producing a step on warmup.

In [None]:
envs = ['CartPole-v1']*3
n_episodes = 3

pipe = dp.map.Mapper(envs)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
# We want to cycle through the envs enough times that their epsiode sum to 9, 3 episodes each
# We add an additional +3 cycles since `synchronized_reset` cycles through the envs additional times
# to make sure they are all reset prior to stepping
pipe = pipe.cycle(count=(18*len(envs))+3) 
pipe = GymStepper(pipe,seed=0,synchronized_reset=True)

In [None]:
steps = list(pipe)
gsteps = groupby(steps,lambda o:int(o.step_n))
test_len(gsteps.keys(),18)
pd.DataFrame([step for step in steps])[['state','terminated','env_id','episode_n','step_n']][::10]

Unnamed: 0,state,terminated,env_id,episode_n,step_n
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]",tensor(False),tensor(140532715706768),tensor(1),tensor(1)
10,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(False),tensor(140532715706256),tensor(1),tensor(4)
20,"[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]",tensor(False),tensor(140532828444688),tensor(1),tensor(7)
30,"[tensor(0.0217), tensor(-0.4009), tensor(-0.0929), tensor(0.2661)]",tensor(False),tensor(140532715706768),tensor(1),tensor(11)
40,"[tensor(0.0094), tensor(0.1879), tensor(-0.0961), tensor(-0.6926)]",tensor(False),tensor(140532715706256),tensor(1),tensor(14)
50,"[tensor(0.0325), tensor(0.7771), tensor(-0.1570), tensor(-1.6694)]",tensor(False),tensor(140532828444688),tensor(1),tensor(17)
60,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(False),tensor(140532715706768),tensor(2),tensor(3)
70,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",tensor(False),tensor(140532715706256),tensor(2),tensor(6)
80,"[tensor(0.0417), tensor(-0.4040), tensor(-0.1113), tensor(0.3342)]",tensor(False),tensor(140532828444688),tensor(2),tensor(9)
90,"[tensor(0.0096), tensor(-0.0083), tensor(-0.0886), tensor(-0.3733)]",tensor(False),tensor(140532715706768),tensor(2),tensor(13)


All of the step groups should be the same length...

In [None]:
group_sz = None
for name,group in gsteps.items():
    if group_sz is None: group_sz = len(group)
    else:                assert len(group)==group_sz,f' Got lengths {len(group)} and {group_sz} for {name}.\n\n{group}'

Each step group's state and next_states should match across envs...

In [None]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other in group[1:]: test_eq(e1.state,other.state)
    for other in group[1:]: test_eq(e1.next_state,other.next_state)

Each step group value should not show up/be duplicated in any other step groups...

In [None]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other_name,other_group in gsteps.items():
        if other_name==name: continue
        for other in other_group[1:]: test_ne(e1.state,other.state)
        for other in other_group[1:]: test_ne(e1.next_state,other.next_state)

Given 3 envs, single steps, epsiodes of 18 steps in len, 3 episodes each, run for 162 iterations, we should
expect there to be 9 dones.

In [None]:
test_eq(sum([o.terminated for o in steps]),tensor([9]))

The max episode numbers for each env should sum to 9 where for each env, it should reach and finish 3 episodes...

In [None]:
gsteps = groupby(steps,lambda o:int(o.env_id))
test_len(gsteps.keys(),3)
env1,env2,env3 = L(gsteps.values()).map(L).map(Self.map(Self.episode_n()).map(int))
test_eq(max(env1)+max(env2)+max(env3),9)

In [None]:
envs = ['CartPole-v1']*10

pipe = dp.map.Mapper(envs)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle(count=(18*len(envs))) 
# Turn off the seed so that some envs end before others...
pipe = GymStepper(pipe,synchronized_reset=True)
steps = list(pipe)

Since the seed is turned off the only properties we are to expect are:
    
    - If an env finishes, no steps from that env should be seen until all 9 of the other envs finish

In [None]:
def synchronized_reset_checker(steps):
    env_id_done_tracker = {}
    did_syncs_happen = False
    for d,env_id,idx in [(bool(o.terminated),int(o.env_id),i) for i,o in enumerate(steps)]:

        if d: 
            env_id_done_tracker[env_id] = idx
            continue

        if env_id in env_id_done_tracker:
            if len(env_id_done_tracker)!=len(envs):
                raise Exception(f'env_id {env_id} was iterated through when it should not have been! idx: {idx}')
        if len(env_id_done_tracker)==len(envs):
            did_syncs_happen = True
            env_id_done_tracker = {}

    if not did_syncs_happen: 
        raise Exception('There should have at least been 1 time where all the envs had to reset, which did not happen.')
synchronized_reset_checker(steps)

For sanity, we should expect that without `synchronized_reset` envs will be reset and stepped through before other 
envs are reset, `synchronized_reset_checker` should fail.

In [None]:
pipe = dp.map.Mapper(envs)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle(count=(18*len(envs))) 
# Turn off the seed so that some envs end before others...
pipe = GymStepper(pipe)
steps = list(pipe)

In [None]:
with ExceptionExpected(regex='was iterated through when it should not have been'):
    synchronized_reset_checker(steps)

In [None]:
#|export
class GymTransformBlock():

    def __init__(self,
        agent:DataPipe, # An AgentHead
        seed:Optional[int]=None, # The seed for the gym to use
        # Used by `NStepper`, outputs tuples / chunks of assiciated steps
        nsteps:int=1, 
        # Used by `NSkipper` to skip a certain number of steps (agent still gets called for each)
        nskips:int=1,
        # Whether when nsteps>1 to merge it into a single `StepType`
        firstlast:bool=False,
        # Functions to run once, at the beginning of the pipeline
        type_tfms:Optional[List[Callable]]=None,
        # Functions to run over individual steps before batching
        item_tfms:Optional[List[Callable]]=None,
        # Functions to run over batches (as specified by `bs`)
        batch_tfms:Optional[List[Callable]]=None,
        # The batch size, which is different from `nsteps` in that firstlast will be 
        # run prior to batching, and a batch of steps might come from multiple envs,
        # where nstep is associated with a single env
        bs:int=1,
        # The prefered default is for the pipeline to be infinate, and the learner
        # decides how much to iter. If this is not None, then the pipeline will run for 
        # that number of `n`
        n:Optional[int]=None,
        # Whether to reset all the envs at the same time as opposed to reseting them 
        # the moment an episode ends. 
        synchronized_reset:bool=False,
        # Should be used only for validation / logging, will grab a render of the gym
        # and assign to the `StepType` image field. This data should not be used for training.
        # If it images are needed for training, then you should wrap the env instead. 
        include_images:bool=False,
        # Additional pipelines to insert, replace, remove
        dp_augmentation_fns:Tuple[DataPipeAugmentationFn]=None
    ) -> None:
        "Basic OpenAi gym `DataPipeGraph` with first-last, nstep, and nskip capability"
        self.agent = agent
        store_attr()

    def __call__(
        self,
        # `source` likely will be an iterable that gets pushed into the pipeline when an 
        # experiment is actually being run.
        source:Any,
        # Any parameters needed for the dataloader
        num_workers:int=0,
        # This param must exist: as_dataloader for the datablock to create dataloaders
        as_dataloader:bool=False
    ) -> DataPipeOrDataLoader:
        _type_tfms = ifnone(self.type_tfms,GymTypeTransform)
        "This is the function that is actually run by `DataBlock`"
        pipe = dp.map.Mapper(source)
        pipe = TypeTransformer(pipe,_type_tfms)
        pipe = dp.iter.MapToIterConverter(pipe)
        pipe = dp.iter.InMemoryCacheHolder(pipe)
        pipe = pipe.cycle() # Cycle through the envs inf
        pipe = GymStepper(pipe,agent=self.agent,seed=self.seed,
                          include_images=self.include_images,
                          synchronized_reset=self.synchronized_reset)
        if self.nskips!=1: pipe = NSkipper(pipe,n=self.nskips)
        if self.nsteps!=1:
            pipe = NStepper(pipe,n=self.nsteps)
            if self.firstlast:
                pipe = FirstLastMerger(pipe)
            else:
                pipe = NStepFlattener(pipe) # We dont want to flatten if using FirstLastMerger
        if self.n is not None: pipe = pipe.header(limit=self.n)
        pipe = ItemTransformer(pipe,self.item_tfms)
        pipe  = pipe.batch(batch_size=self.bs)
        pipe = BatchTransformer(pipe,self.batch_tfms)
        
        pipe = apply_dp_augmentation_fns(pipe,ifnone(self.dp_augmentation_fns,()))
        
        if as_dataloader:
            pipe = DataLoader2(
                datapipe=pipe,
                reading_service=PrototypeMultiProcessingReadingService(
                    num_workers = num_workers,
                    protocol_client_type = InputItemIterDataPipeQueueProtocolClient,
                    protocol_server_type = InputItemIterDataPipeQueueProtocolServer,
                    pipe_type = item_input_pipe_type,
                    eventloop = SpawnProcessForDataPipeline
                ) if num_workers>0 else None
            )
        return pipe

In [None]:
pd.set_option('display.max_rows', 50)

envs = ['CartPole-v1']*3
block = GymTransformBlock(None,nsteps=2,nskips=2,firstlast=True,bs=1,n=100)
pipes = block(envs)
pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward']][:50]

Unnamed: 0,state,action,terminated,reward
0,"[tensor(0.0281), tensor(0.0010), tensor(0.0271), tensor(0.0137)]",tensor(1.),tensor(False),tensor(1.9900)
1,"[tensor(0.0087), tensor(0.0049), tensor(0.0105), tensor(0.0164)]",tensor(1.),tensor(False),tensor(1.9900)
2,"[tensor(0.0029), tensor(-0.0071), tensor(-0.0149), tensor(-0.0203)]",tensor(1.),tensor(False),tensor(1.9900)
3,"[tensor(0.0281), tensor(0.1957), tensor(0.0274), tensor(-0.2703)]",tensor(0.),tensor(False),tensor(1.9900)
4,"[tensor(0.0088), tensor(0.1999), tensor(0.0108), tensor(-0.2730)]",tensor(1.),tensor(False),tensor(1.9900)
5,"[tensor(0.0028), tensor(0.1882), tensor(-0.0153), tensor(-0.3176)]",tensor(0.),tensor(False),tensor(1.9900)
6,"[tensor(0.0320), tensor(-0.1952), tensor(0.0226), tensor(0.3305)]",tensor(1.),tensor(False),tensor(1.9900)
7,"[tensor(0.0207), tensor(0.5899), tensor(-0.0059), tensor(-0.8532)]",tensor(1.),tensor(False),tensor(1.9900)
8,"[tensor(0.0064), tensor(0.1888), tensor(-0.0223), tensor(-0.3293)]",tensor(0.),tensor(False),tensor(1.9900)
9,"[tensor(0.0281), tensor(0.1943), tensor(0.0301), tensor(-0.2383)]",tensor(1.),tensor(False),tensor(1.9900)


In [None]:
pd.set_option('display.max_rows', 50)

envs = ['CartPole-v1']*3
block = GymTransformBlock(None,nsteps=1,nskips=1,firstlast=True,bs=1,n=100)
pipes = block(envs)

pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward']][:50]

Unnamed: 0,state,action,terminated,reward
0,"[tensor(-0.0469), tensor(-0.0038), tensor(0.0296), tensor(0.0157)]",tensor(0.),tensor(False),tensor(1.)
1,"[tensor(0.0044), tensor(-0.0130), tensor(0.0411), tensor(0.0286)]",tensor(0.),tensor(False),tensor(1.)
2,"[tensor(0.0370), tensor(0.0482), tensor(0.0339), tensor(-0.0419)]",tensor(1.),tensor(False),tensor(1.)
3,"[tensor(-0.0470), tensor(-0.1994), tensor(0.0299), tensor(0.3176)]",tensor(1.),tensor(False),tensor(1.)
4,"[tensor(0.0042), tensor(-0.2087), tensor(0.0417), tensor(0.3340)]",tensor(1.),tensor(False),tensor(1.)
5,"[tensor(0.0379), tensor(0.2428), tensor(0.0331), tensor(-0.3237)]",tensor(0.),tensor(False),tensor(1.)
6,"[tensor(-0.0510), tensor(-0.0047), tensor(0.0362), tensor(0.0345)]",tensor(0.),tensor(False),tensor(1.)
7,"[tensor(9.4943e-06), tensor(-0.0142), tensor(0.0483), tensor(0.0547)]",tensor(0.),tensor(False),tensor(1.)
8,"[tensor(0.0428), tensor(0.0473), tensor(0.0266), tensor(-0.0208)]",tensor(0.),tensor(False),tensor(1.)
9,"[tensor(-0.0511), tensor(-0.2003), tensor(0.0369), tensor(0.3384)]",tensor(0.),tensor(False),tensor(1.)


In [None]:
envs = ['CartPole-v1']*3
block = GymTransformBlock(None,nsteps=2,nskips=1,firstlast=True,bs=1,n=100)
pipes = block(envs)
pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward']][:50]

Unnamed: 0,state,action,terminated,reward
0,"[tensor(0.0226), tensor(-0.0454), tensor(-0.0444), tensor(-0.0308)]",tensor(1.),tensor(False),tensor(1.9900)
1,"[tensor(-0.0258), tensor(-0.0329), tensor(-0.0368), tensor(-0.0369)]",tensor(1.),tensor(False),tensor(1.9900)
2,"[tensor(0.0124), tensor(0.0284), tensor(0.0162), tensor(-0.0412)]",tensor(0.),tensor(False),tensor(1.9900)
3,"[tensor(0.0217), tensor(0.1504), tensor(-0.0450), tensor(-0.3371)]",tensor(0.),tensor(False),tensor(1.9900)
4,"[tensor(-0.0264), tensor(0.1627), tensor(-0.0375), tensor(-0.3410)]",tensor(0.),tensor(False),tensor(1.9900)
5,"[tensor(0.0129), tensor(-0.1669), tensor(0.0153), tensor(0.2566)]",tensor(0.),tensor(False),tensor(1.9900)
6,"[tensor(0.0247), tensor(-0.0441), tensor(-0.0517), tensor(-0.0589)]",tensor(0.),tensor(False),tensor(1.9900)
7,"[tensor(-0.0232), tensor(-0.0318), tensor(-0.0443), tensor(-0.0604)]",tensor(1.),tensor(False),tensor(1.9900)
8,"[tensor(0.0096), tensor(-0.3623), tensor(0.0205), tensor(0.5541)]",tensor(0.),tensor(False),tensor(1.9900)
9,"[tensor(0.0239), tensor(-0.2384), tensor(-0.0529), tensor(0.2170)]",tensor(0.),tensor(False),tensor(1.9900)


## Multi Processing

In [None]:
%%writefile ../external_run_scripts/spawn_multiproc.py
import torch
import torchdata.datapipes as dp
from torch.utils.data.dataloader_experimental import DataLoader2
       
class PointlessLoop(dp.iter.IterDataPipe):
    def __init__(self,datapipe=None):
        self.datapipe = datapipe
    
    def __iter__(self):
        while True:
            yield torch.LongTensor(4).detach().clone()
            

if __name__=='__main__':
    from torch.multiprocessing import Pool, Process, set_start_method
    try:
         set_start_method('spawn')
    except RuntimeError:
        pass


    pipe = PointlessLoop()
    pipe = pipe.header(limit=10)
    dls = [DataLoader2(pipe,num_workers=1)]
    # Setup the Learner
    print('type: ',type(dls[0]))
    for o in dls[0]:
        print(o)

Overwriting ../external_run_scripts/spawn_multiproc.py


In [None]:
%%python ../external_run_scripts/spawn_multiproc.py
print('hi')

type:  <class 'torch.utils.data.dataloader.DataLoader'>
tensor([[140433870888096, 140433870888096,               0,               0]])
tensor([[140433870888272, 140433870888272, 140433881812912, 140433882656048]])
tensor([[4816656997978587317,     140431723606784,     140431722231568,
         5543175674065308325]])
tensor([[7810759217884700704, 8313414454589354784,  751086762754076261,
         2314885530818453536]])
tensor([[7380385434628089195, 7308049669733556268, 2314885530451061113,
          731596513467179040]])
tensor([[2314885530449815335, 8463501003136704544, 7811887317905337970,
         7598244868600897637]])
tensor([[4212112949293312288, 2314885530818453514, 2334669043537027104,
         4404647383869256819]])
tensor([[6877116887384392576, 5767527136280273261, 8319104453167704431,
         2954485645205008204]])
tensor([[ 94183108464848,  94183087009168, 140431950437552,               2]])
tensor([[ 94183108468672,  94183087009168, 140431718197584,               2]])


In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()