In [1]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
# default_exp envs.gym

In [4]:
# export
# Python native modules
import os
import warnings
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
import gym
from fastai.torch_basics import *
from fastai.torch_core import *
# Local modules
from fastrl.core import *
from fastrl.fastai.data.pipes.core import *
from fastrl.fastai.data.load import *
from fastrl.fastai.data.block import *
from fastrl.envs.core import *

# Envs Gym
> Fastrl API for working with OpenAI Gyms

### Pipes

In [5]:
# export
class GymTypeTransform(Transform):
    "Creates an gym.env"
    def encodes(self,o): return gym.make(o)


In [248]:
# export
class GymStepper(dp.iter.IterDataPipe):
    def __init__(self,
                 source_datapipe:Union[Iterable,dp.iter.IterDataPipe], # Calling `next()` should produce a `gym.Env`
                 agent=None, # Optional `Agent` that accepts a `SimpleStep` to produce a list of actions.
                 seed=None, # Optional seed to set the env to and also random action sames if `agent==None`
                 synchronized_reset=False # Some `gym.Envs` require reset to be done on *all* envs before proceeding to step.
                ):
        self.source_datapipe = source_datapipe
        self.agent = agent
        self.seed = seed
        self.synchronized_reset = synchronized_reset
        self._env_ids = {}
        
    def env_reset(self,env,env_id) -> SimpleStep:
        state = env.reset(seed=self.seed)
        env.action_space.seed(seed=self.seed)
        episode_n = self._env_ids[env_id].episode_n+1 if env_id in self._env_ids else tensor([1])

        step = SimpleStep(
            state=tensor(state),
            next_state=tensor(state),
            done=tensor(False),
            reward=tensor([0]),
            total_reward=tensor([0.]),
            env_id=tensor([env_id]),
            proc_id=tensor([os.getpid()]),
            step_n=tensor([0]),
            episode_n=episode_n,
        )
        self._env_ids[env_id] = step
        return step

    def __iter__(self) -> SimpleStep:
        for env in self.source_datapipe:
            assert issubclass(env.__class__,gym.Env),f'Expected subclass of gym.Env, but got {env.__class__}'    
            env_id = id(env)
            
            if env_id not in self._env_ids or self._env_ids[env_id].done:
                if self.synchronized_reset:
                    if env_id in self._env_ids \
                       and not self._env_ids[env_id].done \
                       and self._resetting_all:
                        # If this env has already been reset, and we are currently in the 
                        # self._resetting_all phase, then skip this so we can reset all remaining envs
                        continue
                    elif env_id not in self._env_ids \
                       or all([self._env_ids[s].done for s in self._env_ids])\
                       or self._resetting_all:
                        # If the id is not in the _env_ids, we can assume this is a fresh start.
                        # OR 
                        # If all the envs are done, then we can start doing a reset operation.
                        # OR
                        # If we are currently resetting all the envs anyways
                        # This means we want to reset ALL the envs before doing any steps.
                        self.env_reset(env,env_id)
                        # Move to the next env, eventually we will reset all the envs in sync.
                        # then we will be able to start calling `step` for each of them.
                        # _resetting_all is True when there are envs still "done".
                        self._resetting_all = any([self._env_ids[s].done for s in self._env_ids])
                        continue 
                    elif self._env_ids[env_id].done:
                        continue
                    else:
                        raise ValueError('This else should never happen.')
                else:
                    step = self.env_reset(env,env_id)
            else:
                step = self._env_ids[env_id]

            if self.agent is not None: self.agent.agent_base.iterator.append(step)
            
            action = None
            for action in ifnone(self.agent,[env.action_space.sample()]):
                next_state,reward,done,_ = env.step(action)

                step = SimpleStep(
                    state=tensor(step.next_state),
                    next_state=tensor(next_state),
                    action=tensor([action]).float(),
                    done=tensor([done]),
                    reward=tensor([reward]),
                    total_reward=step.total_reward+reward,
                    env_id=tensor([env_id]),
                    proc_id=tensor([os.getpid()]),
                    step_n=step.step_n+1,
                    episode_n=step.episode_n,
                )
                self._env_ids[env_id] = step
                yield step
                if done: break
            if action is None: 
                raise Exception('The agent produced no actions. This should never occur.')
                
add_docs(
    GymStepper,
    """Accepts a `source_datapipe` or iterable whose `next()` produces a single `gym.Env`.
       Tracks multiple envs using `id(env)`.""",
    env_reset="Resets a env given the env_id."
)

