In [1]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [2]:
#|default_exp envs.gym

In [3]:
#|export
# Python native modules
import os
import warnings
from functools import partial
from typing import Callable, Any, Union, Iterable, Optional
# Third party libs
import gymnasium as gym
import torch
# from fastrl.torch_core import *
from fastcore.all import add_docs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import find_dps,DataPipeGraph,DataPipe,traverse_dps
from torchdata.dataloader2 import MultiProcessingReadingService,DataLoader2
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe
# Local modules
from fastrl.core import StepType,SimpleStep
from fastrl.pipes.core import find_dps
from fastrl.pipes.iter.nskip import NSkipper
from fastrl.pipes.iter.nstep import NStepper,NStepFlattener
from fastrl.pipes.iter.firstlast import FirstLastMerger
import fastrl.pipes.iter.cacheholder

# Envs Gym
> Fastrl API for working with OpenAI Gyms

### Pipes

In [4]:
#|export
class GymStepper(dp.iter.IterDataPipe):
    def __init__(self,
        source_datapipe:Union[Iterable,dp.iter.IterDataPipe], # Calling `next()` should produce a `gym.Env`
        agent=None, # Optional `Agent` that accepts a `SimpleStep` to produce a list of actions.
        seed:int=None, # Optional seed to set the env to and also random action sames if `agent==None`
        synchronized_reset:bool=False, # Some `gym.Envs` require reset to be terminated on *all* envs before proceeding to step.
        include_images:bool=False, # Render images from the environment
        terminate_on_truncation:bool=True
    ):
        self.source_datapipe = source_datapipe
        self.agent = agent
        self.seed = seed
        self.include_images = include_images
        self.synchronized_reset = synchronized_reset
        self.terminate_on_truncation = terminate_on_truncation
        self._env_ids = {}
        
    def env_reset(self,
      env:gym.Env, # The env to rest along with its numeric object id
      env_id:int # Resets env in `self._env_ids[env_id]`
    ) -> StepType:
        state, info = env.reset(seed=self.seed)
        env.action_space.seed(seed=self.seed)
        episode_n = self._env_ids[env_id].episode_n+1 if env_id in self._env_ids else torch.tensor(1)

        step = (self.no_agent_create_step if self.agent is None else self.agent.create_step)(
            state=torch.tensor(state),
            next_state=torch.tensor(state),
            terminated=torch.tensor(False),
            truncated=torch.tensor(False),
            reward=torch.tensor(0),
            total_reward=torch.tensor(0.),
            env_id=torch.tensor(env_id),
            proc_id=torch.tensor(os.getpid()),
            step_n=torch.tensor(0),
            episode_n=episode_n,
            # image=env.render(mode='rgb_array') if self.include_images else torch.FloatTensor([0])
            image=torch.tensor(env.render()) if self.include_images else torch.FloatTensor([0]),
            raw_action=torch.FloatTensor([0])
        )
        self._env_ids[env_id] = step
        return step
    
    def no_agent_create_step(self,**kwargs): return SimpleStep(**kwargs)

    def __iter__(self) -> SimpleStep:
        for env in self.source_datapipe:
            assert issubclass(env.__class__,gym.Env),f'Expected subclass of gym.Env, but got {env.__class__}'    
            env_id = id(env)
            
            if env_id not in self._env_ids or self._env_ids[env_id].terminated:
                if self.synchronized_reset:
                    if env_id in self._env_ids \
                       and not self._env_ids[env_id].terminated \
                       and self._resetting_all:
                        # If this env has already been reset, and we are currently in the 
                        # self._resetting_all phase, then skip this so we can reset all remaining envs
                        continue
                    elif env_id not in self._env_ids \
                       or all([self._env_ids[s].terminated for s in self._env_ids])\
                       or self._resetting_all:
                        # If the id is not in the _env_ids, we can assume this is a fresh start.
                        # OR 
                        # If all the envs are terminated, then we can start doing a reset operation.
                        # OR
                        # If we are currently resetting all the envs anyways
                        # This means we want to reset ALL the envs before doing any steps.
                        self.env_reset(env,env_id)
                        # Move to the next env, eventually we will reset all the envs in sync.
                        # then we will be able to start calling `step` for each of them.
                        # _resetting_all is True when there are envs still "terminated".
                        self._resetting_all = any([self._env_ids[s].terminated for s in self._env_ids])
                        continue 
                    elif self._env_ids[env_id].terminated:
                        continue
                    else:
                        raise ValueError('This else should never happen.')
                else:
                    step = self.env_reset(env,env_id)
            else:
                step = self._env_ids[env_id]

            action = None
            raw_action = None
            for action in (self.agent([step]) if self.agent is not None else [env.action_space.sample()]):
                if isinstance(action,tuple):
                    action, raw_action = action
                next_state,reward,terminated,truncated,_ = env.step(
                    self.agent.augment_actions(action) if self.agent is not None else action
                )

                if self.terminate_on_truncation and truncated: terminated = True

                step = (self.no_agent_create_step if self.agent is None else self.agent.create_step)(
                    state=step.next_state.clone().detach(),
                    next_state=torch.tensor(next_state),
                    action=torch.tensor(action).float(),
                    terminated=torch.tensor(terminated),
                    truncated=torch.tensor(truncated),
                    reward=torch.tensor(reward),
                    total_reward=step.total_reward+reward,
                    env_id=torch.tensor(env_id),
                    proc_id=torch.tensor(os.getpid()),
                    step_n=step.step_n+1,
                    episode_n=step.episode_n,
                    # image=env.render(mode='rgb_array') if self.include_images else torch.FloatTensor([0])
                    image=torch.tensor(env.render()) if self.include_images else torch.FloatTensor([0]),
                    raw_action=raw_action
                )
                self._env_ids[env_id] = step
                yield step
                if terminated: break
            if action is None: 
                raise Exception('The agent produced no actions. This should never occur.')
                
