In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp pipes.iter.nstep

In [None]:
#|export
# Python native modules
import os
from typing import Type, Dict, Union, Tuple
import typing
import warnings
# Third party libs
from fastcore.all import add_docs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import find_dps,DataPipeGraph,DataPipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe
# Local modules
from fastrl.core import StepTypes

# NStep
> DataPipe for producing grouped steps env-wise.

In [None]:
#|export
class NStepper(IterDataPipe):
    def __init__(
            self, 
            # The datapipe we are extracting from must produce `StepType.types`
            source_datapipe:IterDataPipe[Union[StepTypes.types]], 
            # Maximum number of steps to produce per yield as a tuple. This is the *max* number
            # and may be less if for example we are yielding terminal states.
            # Default produces single steps
            n:int=1
        ) -> None:
        self.source_datapipe:IterDataPipe[StepTypes.types] = source_datapipe
        self.n:int = n
        self.env_buffer:Dict = {}
        
    def __iter__(self) -> StepTypes.types:
        self.env_buffer = {}
        for step in self.source_datapipe:
            if not issubclass(step.__class__,StepTypes.types):
                raise Exception(f'Expected typing.NamedTuple object got {type(step)}\n{step}')
    
            env_id,terminated = int(step.env_id),bool(step.terminated)
        
            if env_id in self.env_buffer:
                self.env_buffer[env_id].append(step)
            else:
                self.env_buffer[env_id] = [step]
                
            if not terminated and len(self.env_buffer[env_id])<self.n: continue
            
            while terminated and len(self.env_buffer[env_id])!=0:
                yield tuple(self.env_buffer[env_id])
                self.env_buffer[env_id].pop(0)
                
            if not terminated:
                yield tuple(self.env_buffer[env_id])
                self.env_buffer[env_id].pop(0)
add_docs(
NStepper,
"""Accepts a `source_datapipe` or iterable whose `next()` produces a `StepType.types` of 
max size `n` that will contain steps from a single environment with 
a subset of fields from `SimpleStep`, namely `terminated` and `env_id`.""",
)

In [None]:
#|export
class NStepFlattener(IterDataPipe):
    def __init__(
            self, 
            # The datapipe we are extracting from must produce `StepType.types` or `Tuple[StepType.types]`
            source_datapipe:IterDataPipe[Union[StepTypes.types]], 
        ) -> None:
        self.source_datapipe:IterDataPipe[[StepTypes.types]] = source_datapipe
        
    def __iter__(self) -> StepTypes.types:
        for step in self.source_datapipe:
            if issubclass(step.__class__,StepTypes.types):
                # print(step)
                yield step
            elif isinstance(step,tuple):
                # print('got step: ',step)
                yield from step 
            else:
                raise Exception(f'Expected {StepTypes.types} or tuple object got {type(step)}\n{step}')

            
add_docs(
NStepFlattener,
"""Handles unwrapping `StepType.typess` in tuples better than `dp.iter.UnBatcher` and `dp.iter.Flattener`""",
)

In [None]:
#|hide
# Used here to avoid UserWarnings related to gym complaining about bounding box / action space format.
# There must be a bug in the CartPole-v1 env that is causing this to show. Also couldnt figure out the 
# regex, so instead we filter on the lineno, which is line 98.
warnings.filterwarnings("ignore",category=UserWarning,lineno=98)

Below we see an example where we collect 2 steps for each env, **then** yield them. This is useful for
training models of larger chunks of env step output.

In [None]:
import pandas as pd
import gymnasium as gym
from fastrl.envs.gym import GymStepper

In [None]:
def n_step_test(envs,total_steps,n=1,seed=0):
    pipe = dp.map.Mapper(envs)
    pipe = pipe.map(gym.make)
    pipe = dp.iter.MapToIterConverter(pipe)
    pipe = dp.iter.InMemoryCacheHolder(pipe)
    pipe = pipe.cycle()
    pipe = GymStepper(pipe,seed=seed)
    pipe = NStepper(pipe,n=n)
    pipe = NStepFlattener(pipe)
    pipe = pipe.header(total_steps)
    return list(pipe)

steps = n_step_test(['CartPole-v1']*3,200,2,0)
pd.DataFrame(steps)[['state','next_state','env_id','terminated']][:10]

## NStepper Tests

