In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp envs.gym

In [None]:
#|export
# Python native modules
import os
import warnings
from functools import partial
from typing import Callable, Any, Union, Iterable, Optional
# Third party libs
import gymnasium as gym
import torch
# from fastrl.torch_core import *
from fastcore.all import add_docs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import find_dps,DataPipeGraph,DataPipe,traverse_dps
from torchdata.dataloader2 import MultiProcessingReadingService,DataLoader2
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe
# Local modules
from fastrl.core import StepTypes,SimpleStep
from fastrl.pipes.core import find_dps
from fastrl.pipes.iter.nskip import NSkipper
from fastrl.pipes.iter.nstep import NStepper,NStepFlattener
from fastrl.pipes.iter.firstlast import FirstLastMerger
import fastrl.pipes.iter.cacheholder

# Envs Gym
> Fastrl API for working with OpenAI Gyms

### Pipes

In [None]:
#|export
class GymStepper(dp.iter.IterDataPipe):
    def __init__(self,
        source_datapipe:Union[Iterable,dp.iter.IterDataPipe], # Calling `next()` should produce a `gym.Env`
        agent=None, # Optional `Agent` that accepts a `SimpleStep` to produce a list of actions.
        seed:int=None, # Optional seed to set the env to and also random action sames if `agent==None`
        synchronized_reset:bool=False, # Some `gym.Envs` require reset to be terminated on *all* envs before proceeding to step.
        include_images:bool=False, # Render images from the environment
        terminate_on_truncation:bool=True
    ):
        self.source_datapipe = source_datapipe
        self.agent = agent
        self._agent_iter = None
        self.seed = seed
        self.include_images = include_images
        self.synchronized_reset = synchronized_reset
        self.terminate_on_truncation = terminate_on_truncation
        self._env_ids = {}
        
    def env_reset(self,
      env:gym.Env, # The env to rest along with its numeric object id
      env_id:int # Resets env in `self._env_ids[env_id]`
    ) -> StepTypes.types:
        # self.agent.reset()
        state, info = env.reset(seed=self.seed)
        env.action_space.seed(seed=self.seed)
        episode_n = self._env_ids[env_id].episode_n+1 if env_id in self._env_ids else torch.tensor(1)

        step = (self.no_agent_create_step if self.agent is None else self.agent.create_step)(
            state=torch.tensor(state),
            next_state=torch.tensor(state),
            terminated=torch.tensor(False),
            truncated=torch.tensor(False),
            reward=torch.tensor(0),
            total_reward=torch.tensor(0.),
            env_id=torch.tensor(env_id),
            proc_id=torch.tensor(os.getpid()),
            step_n=torch.tensor(0),
            episode_n=episode_n,
            # image=env.render(mode='rgb_array') if self.include_images else torch.FloatTensor([0])
            image=torch.tensor(env.render()) if self.include_images else torch.FloatTensor([0]),
            raw_action=torch.FloatTensor([0])
        )
        self._env_ids[env_id] = step
        return step
    
    def no_agent_create_step(self,**kwargs): return SimpleStep(**kwargs,batch_size=[])

    def __iter__(self) -> SimpleStep:
        for env in self.source_datapipe:
            assert issubclass(env.__class__,gym.Env),f'Expected subclass of gym.Env, but got {env.__class__}'    
            env_id = id(env)
            
            if env_id not in self._env_ids or self._env_ids[env_id].terminated:
                if self.synchronized_reset:
                    if env_id in self._env_ids \
                    and not self._env_ids[env_id].terminated \
                    and self._resetting_all:
                        # If this env has already been reset, and we are currently in the 
                        # self._resetting_all phase, then skip this so we can reset all remaining envs
                        continue
                    elif env_id not in self._env_ids \
                    or all([self._env_ids[s].terminated for s in self._env_ids])\
                    or self._resetting_all:
                        # If the id is not in the _env_ids, we can assume this is a fresh start.
                        # OR 
                        # If all the envs are terminated, then we can start doing a reset operation.
                        # OR
                        # If we are currently resetting all the envs anyways
                        # This means we want to reset ALL the envs before doing any steps.
                        self.env_reset(env,env_id)
                        # Move to the next env, eventually we will reset all the envs in sync.
                        # then we will be able to start calling `step` for each of them.
                        # _resetting_all is True when there are envs still "terminated".
                        self._resetting_all = any([self._env_ids[s].terminated for s in self._env_ids])
                        continue 
                    elif self._env_ids[env_id].terminated:
                        continue
                    else:
                        raise ValueError('This else should never happen.')
                else:
                    step = self.env_reset(env,env_id)
            else:
                step = self._env_ids[env_id]

            action = None
            raw_action = None
            self._agent_iter = iter((self.agent([step]) if self.agent is not None else [env.action_space.sample()]))
            while True:
                try:
                    action = next(self._agent_iter)
                    if isinstance(action,tuple):
                        action, raw_action = action
                    next_state,reward,terminated,truncated,_ = env.step(
                        self.agent.augment_actions(action) if self.agent is not None else action
                    )

                    if self.terminate_on_truncation and truncated: terminated = True

                    step = (self.no_agent_create_step if self.agent is None else self.agent.create_step)(
                        state=step.next_state.clone().detach(),
                        next_state=torch.tensor(next_state),
                        action=torch.tensor(action).float(),
                        terminated=torch.tensor(terminated),
                        truncated=torch.tensor(truncated),
                        reward=torch.tensor(reward),
                        total_reward=step.total_reward+reward,
                        env_id=torch.tensor(env_id),
                        proc_id=torch.tensor(os.getpid()),
                        step_n=step.step_n+1,
                        episode_n=step.episode_n,
                        # image=env.render(mode='rgb_array') if self.include_images else torch.FloatTensor([0])
                        image=torch.tensor(env.render()) if self.include_images else torch.FloatTensor([0]),
                        raw_action=raw_action if raw_action is not None else torch.FloatTensor([0])
                    )
                    self._env_ids[env_id] = step
                    yield step
                    if terminated: break
                except StopIteration:
                    self._agent_iter = None
                    break
                finally:
                    if self._agent_iter is not None:
                        while True:
                            try: next(self._agent_iter)
                            except StopIteration:break
            if action is None: 
                raise Exception('The agent produced no actions. This should never occur.')
                