add_docs(
GymStepper,
"""Accepts a `source_datapipe` or iterable whose `next()` produces a single `gym.Env`.
    Tracks multiple envs using `id(env)`.""",
env_reset="Resets a env given the env_id.",
no_agent_create_step="If there is no agent for creating the step output, then `GymStepper` will create its own",
reset="Resets the env's back to original str types to avoid pickling issues."
)

In [5]:
#|hide
# Used here to avoid UserWarnings related to gym complaining about bounding box / action space format.
# There must be a bug in the CartPole-v1 env that is causing this to show. Also couldnt figure out the 
# regex, so instead we filter on the lineno, which is line 98.
warnings.filterwarnings("ignore",category=UserWarning,lineno=98)

## Iteration Examples

In [6]:
import pandas as pd
from fastrl.agents.core import * 

In [7]:
#|eval: False
class ConstantRunner(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe,constant=1,array_nestings=0): 
        self.source_datapipe = source_datapipe
        self.agent_base = find_dps(traverse_dps(self.source_datapipe),AgentBase)
        self.constant = constant
        self.array_nestings = array_nestings
    
    def __iter__(self):
        for o in self.source_datapipe: 
            try: 
                if self.array_nestings==0: yield self.constant
                else:
                    yield [self.constant]*self.array_nestings
            except Exception:
                print('Failed on ',o)
                raise

agent = AgentBase(None,[])
agent = ConstantRunner(agent)
agent = AgentHead(agent)

pipe = dp.iter.IterableWrapper(['CartPole-v1']*3)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
# pipe = TypeTransformer(pipe,[GymTypeTransform])
# pipe = dp.iter.MapToIterConverter(pipe)
pipe = pipe.pickleable_in_memory_cache()
pipe = pipe.cycle()
pipe = GymStepper(pipe,agent=agent,seed=0)

pd.DataFrame([step for step,_ in zip(*(pipe,range(10)))])[['state','next_state','action','terminated']]

Unnamed: 0,state,next_state,action,terminated
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
1,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
2,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False)
3,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
4,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
5,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False)
6,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
7,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
8,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False)
9,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]","[tensor(0.0353), tensor(0.7603), tensor(-0.0866), tensor(-1.2844)]",tensor(1.),tensor(False)


In [8]:
from functools import partial

In [9]:
pipe = dp.iter.IterableWrapper(['CartPole-v1']*3)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
# pipe = TypeTransformer(pipe,[GymTypeTransform])
# pipe = dp.iter.MapToIterConverter(pipe)
pipe = pipe.pickleable_in_memory_cache()
pipe = pipe.cycle()
pipe = GymStepper(pipe,seed=0,include_images=True)