There are a couple properties that we expect from n-step output:
- tuples should be `n` size at max, however can be smaller.
- `done` n-steps unravel into multiple tuples yielded individually.

    - In other words if `n=3`, meaning we want to yield 3 blocks of steps per env, then if we have
      [step5,step6,step7] where step7 is `done` we will get individual tuples in the order:
      
          1. [step5,step6,step7]
          2. [step6,step7]
          3. [step7]

First, `NStepper(pipe,n=1)` when falttened should be identical to a pipelines that never used it.

In [None]:
import pandas as pd
from fastcore.all import test_eq
from fastrl.core import test_len,SimpleStep

In [None]:
pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = pipe.map(gym.make)
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,seed=0)
pipe = pipe.header(10)

no_n_steps = list(pipe)
steps = n_step_test(['CartPole-v1']*3,10,1,0)

If `n=1` we should expect that regardless of the number of envs, both n-step and simple environment
pipelines should be identical.

In [None]:
test_len(steps,no_n_steps)
for field in ['next_state','state','terminated']:
    for i,(step,no_n_step) in enumerate(zip(steps,no_n_steps)): 
        test_eq(getattr(step,field),getattr(no_n_step,field))

We should expect n=1 -> 3 to have the same basic shape...

In [None]:
steps1 = n_step_test(['CartPole-v1']*1,30,1,0)
steps2 = n_step_test(['CartPole-v1']*1,30,2,0)
steps3 = n_step_test(['CartPole-v1']*1,30,3,0)

In [None]:
import itertools

In [None]:
for o in itertools.chain(steps1,steps2,steps3):
    test_eq(len(o),12)
    test_eq(isinstance(o,SimpleStep),True)

In [None]:
#|export
def n_steps_expected(
    default_steps:int, # The number of steps the episode would run without n_steps
    n:int # The n-step value that we are planning ot use
):
    return (default_steps * n) - sum(range(n))
    
n_steps_expected.__doc__=r"""
Produces the expected number of steps, assuming a fully deterministic episode based on `default_steps` and `n`

Given `n=2`, given 1 envs, knowing that `CartPole-v1` when `seed=0` will always run 18 steps, the total 
steps will be:

$$
18 * n - \sum_{0}^{n - 1}(i)
$$
"""    

In [None]:
import torch

In [None]:
expected_n_steps = n_steps_expected(default_steps=18,n=2)
print('Given the above values, we expect a single episode to be ',expected_n_steps,' steps long')
steps = n_step_test(['CartPole-v1']*1,expected_n_steps+1,2,0)
# The first episode should have ended on row 34, beign 35 steps long. The 36th row should be a new episode
test_eq(steps[-2].terminated,torch.tensor([True]))
test_eq(steps[-2].episode_n,torch.tensor([1]))
test_eq(steps[-2].step_n,torch.tensor([18]))
test_eq(steps[-1].terminated,torch.tensor([False]))
test_eq(steps[-1].episode_n,torch.tensor([2]))
test_eq(steps[-1].step_n,torch.tensor([1]))

In [None]:
expected_n_steps = n_steps_expected(default_steps=18,n=4)
print('Given the above values, we expect a single episode to be ',expected_n_steps,' steps long')
steps = n_step_test(['CartPole-v1']*1,expected_n_steps+1,4,0)
# The first episode should have ended on row 34, beign 35 steps long. The 36th row should be a new episode
test_eq(steps[-2].terminated,torch.tensor([True]))
test_eq(steps[-2].episode_n,torch.tensor([1]))
test_eq(steps[-2].step_n,torch.tensor([18]))
test_eq(steps[-1].terminated,torch.tensor([False]))
test_eq(steps[-1].episode_n,torch.tensor([2]))
test_eq(steps[-1].step_n,torch.tensor([1]))

In [None]:
expected_n_steps = n_steps_expected(default_steps=18,n=2)
print('Given the above values, we expect a single episode to be ',expected_n_steps,' steps long')
steps = n_step_test(['CartPole-v1']*3,expected_n_steps*3+1,2,0)
# The first episode should have ended on row 34, beign 35 steps long. The 36th row should be a new episode
test_eq(steps[-2].terminated,torch.tensor([True]))
test_eq(steps[-2].episode_n,torch.tensor([1]))
test_eq(steps[-2].step_n,torch.tensor([18]))
test_eq(steps[-1].terminated,torch.tensor([False]))
test_eq(steps[-1].episode_n,torch.tensor([2]))
test_eq(steps[-1].step_n,torch.tensor([1]))

In [None]:
#|hide
#|eval: false
!nbdev_export