add_docs(
GymStepper,
"""Accepts a `source_datapipe` or iterable whose `next()` produces a single `gym.Env`.
    Tracks multiple envs using `id(env)`.""",
env_reset="Resets a env given the env_id.",
no_agent_create_step="If there is no agent for creating the step output, then `GymStepper` will create its own",
reset="Resets the env's back to original str types to avoid pickling issues."
)

In [None]:
#|hide
# Used here to avoid UserWarnings related to gym complaining about bounding box / action space format.
# There must be a bug in the CartPole-v1 env that is causing this to show. Also couldnt figure out the 
# regex, so instead we filter on the lineno, which is line 98.
warnings.filterwarnings("ignore",category=UserWarning,lineno=98)

## Iteration Examples

In [None]:
import pandas as pd
from fastrl.agents.core import * 

In [None]:
#|eval: False
class ConstantRunner(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe,constant=1,array_nestings=0): 
        self.source_datapipe = source_datapipe
        self.agent_base = find_dps(traverse_dps(self.source_datapipe),AgentBase)
        self.constant = constant
        self.array_nestings = array_nestings
    
    def __iter__(self):
        for o in self.source_datapipe: 
            try: 
                if self.array_nestings==0: yield self.constant
                else:
                    yield [self.constant]*self.array_nestings
            except Exception:
                print('Failed on ',o)
                raise

agent = AgentBase(None,[])
agent = ConstantRunner(agent)
# Tests whether the agent is correctly being exhuasted / reset.
split_1,split_2 = agent.fork(2)
agent = split_1.zip(split_2)
agent = AgentHead(agent)

pipe = dp.iter.IterableWrapper(['CartPole-v1']*3)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
# pipe = TypeTransformer(pipe,[GymTypeTransform])
# pipe = dp.iter.MapToIterConverter(pipe)
pipe = pipe.pickleable_in_memory_cache()
pipe = pipe.cycle()
pipe = GymStepper(pipe,agent=agent,seed=0)

# pd.DataFrame([step for step,_ in zip(*(pipe,range(10)))])[['state','next_state','action','terminated']]