pd.DataFrame([step for step,_ in zip(*(pipe,range(10)))])[['state','next_state','action','terminated','env_id','step_n']]

Unnamed: 0,state,next_state,action,terminated,env_id,step_n
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False),tensor(140516764166176),tensor(1)
1,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False),tensor(140516764761824),tensor(1)
2,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False),tensor(140516715729248),tensor(1)
3,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False),tensor(140516764166176),tensor(2)
4,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False),tensor(140516764761824),tensor(2)
5,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False),tensor(140516715729248),tensor(2)
6,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False),tensor(140516764166176),tensor(3)
7,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False),tensor(140516764761824),tensor(3)
8,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False),tensor(140516715729248),tensor(3)
9,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]","[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]",tensor(0.),tensor(False),tensor(140516764166176),tensor(4)


In [10]:
from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService

In [11]:
def seed_worker(pipe,worker_info): 
    torch.manual_seed(0)
    return pipe

dl = DataLoader2(pipe,reading_service=MultiProcessingReadingService(
        num_workers = 0,
        worker_init_fn=seed_worker
    )
)

pd.DataFrame([step for step,_ in zip(*(dl,range(10)))])[['state','next_state','action','terminated','env_id']]



Unnamed: 0,state,next_state,action,terminated,env_id
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False),tensor(140516715918912)
1,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False),tensor(140516715919632)
2,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False),tensor(140516715935344)
3,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False),tensor(140516715918912)
4,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False),tensor(140516715919632)
5,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(1.),tensor(False),tensor(140516715935344)
6,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False),tensor(140516715918912)
7,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False),tensor(140516715919632)
8,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(1.),tensor(False),tensor(140516715935344)
9,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]","[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]",tensor(0.),tensor(False),tensor(140516715918912)


In [12]:
pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,synchronized_reset=True)

pd.DataFrame([step for step,_ in zip(*(dl,range(10)))])[['state','next_state','action','terminated']]

Unnamed: 0,state,next_state,action,terminated
0,"[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]","[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",tensor(0.),tensor(False)
1,"[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]","[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",tensor(0.),tensor(False)
2,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]","[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]",tensor(0.),tensor(False)
3,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]","[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]",tensor(0.),tensor(False)
4,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]","[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]",tensor(0.),tensor(False)
5,"[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]","[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",tensor(0.),tensor(False)
6,"[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]","[tensor(0.0459), tensor(-0.2106), tensor(-0.1129), tensor(0.0792)]",tensor(0.),tensor(False)
7,"[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]","[tensor(0.0459), tensor(-0.2106), tensor(-0.1129), tensor(0.0792)]",tensor(0.),tensor(False)
8,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]","[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]",tensor(0.),tensor(False)
9,"[tensor(0.0459), tensor(-0.2106), tensor(-0.1129), tensor(0.0792)]","[tensor(0.0417), tensor(-0.4040), tensor(-0.1113), tensor(0.3342)]",tensor(0.),tensor(False)


## Tests

We create 3 envs and put a max iteration count at 180. Each env will run for 18 steps before ending, which means
we expect there to be 10 total episodes.

In [13]:
envs = ['CartPole-v1']*3
n_episodes = 3

pipe = dp.map.Mapper(envs)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
# We want to cycle through the envs enough times that their epsiode sum to 9, 3 episodes each
pipe = pipe.cycle(count=(18*len(envs))) 
pipe = GymStepper(pipe,seed=0)

All the of the environments should reach max 18 steps given a seed of 0...\
The total number of iterations should be `( 18 * n_envs) * n_episodes_per_env = 162`...

In [14]:
from fastrl.core import test_len
from fastcore.all import test_eq,test_ne
from itertools import groupby

In [15]:
steps = list(pipe)
gsteps = dict(groupby(steps,lambda o:int(o.step_n)))
test_len(gsteps.keys(),18)
pd.DataFrame([step for step in steps])[['state','terminated','env_id','episode_n','step_n']][::10]