In [164]:
# hide
# Used here to avoid UserWarnings related to gym complaining about bounding box / action space format.
# There must be a bug in the CartPole-v1 env that is causing this to show. Also couldnt figure out the 
# regex, so instead we filter on the lineno, which is line 98.
warnings.filterwarnings("ignore",category=UserWarning,lineno=98)

## Iteration Examples

In [165]:
import pandas as pd

pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = TypeTransformLoop(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,seed=0)

pd.DataFrame([step for step,_ in zip(*(pipe,range(10)))])[['state','next_state','action','done']]

Unnamed: 0,state,next_state,action,done
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",[tensor(1.)],[tensor(False)]
1,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",[tensor(1.)],[tensor(False)]
2,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]","[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]",[tensor(1.)],[tensor(False)]
3,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",[tensor(1.)],[tensor(False)]
4,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",[tensor(1.)],[tensor(False)]
5,"[tensor(0.0132), tensor(0.1727), tensor(-0.0469), tensor(-0.3552)]","[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",[tensor(1.)],[tensor(False)]
6,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",[tensor(1.)],[tensor(False)]
7,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",[tensor(1.)],[tensor(False)]
8,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]","[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",[tensor(1.)],[tensor(False)]
9,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]","[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]",[tensor(0.)],[tensor(False)]


In [166]:
from torch.utils.data.dataloader_experimental import DataLoader2

def seed_worker(worker_id): torch.manual_seed(0)

dl = DataLoader2(pipe,num_workers=2,worker_init_fn=seed_worker)

pd.DataFrame([step for step,_ in zip(*(dl,range(10)))])[['state','next_state','action','done']]

Unnamed: 0,state,next_state,action,done
0,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[[tensor(0.)]],[[tensor(False)]]
1,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[[tensor(0.)]],[[tensor(False)]]
2,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[[tensor(0.)]],[[tensor(False)]]
3,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[[tensor(0.)]],[[tensor(False)]]
4,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[[tensor(0.)]],[[tensor(False)]]
5,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[[tensor(0.)]],[[tensor(False)]]
6,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[[tensor(0.)]],[[tensor(False)]]
7,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[[tensor(0.)]],[[tensor(False)]]
8,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[[tensor(0.)]],[[tensor(False)]]
9,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[[tensor(0.)]],[[tensor(False)]]


In [167]:
pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = TypeTransformLoop(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,synchronized_reset=True)

pd.DataFrame([step for step,_ in zip(*(dl,range(10)))])[['state','next_state','action','done']]

Unnamed: 0,state,next_state,action,done
0,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[[tensor(0.)]],[[tensor(False)]]
1,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[[tensor(0.)]],[[tensor(False)]]
2,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[[tensor(0.)]],[[tensor(False)]]
3,"[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]","[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]",[[tensor(0.)]],[[tensor(False)]]
4,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[[tensor(0.)]],[[tensor(False)]]
5,"[[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]]","[[tensor(0.0353), tensor(0.3702), tensor(-0.0866), tensor(-0.7006)]]",[[tensor(0.)]],[[tensor(False)]]
6,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[[tensor(0.)]],[[tensor(False)]]
7,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[[tensor(0.)]],[[tensor(False)]]
8,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[[tensor(0.)]],[[tensor(False)]]
9,"[[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]]","[[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]]",[[tensor(0.)]],[[tensor(False)]]


## Tests

We create 3 envs and put a max iteration count at 180. Each env will run for 18 steps before ending, which means
we expect there to be 10 total episodes.

In [185]:
import pandas as pd

envs = ['CartPole-v1']*3
n_episodes = 3