There show be no resetting buffer warnings. If the pipe is ended early, it should reset correctly.

In [None]:
#|eval: False
with warnings.catch_warnings(record=True) as w:
    # If you always want to trigger the warning, regardless of filter configurations:
    warnings.simplefilter("always")
    for step in pipe:
        break
    for step in pipe:
        break
    
    # If any warnings are triggered, fail the test
    assert len(w) == 0,f'There should be no warnings, but got: {[o.message for o in w]}'

In [None]:
from functools import partial

In [None]:
pipe = dp.iter.IterableWrapper(['CartPole-v1']*3)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
# pipe = TypeTransformer(pipe,[GymTypeTransform])
# pipe = dp.iter.MapToIterConverter(pipe)
pipe = pipe.pickleable_in_memory_cache()
pipe = pipe.cycle()
pipe = GymStepper(pipe,seed=0,include_images=True)

pd.DataFrame([step for step,_ in zip(*(pipe,range(10)))])[['state','next_state','action','terminated','env_id','step_n']]

In [None]:
from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService

In [None]:
def seed_worker(pipe,worker_info): 
    torch.manual_seed(0)
    return pipe

dl = DataLoader2(pipe,reading_service=MultiProcessingReadingService(
        num_workers = 0,
        worker_init_fn=seed_worker
    )
)

pd.DataFrame([step for step,_ in zip(*(dl,range(10)))])[['state','next_state','action','terminated','env_id']]

In [None]:
pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,synchronized_reset=True)

pd.DataFrame([step for step,_ in zip(*(dl,range(10)))])[['state','next_state','action','terminated']]

## Tests

We create 3 envs and put a max iteration count at 180. Each env will run for 18 steps before ending, which means
we expect there to be 10 total episodes.

In [None]:
envs = ['CartPole-v1']*3
n_episodes = 3

pipe = dp.map.Mapper(envs)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
# We want to cycle through the envs enough times that their epsiode sum to 9, 3 episodes each
pipe = pipe.cycle(count=(18*len(envs))) 
pipe = GymStepper(pipe,seed=0)

All the of the environments should reach max 18 steps given a seed of 0...\
The total number of iterations should be `( 18 * n_envs) * n_episodes_per_env = 162`...

In [None]:
from fastrl.core import test_len
from fastcore.all import test_eq,test_ne
from itertools import groupby

In [None]:
steps = list(pipe)
gsteps = dict(groupby(steps,lambda o:int(o.step_n)))
test_len(gsteps.keys(),18)
pd.DataFrame([step for step in steps])[['state','terminated','env_id','episode_n','step_n']][::10]

All of the step groups should be the same length...

In [None]:
group_sz = None
gsteps = {k:list(v) for k,v in groupby(steps,lambda o:int(o.step_n))}
for name,group in gsteps.items():
    if group_sz is None: group_sz = len(group)
    else:                assert len(group)==group_sz,f' Got lengths {len(group)} and {group_sz} for {name}.\n\n{group}'

Each step group's state and next_states should match across envs...

In [None]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other in group[1:]: test_eq(e1.state,other.state)
    for other in group[1:]: test_eq(e1.next_state,other.next_state)

Each step group value should not show up/be duplicated in any other step groups...

In [None]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other_name,other_group in gsteps.items():
        if other_name==name: continue
        for other in other_group[1:]: test_ne(e1.state,other.state)
        for other in other_group[1:]: test_ne(e1.next_state,other.next_state)

Given 3 envs, single steps, epsiodes of 18 steps in len, 3 episodes each, run for 162 iterations, we should
expect there to be 9 dones.

In [None]:
import torch
from fastcore.all import L,Self

In [None]:
test_eq(sum([o.terminated for o in steps]),torch.tensor([9]))

The max episode numbers for each env should sum to 9 where for each env, it should reach and finish 3 episodes...

In [None]:
gsteps = {k:list(v) for k,v in groupby(steps,lambda o:int(o.env_id))}
test_len(gsteps.keys(),3)
env1,env2,env3 = L(gsteps.values()).map(L).map(Self.map(Self.episode_n()).map(int))
test_eq(max(env1)+max(env2)+max(env3),9)

### Test the `synchronized_reset` param...
> In this case, we will have iterate through the 3 envs without producing a step on warmup.