Unnamed: 0,state,terminated,env_id,episode_n,step_n
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]",tensor(False),tensor(140519661545936),tensor(1),tensor(1)
10,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(False),tensor(140516715889760),tensor(1),tensor(4)
20,"[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]",tensor(False),tensor(140519661546080),tensor(1),tensor(7)
30,"[tensor(0.0217), tensor(-0.4009), tensor(-0.0929), tensor(0.2661)]",tensor(False),tensor(140519661545936),tensor(1),tensor(11)
40,"[tensor(0.0094), tensor(0.1879), tensor(-0.0961), tensor(-0.6926)]",tensor(False),tensor(140516715889760),tensor(1),tensor(14)
50,"[tensor(0.0325), tensor(0.7771), tensor(-0.1570), tensor(-1.6694)]",tensor(False),tensor(140519661546080),tensor(1),tensor(17)
60,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(False),tensor(140519661545936),tensor(2),tensor(3)
70,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",tensor(False),tensor(140516715889760),tensor(2),tensor(6)
80,"[tensor(0.0417), tensor(-0.4040), tensor(-0.1113), tensor(0.3342)]",tensor(False),tensor(140519661546080),tensor(2),tensor(9)
90,"[tensor(0.0096), tensor(-0.0083), tensor(-0.0886), tensor(-0.3733)]",tensor(False),tensor(140519661545936),tensor(2),tensor(13)


All of the step groups should be the same length...

In [16]:
group_sz = None
gsteps = {k:list(v) for k,v in groupby(steps,lambda o:int(o.step_n))}
for name,group in gsteps.items():
    if group_sz is None: group_sz = len(group)
    else:                assert len(group)==group_sz,f' Got lengths {len(group)} and {group_sz} for {name}.\n\n{group}'

Each step group's state and next_states should match across envs...

In [17]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other in group[1:]: test_eq(e1.state,other.state)
    for other in group[1:]: test_eq(e1.next_state,other.next_state)

Each step group value should not show up/be duplicated in any other step groups...

In [18]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other_name,other_group in gsteps.items():
        if other_name==name: continue
        for other in other_group[1:]: test_ne(e1.state,other.state)
        for other in other_group[1:]: test_ne(e1.next_state,other.next_state)

Given 3 envs, single steps, epsiodes of 18 steps in len, 3 episodes each, run for 162 iterations, we should
expect there to be 9 dones.

In [19]:
import torch
from fastcore.all import L,Self

In [20]:
test_eq(sum([o.terminated for o in steps]),torch.tensor([9]))

The max episode numbers for each env should sum to 9 where for each env, it should reach and finish 3 episodes...

In [21]:
gsteps = {k:list(v) for k,v in groupby(steps,lambda o:int(o.env_id))}
test_len(gsteps.keys(),3)
env1,env2,env3 = L(gsteps.values()).map(L).map(Self.map(Self.episode_n()).map(int))
test_eq(max(env1)+max(env2)+max(env3),9)

### Test the `synchronized_reset` param...
> In this case, we will have iterate through the 3 envs without producing a step on warmup.

In [22]:
envs = ['CartPole-v1']*3
n_episodes = 3

pipe = dp.map.Mapper(envs)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
# We want to cycle through the envs enough times that their epsiode sum to 9, 3 episodes each
# We add an additional +3 cycles since `synchronized_reset` cycles through the envs additional times
# to make sure they are all reset prior to stepping
pipe = pipe.cycle(count=(18*len(envs))+3) 
pipe = GymStepper(pipe,seed=0,synchronized_reset=True)

In [23]:
steps = list(pipe)
gsteps = {k:list(v) for k,v in groupby(steps,lambda o:int(o.step_n))}
test_len(gsteps.keys(),18)
pd.DataFrame([step for step in steps])[['state','terminated','env_id','episode_n','step_n']][::10]