pipe = dp.map.Mapper(envs)
pipe = TypeTransformLoop(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
# We want to cycle through the envs enough times that their epsiode sum to 9, 3 episodes each
pipe = pipe.cycle(count=(18*len(envs))) 
pipe = GymStepper(pipe,seed=0)

All the of the environments should reach max 18 steps given a seed of 0...\
The total number of iterations should be `( 18 * n_envs) * n_episodes_per_env = 162`...

In [186]:
steps = list(pipe)
gsteps = groupby(steps,lambda o:int(o.step_n))
test_len(gsteps.keys(),18)
pd.DataFrame([step for step in steps])[['state','done','env_id','episode_n','step_n']][::10]

Unnamed: 0,state,done,env_id,episode_n,step_n
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]",[tensor(False)],[tensor(140235555030864)],[tensor(1)],[tensor(1)]
10,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",[tensor(False)],[tensor(140235556861136)],[tensor(1)],[tensor(4)]
20,"[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]",[tensor(False)],[tensor(140235558087696)],[tensor(1)],[tensor(7)]
30,"[tensor(0.0217), tensor(-0.4009), tensor(-0.0929), tensor(0.2661)]",[tensor(False)],[tensor(140235555030864)],[tensor(1)],[tensor(11)]
40,"[tensor(0.0094), tensor(0.1879), tensor(-0.0961), tensor(-0.6926)]",[tensor(False)],[tensor(140235556861136)],[tensor(1)],[tensor(14)]
50,"[tensor(0.0325), tensor(0.7771), tensor(-0.1570), tensor(-1.6694)]",[tensor(False)],[tensor(140235558087696)],[tensor(1)],[tensor(17)]
60,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",[tensor(False)],[tensor(140235555030864)],[tensor(2)],[tensor(3)]
70,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",[tensor(False)],[tensor(140235556861136)],[tensor(2)],[tensor(6)]
80,"[tensor(0.0417), tensor(-0.4040), tensor(-0.1113), tensor(0.3342)]",[tensor(False)],[tensor(140235558087696)],[tensor(2)],[tensor(9)]
90,"[tensor(0.0096), tensor(-0.0083), tensor(-0.0886), tensor(-0.3733)]",[tensor(False)],[tensor(140235555030864)],[tensor(2)],[tensor(13)]


All of the step groups should be the same length...

In [187]:
group_sz = None
for name,group in gsteps.items():
    if group_sz is None: group_sz = len(group)
    else:                assert len(group)==group_sz,f' Got lengths {len(group)} and {group_sz} for {name}.\n\n{group}'

Each step group's state and next_states should match across envs...

In [189]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other in group[1:]: test_eq(e1.state,other.state)
    for other in group[1:]: test_eq(e1.next_state,other.next_state)

Each step group value should not show up/be duplicated in any other step groups...

In [195]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other_name,other_group in gsteps.items():
        if other_name==name: continue
        for other in other_group[1:]: test_ne(e1.state,other.state)
        for other in other_group[1:]: test_ne(e1.next_state,other.next_state)

Given 3 envs, single steps, epsiodes of 18 steps in len, 3 episodes each, run for 162 iterations, we should
expect there to be 9 dones.

In [198]:
test_eq(sum([o.done for o in steps]),tensor([9]))

The max episode numbers for each env should sum to 9 where for each env, it should reach and finish 3 episodes...

In [178]:
gsteps = groupby(steps,lambda o:int(o.env_id))
test_len(gsteps.keys(),3)
env1,env2,env3 = L(gsteps.values()).map(L).map(Self.map(Self.episode_n()).map(int))
test_eq(max(env1)+max(env2)+max(env3),9)

### Test the `synchronized_reset` param...
> In this case, we will have iterate through the 3 envs without producing a step on warmup.

In [292]:
import pandas as pd

envs = ['CartPole-v1']*3
n_episodes = 3