In [None]:
envs = ['CartPole-v1']*3
n_episodes = 3

pipe = dp.map.Mapper(envs)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
# We want to cycle through the envs enough times that their epsiode sum to 9, 3 episodes each
# We add an additional +3 cycles since `synchronized_reset` cycles through the envs additional times
# to make sure they are all reset prior to stepping
pipe = pipe.cycle(count=(18*len(envs))+3) 
pipe = GymStepper(pipe,seed=0,synchronized_reset=True)

In [None]:
steps = list(pipe)
gsteps = {k:list(v) for k,v in groupby(steps,lambda o:int(o.step_n))}
test_len(gsteps.keys(),18)
pd.DataFrame([step for step in steps])[['state','terminated','env_id','episode_n','step_n']][::10]

All of the step groups should be the same length...

In [None]:
group_sz = None
for name,group in gsteps.items():
    if group_sz is None: group_sz = len(group)
    else:                assert len(group)==group_sz,f' Got lengths {len(group)} and {group_sz} for {name}.\n\n{group}'

Each step group's state and next_states should match across envs...

In [None]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other in group[1:]: test_eq(e1.state,other.state)
    for other in group[1:]: test_eq(e1.next_state,other.next_state)

Each step group value should not show up/be duplicated in any other step groups...

In [None]:
group_sz = None
for name,group in gsteps.items():
    e1 = group[0]
    for other_name,other_group in gsteps.items():
        if other_name==name: continue
        for other in other_group[1:]: test_ne(e1.state,other.state)
        for other in other_group[1:]: test_ne(e1.next_state,other.next_state)

Given 3 envs, single steps, epsiodes of 18 steps in len, 3 episodes each, run for 162 iterations, we should
expect there to be 9 dones.

In [None]:
test_eq(sum([o.terminated for o in steps]),torch.tensor([9]))

The max episode numbers for each env should sum to 9 where for each env, it should reach and finish 3 episodes...

In [None]:
# gsteps = groupby(steps,lambda o:int(o.env_id))
gsteps = {k:list(v) for k,v in groupby(steps,lambda o:int(o.env_id))}
test_len(gsteps.keys(),3)
env1,env2,env3 = L(gsteps.values()).map(L).map(Self.map(Self.episode_n()).map(int))
test_eq(max(env1)+max(env2)+max(env3),9)

In [None]:
envs = ['CartPole-v1']*10

pipe = dp.map.Mapper(envs)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle(count=(18*len(envs))) 
# Turn off the seed so that some envs end before others...
pipe = GymStepper(pipe,synchronized_reset=True)
steps = list(pipe)

Since the seed is turned off the only properties we are to expect are:
    
    - If an env finishes, no steps from that env should be seen until all 9 of the other envs finish

In [None]:
def synchronized_reset_checker(steps):
    env_id_done_tracker = {}
    did_syncs_happen = False
    for d,env_id,idx in [(bool(o.terminated),int(o.env_id),i) for i,o in enumerate(steps)]:

        if d: 
            env_id_done_tracker[env_id] = idx
            continue

        if env_id in env_id_done_tracker:
            if len(env_id_done_tracker)!=len(envs):
                raise Exception(f'env_id {env_id} was iterated through when it should not have been! idx: {idx}')
        if len(env_id_done_tracker)==len(envs):
            did_syncs_happen = True
            env_id_done_tracker = {}

    if not did_syncs_happen: 
        raise Exception('There should have at least been 1 time where all the envs had to reset, which did not happen.')
synchronized_reset_checker(steps)

For sanity, we should expect that without `synchronized_reset` envs will be reset and stepped through before other 
envs are reset, `synchronized_reset_checker` should fail.

In [None]:
pipe = dp.map.Mapper(envs)
pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle(count=(18*len(envs))) 
# Turn off the seed so that some envs end before others...
pipe = GymStepper(pipe)
steps = list(pipe)

In [None]:
from fastcore.test import ExceptionExpected

In [None]:
with ExceptionExpected(regex='was iterated through when it should not have been'):
    synchronized_reset_checker(steps)