Unnamed: 0,state,terminated,env_id,episode_n,step_n
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]",tensor(False),tensor(140519661356656),tensor(1),tensor(1)
10,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(False),tensor(140516764166320),tensor(1),tensor(4)
20,"[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]",tensor(False),tensor(140516764166752),tensor(1),tensor(7)
30,"[tensor(0.0217), tensor(-0.4009), tensor(-0.0929), tensor(0.2661)]",tensor(False),tensor(140519661356656),tensor(1),tensor(11)
40,"[tensor(0.0094), tensor(0.1879), tensor(-0.0961), tensor(-0.6926)]",tensor(False),tensor(140516764166320),tensor(1),tensor(14)
50,"[tensor(0.0325), tensor(0.7771), tensor(-0.1570), tensor(-1.6694)]",tensor(False),tensor(140516764166752),tensor(1),tensor(17)
60,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",tensor(False),tensor(140519661356656),tensor(2),tensor(3)
70,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",tensor(False),tensor(140516764166320),tensor(2),tensor(6)
80,"[tensor(0.0417), tensor(-0.4040), tensor(-0.1113), tensor(0.3342)]",tensor(False),tensor(140516764166752),tensor(2),tensor(9)
90,"[tensor(0.0096), tensor(-0.0083), tensor(-0.0886), tensor(-0.3733)]",tensor(False),tensor(140519661356656),tensor(2),tensor(13)


All of the step groups should be the same length...

In [24]:
group_sz = None
for name,group in gsteps.items():
    if group_sz is None: group_sz = len(group)
    else:                assert len(group)==group_sz,f' Got lengths {len(group)} and {group_sz} for {name}.\n\n{group}'

Each step group's state and next_states should match across envs...

In [25]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other in group[1:]: test_eq(e1.state,other.state)
    for other in group[1:]: test_eq(e1.next_state,other.next_state)

Each step group value should not show up/be duplicated in any other step groups...

In [26]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other_name,other_group in gsteps.items():
        if other_name==name: continue
        for other in other_group[1:]: test_ne(e1.state,other.state)
        for other in other_group[1:]: test_ne(e1.next_state,other.next_state)

Given 3 envs, single steps, epsiodes of 18 steps in len, 3 episodes each, run for 162 iterations, we should
expect there to be 9 dones.

In [27]:
test_eq(sum([o.terminated for o in steps]),torch.tensor([9]))

The max episode numbers for each env should sum to 9 where for each env, it should reach and finish 3 episodes...

In [28]:
# gsteps = groupby(steps,lambda o:int(o.env_id))
gsteps = {k:list(v) for k,v in groupby(steps,lambda o:int(o.env_id))}
test_len(gsteps.keys(),3)
env1,env2,env3 = L(gsteps.values()).map(L).map(Self.map(Self.episode_n()).map(int))
test_eq(max(env1)+max(env2)+max(env3),9)

In [29]:
envs = ['CartPole-v1']*10

pipe = dp.map.Mapper(envs)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle(count=(18*len(envs))) 
# Turn off the seed so that some envs end before others...
pipe = GymStepper(pipe,synchronized_reset=True)
steps = list(pipe)

Since the seed is turned off the only properties we are to expect are:
    
    - If an env finishes, no steps from that env should be seen until all 9 of the other envs finish

In [30]:
def synchronized_reset_checker(steps):
    env_id_done_tracker = {}
    did_syncs_happen = False
    for d,env_id,idx in [(bool(o.terminated),int(o.env_id),i) for i,o in enumerate(steps)]:

        if d: 
            env_id_done_tracker[env_id] = idx
            continue

        if env_id in env_id_done_tracker:
            if len(env_id_done_tracker)!=len(envs):
                raise Exception(f'env_id {env_id} was iterated through when it should not have been! idx: {idx}')
        if len(env_id_done_tracker)==len(envs):
            did_syncs_happen = True
            env_id_done_tracker = {}

    if not did_syncs_happen: 
        raise Exception('There should have at least been 1 time where all the envs had to reset, which did not happen.')
synchronized_reset_checker(steps)

For sanity, we should expect that without `synchronized_reset` envs will be reset and stepped through before other 
envs are reset, `synchronized_reset_checker` should fail.

In [31]:
pipe = dp.map.Mapper(envs)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle(count=(18*len(envs))) 
# Turn off the seed so that some envs end before others...
pipe = GymStepper(pipe)
steps = list(pipe)

In [32]:
from fastcore.test import ExceptionExpected

In [33]:
with ExceptionExpected(regex='was iterated through when it should not have been'):
    synchronized_reset_checker(steps)