pipe = dp.map.Mapper(envs)
pipe = TypeTransformLoop(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
# We want to cycle through the envs enough times that their epsiode sum to 9, 3 episodes each
# We add an additional +3 cycles since `synchronized_reset` cycles through the envs additional times
# to make sure they are all reset prior to stepping
pipe = pipe.cycle(count=(18*len(envs))+3) 
pipe = GymStepper(pipe,seed=0,synchronized_reset=True)

In [293]:
steps = list(pipe)
gsteps = groupby(steps,lambda o:int(o.step_n))
test_len(gsteps.keys(),18)
pd.DataFrame([step for step in steps])[['state','done','env_id','episode_n','step_n']][::10]

Unnamed: 0,state,done,env_id,episode_n,step_n
0,"[tensor(0.0137), tensor(-0.0230), tensor(-0.0459), tensor(-0.0483)]",[tensor(False)],[tensor(140235564835600)],[tensor(1)],[tensor(1)]
10,"[tensor(0.0241), tensor(0.5643), tensor(-0.0672), tensor(-0.9714)]",[tensor(False)],[tensor(140235564837840)],[tensor(1)],[tensor(4)]
20,"[tensor(0.0463), tensor(-0.0172), tensor(-0.1094), tensor(-0.1771)]",[tensor(False)],[tensor(140235556184784)],[tensor(1)],[tensor(7)]
30,"[tensor(0.0217), tensor(-0.4009), tensor(-0.0929), tensor(0.2661)]",[tensor(False)],[tensor(140235564835600)],[tensor(1)],[tensor(11)]
40,"[tensor(0.0094), tensor(0.1879), tensor(-0.0961), tensor(-0.6926)]",[tensor(False)],[tensor(140235564837840)],[tensor(1)],[tensor(14)]
50,"[tensor(0.0325), tensor(0.7771), tensor(-0.1570), tensor(-1.6694)]",[tensor(False)],[tensor(140235556184784)],[tensor(1)],[tensor(17)]
60,"[tensor(0.0167), tensor(0.3685), tensor(-0.0540), tensor(-0.6622)]",[tensor(False)],[tensor(140235564835600)],[tensor(2)],[tensor(3)]
70,"[tensor(0.0427), tensor(0.1763), tensor(-0.1007), tensor(-0.4364)]",[tensor(False)],[tensor(140235564837840)],[tensor(2)],[tensor(6)]
80,"[tensor(0.0417), tensor(-0.4040), tensor(-0.1113), tensor(0.3342)]",[tensor(False)],[tensor(140235556184784)],[tensor(2)],[tensor(9)]
90,"[tensor(0.0096), tensor(-0.0083), tensor(-0.0886), tensor(-0.3733)]",[tensor(False)],[tensor(140235564835600)],[tensor(2)],[tensor(13)]


All of the step groups should be the same length...

In [254]:
group_sz = None
for name,group in gsteps.items():
    if group_sz is None: group_sz = len(group)
    else:                assert len(group)==group_sz,f' Got lengths {len(group)} and {group_sz} for {name}.\n\n{group}'

Each step group's state and next_states should match across envs...

In [255]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other in group[1:]: test_eq(e1.state,other.state)
    for other in group[1:]: test_eq(e1.next_state,other.next_state)

Each step group value should not show up/be duplicated in any other step groups...

In [256]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other_name,other_group in gsteps.items():
        if other_name==name: continue
        for other in other_group[1:]: test_ne(e1.state,other.state)
        for other in other_group[1:]: test_ne(e1.next_state,other.next_state)

Given 3 envs, single steps, epsiodes of 18 steps in len, 3 episodes each, run for 162 iterations, we should
expect there to be 9 dones.

In [257]:
test_eq(sum([o.done for o in steps]),tensor([9]))

The max episode numbers for each env should sum to 9 where for each env, it should reach and finish 3 episodes...

In [258]:
gsteps = groupby(steps,lambda o:int(o.env_id))
test_len(gsteps.keys(),3)
env1,env2,env3 = L(gsteps.values()).map(L).map(Self.map(Self.episode_n()).map(int))
test_eq(max(env1)+max(env2)+max(env3),9)

In [284]:
import pandas as pd

envs = ['CartPole-v1']*10

pipe = dp.map.Mapper(envs)
pipe = TypeTransformLoop(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle(count=(18*len(envs))) 
# Turn off the seed so that some envs end before others...
pipe = GymStepper(pipe,synchronized_reset=True)
steps = list(pipe)

Since the seed is turned off the only properties we are to expect are:
    
    - If an env finishes, no steps from that env should be seen until all 9 of the other envs finish

In [286]:
def synchronized_reset_checker(steps):
    env_id_done_tracker = {}
    did_syncs_happen = False
    for d,env_id,idx in [(bool(o.done),int(o.env_id),i) for i,o in enumerate(steps)]:

        if d: 
            env_id_done_tracker[env_id] = idx
            continue

        if env_id in env_id_done_tracker:
            if len(env_id_done_tracker)!=len(envs):
                raise Exception(f'env_id {env_id} was iterated through when it should not have been! idx: {idx}')
        if len(env_id_done_tracker)==len(envs):
            did_syncs_happen = True
            env_id_done_tracker = {}

    if not did_syncs_happen: 
        raise Exception('There should have at least been 1 time where all the envs had to reset, which did not happen.')
synchronized_reset_checker(steps)

For sanity, we should expect that without `synchronized_reset` envs will be reset and stepped through before other 
envs are reset, `synchronized_reset_checker` should fail.

In [288]:
pipe = dp.map.Mapper(envs)
pipe = TypeTransformLoop(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle(count=(18*len(envs))) 
# Turn off the seed so that some envs end before others...
pipe = GymStepper(pipe)
steps = list(pipe)

In [290]:
with ExceptionExpected(regex='was iterated through when it should not have been'):
    synchronized_reset_checker(steps)

In [294]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import *
    make_readme()
    notebook2script(silent=True)

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