In [None]:
#|export
def GymDataPipe(
    source,
    agent:DataPipe=None, # An AgentHead
    seed:Optional[int]=None, # The seed for the gym to use
    # Used by `NStepper`, outputs tuples / chunks of assiciated steps
    nsteps:int=1, 
    # Used by `NSkipper` to skip a certain number of steps (agent still gets called for each)
    nskips:int=1,
    # Whether when nsteps>1 to merge it into a single `StepType`
    firstlast:bool=False,
    # The batch size, which is different from `nsteps` in that firstlast will be 
    # run prior to batching, and a batch of steps might come from multiple envs,
    # where nstep is associated with a single env
    bs:int=1,
    # The prefered default is for the pipeline to be infinate, and the learner
    # decides how much to iter. If this is not None, then the pipeline will run for 
    # that number of `n`
    n:Optional[int]=None,
    # Whether to reset all the envs at the same time as opposed to reseting them 
    # the moment an episode ends. 
    synchronized_reset:bool=False,
    # Should be used only for validation / logging, will grab a render of the gym
    # and assign to the `StepType` image field. This data should not be used for training.
    # If it images are needed for training, then you should wrap the env instead. 
    include_images:bool=False,
    # If an environment truncates, terminate it.
    terminate_on_truncation:bool=True
) -> Callable:
    "Basic `gymnasium` `DataPipeGraph` with first-last, nstep, and nskip capability"
    pipe = dp.iter.IterableWrapper(source)
    if include_images:
        pipe = pipe.map(partial(gym.make,render_mode='rgb_array'))
    else:
        pipe = pipe.map(gym.make)
    # pipe = dp.iter.InMemoryCacheHolder(pipe)
    pipe = pipe.pickleable_in_memory_cache()
    pipe = pipe.cycle() # Cycle through the envs inf
    pipe = GymStepper(pipe,agent=agent,seed=seed,
                        include_images=include_images,
                        terminate_on_truncation=terminate_on_truncation,
                        synchronized_reset=synchronized_reset)
    if nskips!=1: pipe = NSkipper(pipe,n=nskips)
    if nsteps!=1:
        pipe = NStepper(pipe,n=nsteps)
        if firstlast:
            pipe = FirstLastMerger(pipe)
        else:
            pipe = NStepFlattener(pipe) # We dont want to flatten if using FirstLastMerger
    if n is not None: pipe = pipe.header(limit=n)
    pipe  = pipe.batch(batch_size=bs)
    return pipe

In [None]:
pd.set_option('display.max_rows', 50)

envs = ['CartPole-v1']*3
pipes = GymDataPipe(envs,None,seed=0,nsteps=2,nskips=2,firstlast=True,bs=1,n=100)
pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward','step_n']][:50]

In [None]:
pd.set_option('display.max_rows', 50)

envs = ['CartPole-v1']*3
pipes = GymDataPipe(envs,None,nsteps=1,nskips=1,firstlast=True,bs=1,n=100)


pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward']][:50]

In [None]:
envs = ['CartPole-v1']*3
pipes = GymDataPipe(envs,None,nsteps=2,nskips=1,firstlast=True,bs=1,n=100)

pd.DataFrame([o[0] for o in pipes])[['state','action','terminated','reward']][:50]

## Multi Processing

In [None]:
from torchdata.dataloader2 import DataLoader2

In [None]:
%%writefile ../external_run_scripts/spawn_multiproc.py
import torch
import torchdata.datapipes as dp
from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService
       
class PointlessLoop(dp.iter.IterDataPipe):
    def __init__(self,datapipe=None):
        self.datapipe = datapipe
    
    def __iter__(self):
        while True:
            yield torch.LongTensor(4).detach().clone()
            

if __name__=='__main__':
    from torch.multiprocessing import Pool, Process, set_start_method
    try:
         set_start_method('spawn')
    except RuntimeError:
        pass


    pipe = PointlessLoop()
    pipe = pipe.header(limit=10)
    dls = [DataLoader2(pipe,
            reading_service=MultiProcessingReadingService(
                num_workers = 2
            ))]
    # Setup the Learner
    print('type: ',type(dls[0]))
    for o in dls[0]:
        print(o)

In [None]:
%%python ../external_run_scripts/spawn_multiproc.py
pass

In [None]:
#|hide
#|eval: false
!nbdev_export