In [34]:
#|export
def GymDataPipe(
    source,
    agent:DataPipe=None, # An AgentHead
    seed:Optional[int]=None, # The seed for the gym to use
    # Used by `NStepper`, outputs tuples / chunks of assiciated steps
    nsteps:int=1, 
    # Used by `NSkipper` to skip a certain number of steps (agent still gets called for each)
    nskips:int=1,
    # Whether when nsteps>1 to merge it into a single `StepType`
    firstlast:bool=False,
    # The batch size, which is different from `nsteps` in that firstlast will be 
    # run prior to batching, and a batch of steps might come from multiple envs,
    # where nstep is associated with a single env
    bs:int=1,
    # The prefered default is for the pipeline to be infinate, and the learner
    # decides how much to iter. If this is not None, then the pipeline will run for 
    # that number of `n`
    n:Optional[int]=None,
    # Whether to reset all the envs at the same time as opposed to reseting them 
    # the moment an episode ends. 
    synchronized_reset:bool=False,
    # Should be used only for validation / logging, will grab a render of the gym
    # and assign to the `StepType` image field. This data should not be used for training.
    # If it images are needed for training, then you should wrap the env instead. 
    include_images:bool=False,
    # If an environment truncates, terminate it.
    terminate_on_truncation:bool=True
) -> Callable:
    "Basic `gymnasium` `DataPipeGraph` with first-last, nstep, and nskip capability"
    pipe = dp.iter.IterableWrapper(source)
    if include_images:
        pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
    else:
        pipe = pipe.map(gym.make)
    # pipe = dp.iter.InMemoryCacheHolder(pipe)
    pipe = pipe.pickleable_in_memory_cache()
    pipe = pipe.cycle() # Cycle through the envs inf
    pipe = GymStepper(pipe,agent=agent,seed=seed,
                        include_images=include_images,
                        terminate_on_truncation=terminate_on_truncation,
                        synchronized_reset=synchronized_reset)
    if nskips!=1: pipe = NSkipper(pipe,n=nskips)
    if nsteps!=1:
        pipe = NStepper(pipe,n=nsteps)
        if firstlast:
            pipe = FirstLastMerger(pipe)
        else:
            pipe = NStepFlattener(pipe) # We dont want to flatten if using FirstLastMerger
    if n is not None: pipe = pipe.header(limit=n)
    pipe  = pipe.batch(batch_size=bs)
    return pipe

In [35]:
pd.set_option('display.max_rows', 50)

envs = ['CartPole-v1']*3
pipes = GymDataPipe(envs,None,seed=0,nsteps=2,nskips=2,firstlast=True,bs=1,n=100)
pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward','step_n']][:50]

Unnamed: 0,state,action,terminated,reward,step_n
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]",tensor(1.),tensor(False),tensor(1.9900),tensor(2)
1,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]",tensor(1.),tensor(False),tensor(1.9900),tensor(2)
2,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]",tensor(1.),tensor(False),tensor(1.9900),tensor(2)
3,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False),tensor(1.9900),tensor(4)
4,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False),tensor(1.9900),tensor(4)
5,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",tensor(1.),tensor(False),tensor(1.9900),tensor(4)
6,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(0.),tensor(False),tensor(1.9900),tensor(6)
7,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(0.),tensor(False),tensor(1.9900),tensor(6)
8,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",tensor(0.),tensor(False),tensor(1.9900),tensor(6)
9,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",tensor(0.),tensor(False),tensor(1.9900),tensor(8)


In [36]:
pd.set_option('display.max_rows', 50)

envs = ['CartPole-v1']*3
pipes = GymDataPipe(envs,None,nsteps=1,nskips=1,firstlast=True,bs=1,n=100)


pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward']][:50]

Unnamed: 0,state,action,terminated,reward
0,"[tensor(0.0024), tensor(-0.0011), tensor(-0.0332), tensor(-0.0257)]",tensor(0.),tensor(False),tensor(1.)
1,"[tensor(0.0215), tensor(0.0402), tensor(-0.0449), tensor(0.0047)]",tensor(1.),tensor(False),tensor(1.)
2,"[tensor(0.0430), tensor(-0.0106), tensor(0.0240), tensor(-0.0464)]",tensor(0.),tensor(False),tensor(1.)
3,"[tensor(0.0024), tensor(-0.1958), tensor(-0.0337), tensor(0.2564)]",tensor(0.),tensor(False),tensor(1.)
4,"[tensor(0.0223), tensor(0.2359), tensor(-0.0448), tensor(-0.3018)]",tensor(0.),tensor(False),tensor(1.)
5,"[tensor(0.0428), tensor(-0.2061), tensor(0.0231), tensor(0.2538)]",tensor(1.),tensor(False),tensor(1.)
6,"[tensor(-0.0015), tensor(-0.3904), tensor(-0.0286), tensor(0.5382)]",tensor(0.),tensor(False),tensor(1.)
7,"[tensor(0.0270), tensor(0.0414), tensor(-0.0508), tensor(-0.0236)]",tensor(1.),tensor(False),tensor(1.)
8,"[tensor(0.0387), tensor(-0.0113), tensor(0.0281), tensor(-0.0316)]",tensor(1.),tensor(False),tensor(1.)
9,"[tensor(-0.0093), tensor(-0.5851), tensor(-0.0178), tensor(0.8218)]",tensor(0.),tensor(False),tensor(1.)


In [37]:
envs = ['CartPole-v1']*3
pipes = GymDataPipe(envs,None,nsteps=2,nskips=1,firstlast=True,bs=1,n=100)

pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward']][:50]

Unnamed: 0,state,action,terminated,reward
0,"[tensor(-0.0069), tensor(0.0185), tensor(-0.0397), tensor(0.0332)]",tensor(0.),tensor(False),tensor(1.9900)
1,"[tensor(0.0189), tensor(-0.0227), tensor(0.0293), tensor(-0.0298)]",tensor(1.),tensor(False),tensor(1.9900)
2,"[tensor(0.0235), tensor(-0.0195), tensor(-0.0209), tensor(0.0408)]",tensor(1.),tensor(False),tensor(1.9900)
3,"[tensor(-0.0066), tensor(-0.1760), tensor(-0.0391), tensor(0.3130)]",tensor(1.),tensor(False),tensor(1.9900)
4,"[tensor(0.0184), tensor(0.1720), tensor(0.0288), tensor(-0.3131)]",tensor(1.),tensor(False),tensor(1.9900)
5,"[tensor(0.0232), tensor(0.1760), tensor(-0.0201), tensor(-0.2584)]",tensor(1.),tensor(False),tensor(1.9900)
6,"[tensor(-0.0101), tensor(0.0197), tensor(-0.0328), tensor(0.0083)]",tensor(0.),tensor(False),tensor(1.9900)
7,"[tensor(0.0219), tensor(0.3667), tensor(0.0225), tensor(-0.5966)]",tensor(1.),tensor(False),tensor(1.9900)
8,"[tensor(0.0267), tensor(0.3714), tensor(-0.0252), tensor(-0.5573)]",tensor(1.),tensor(False),tensor(1.9900)
9,"[tensor(-0.0097), tensor(-0.1750), tensor(-0.0327), tensor(0.2905)]",tensor(0.),tensor(False),tensor(1.9900)


## Multi Processing

In [38]:
from torchdata.dataloader2 import DataLoader2

In [39]:
%%writefile ../external_run_scripts/spawn_multiproc.py
import torch
import torchdata.datapipes as dp
from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService
       
class PointlessLoop(dp.iter.IterDataPipe):
    def __init__(self,datapipe=None):
        self.datapipe = datapipe
    
    def __iter__(self):
        while True:
            yield torch.LongTensor(4).detach().clone()
            

if __name__=='__main__':
    from torch.multiprocessing import Pool, Process, set_start_method
    try:
         set_start_method('spawn')
    except RuntimeError:
        pass


    pipe = PointlessLoop()
    pipe = pipe.header(limit=10)
    dls = [DataLoader2(pipe,
            reading_service=MultiProcessingReadingService(
                num_workers = 2
            ))]
    # Setup the Learner
    print('type: ',type(dls[0]))
    for o in dls[0]:
        print(o)

Overwriting ../external_run_scripts/spawn_multiproc.py


In [40]:
%%python ../external_run_scripts/spawn_multiproc.py
pass

type:  <class 'torchdata.dataloader2.dataloader2.DataLoader2'>
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])
tensor([0, 0, 0, 0])


In [41]:
#|hide
#|eval: false
!nbdev_export

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
