In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp pipes.iter.nskip

In [None]:
#|export
# Python native modules
import os
import warnings
from typing import Callable, Dict, Iterable, Optional, TypeVar, Type
# Third party libs
import torchdata.datapipes as dp
from torchdata.datapipes.iter import IterDataPipe
from fastcore.all import add_docs
# Local modules
from fastrl.core import StepTypes
# from fastrl.pipes.core import *
from fastrl.pipes.iter.nstep import NStepper
# from fastrl.data.block import *

# NSkip
> DataPipe for skipping env steps env-wise.

In [None]:
#|export
_msg = """
NSkipper should not go after NStepper. Please make the order:

```python
...
pipe = NSkipper(pipe,n=3)
pipe = NStepper(pipe,n=3)
...
```

"""

class NSkipper(IterDataPipe[StepTypes.types]):
    def __init__(
            self, 
            # The datapipe we are extracting from must produce `StepType`
            source_datapipe:IterDataPipe[StepTypes.types], 
            # Number of steps to skip per env. Default will not skip at all.
            n:int=1
        ) -> None:
        if isinstance(source_datapipe,NStepper): raise Exception(_msg)
        self.source_datapipe = source_datapipe
        self.n = n
        self.env_buffer = {}
        
    def __iter__(self) -> StepTypes.types:
        self.env_buffer = {}
        for step in self.source_datapipe:
            if not issubclass(step.__class__,StepTypes.types):
                raise Exception(f'Expected {StepTypes.types} object got {type(step)}\n{step}')
    
            env_id,terminated,step_n = int(step.env_id),bool(step.terminated),int(step.step_n)
        
            if env_id in self.env_buffer: self.env_buffer[env_id] += 1
            else:                         self.env_buffer[env_id] = 1
                
            if self.env_buffer[env_id]%self.n==0: yield step  
            elif terminated:                      yield step  
            elif step_n==1:                       yield step
            
            if terminated: self.env_buffer[env_id] = 1                
            
add_docs(
NSkipper,
"""Accepts a `source_datapipe` or iterable whose `next()` produces a `StepType` that
skips N steps for individual environments *while always producing 1st steps and terminated steps.*
"""
)

In [None]:
#|hide
# Used here to avoid UserWarnings related to gym complaining about bounding box / action space format.
# There must be a bug in the CartPole-v1 env that is causing this to show. Also couldnt figure out the 
# regex, so instead we filter on the lineno, which is line 98.
warnings.filterwarnings("ignore",category=UserWarning,lineno=98)

Below we skip every other step given 3 envs while always keeping the 1st and terminated steps.

In [None]:
import pandas as pd
import gymnasium as gym
from fastrl.envs.gym import GymStepper
from fastrl.pipes.iter.nstep import NStepper

In [None]:
def n_skip_test(envs,total_steps,n=1,seed=0):
    pipe = dp.map.Mapper(envs)
    pipe = pipe.map(gym.make)
    pipe = dp.iter.MapToIterConverter(pipe)
    pipe = dp.iter.InMemoryCacheHolder(pipe)
    pipe = pipe.cycle()
    pipe = GymStepper(pipe,seed=seed)
    pipe = NSkipper(pipe,n=n)

    steps = [step for step,_ in zip(*(pipe,range(total_steps)))]
    return steps

steps = n_skip_test(['CartPole-v1']*3,200,2,0)
pd.DataFrame(steps)[['state','next_state','env_id','terminated']][:10]

Here is a simple 1-env result...

In [None]:
steps = n_skip_test(['CartPole-v1']*1,200,2,0)
pd.DataFrame(steps)[['state','next_state','step_n','terminated']][:10]

#|hide
## NSkipper Tests

There are a couple properties that we expect from `NSkipper`:

    - The 1st step should always be returned.
    - The terminated step should always be returned.
    - Every env should have its own steps skipped/kept
    
First, `NSkipper(pipe,n=1)` should be identical to a pipelines that never used it.

In [None]:
#|hide
pipe = dp.map.Mapper(['CartPole-v1']*3)
pipe = pipe.map(gym.make)
pipe = dp.iter.MapToIterConverter(pipe)
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,seed=0)

no_n_skips = [step for step,_ in zip(*(pipe,range(60)))]
steps = n_skip_test(['CartPole-v1']*3,60,1,0)

#|hide
If `n=1` we should expect that regardless of the number of envs, both n-step and simple environment
pipelines should be identical.

In [None]:
from fastcore.all import test_eq
from fastrl.core import test_len

In [None]:
#|hide
test_len(steps,no_n_skips)
for field in ['next_state','state','terminated']:
    for i,(step,no_n_step) in enumerate(zip(steps,no_n_skips)): 
        test_eq(getattr(step,field),getattr(no_n_step,field))

In [None]:
#|hide
# pd.set_option('display.max_rows', 500)
# pd.DataFrame(steps)[['state','next_state','env_id','done']]
# pd.DataFrame(no_n_skips)[['state','next_state','env_id','done']]

In [None]:
#|export
def n_skips_expected(
    default_steps:int, # The number of steps the episode would run without n_skips
    n:int # The n-skip value that we are planning to use
):
    if n==1: return default_steps # All the steps will eb retained including the 1st step. No offset needed
    # If n goes into default_steps evenly, then the final "done" will be technically an "extra" step
    elif default_steps%n==0: return (default_steps // n) + 1 # first step will be kept
    else:
        # If the steps dont divide evenly then it will attempt to skip done, but ofcourse, we dont
        # let that happen
        return (default_steps // n) + 2 # first step and done will be kept
    
n_skips_expected.__doc__=r"""
Produces the expected number of steps, assuming a fully deterministic episode based on `default_steps` and `n`.

Mainly used for testing.

Given `n=2`, given 1 envs, knowing that `CartPole-v1` when `seed=0` will always run 18 steps, the total 
steps will be:

$$
18 // n + 1 (1st+last)
$$
"""    

In [None]:
import torch

In [None]:
#|hide
expected_n_skips = n_skips_expected(default_steps=18,n=1)
print('Given the above values, we expect a single episode to be ',expected_n_skips,' steps long')
steps = n_skip_test(['CartPole-v1']*1,expected_n_skips+1,1,0)
# The first episode should have ended on row 34, beign 35 steps long. The 36th row should be a new episode
test_eq(steps[-2].terminated,torch.tensor([True]))
test_eq(steps[-2].episode_n,torch.tensor([1]))
test_eq(steps[-2].step_n,torch.tensor([18]))
test_eq(steps[-1].terminated,torch.tensor([False]))
test_eq(steps[-1].episode_n,torch.tensor([2]))
test_eq(steps[-1].step_n,torch.tensor([1]))

In [None]:
#|hide
expected_n_skips = n_skips_expected(default_steps=18,n=2)
print('Given the above values, we expect a single episode to be ',expected_n_skips,' steps long')
steps = n_skip_test(['CartPole-v1']*1,expected_n_skips+1,2,0)
# The first episode should have ended on row 34, beign 35 steps long. The 36th row should be a new episode
test_eq(steps[-2].terminated,torch.tensor([True]))
test_eq(steps[-2].episode_n,torch.tensor([1]))
test_eq(steps[-2].step_n,torch.tensor([18]))
test_eq(steps[-1].terminated,torch.tensor([False]))
test_eq(steps[-1].episode_n,torch.tensor([2]))
test_eq(steps[-1].step_n,torch.tensor([1]))

In [None]:
#|hide
expected_n_skips = n_skips_expected(default_steps=18,n=4)
print('Given the above values, we expect a single episode to be ',expected_n_skips,' steps long')
steps = n_skip_test(['CartPole-v1']*1,expected_n_skips+1,4,0)
# The first episode should have ended on row 34, beign 35 steps long. The 36th row should be a new episode
test_eq(steps[-2].terminated,torch.tensor([True]))
test_eq(steps[-2].episode_n,torch.tensor([1]))
test_eq(steps[-2].step_n,torch.tensor([18]))
test_eq(steps[-1].terminated,torch.tensor([False]))
test_eq(steps[-1].episode_n,torch.tensor([2]))
test_eq(steps[-1].step_n,torch.tensor([1]))

In [None]:
#|hide
expected_n_skips = n_skips_expected(default_steps=18,n=2)
print('Given the above values, we expect a single episode to be ',expected_n_skips,' steps long')
steps = n_skip_test(['CartPole-v1']*3,expected_n_skips*3+1,2,0)
# The first episode should have ended on row 34, beign 35 steps long. The 36th row should be a new episode
test_eq(steps[-2].terminated,torch.tensor([True]))
test_eq(steps[-2].episode_n,torch.tensor([1]))
test_eq(steps[-2].step_n,torch.tensor([18]))
test_eq(steps[-1].terminated,torch.tensor([False]))
test_eq(steps[-1].episode_n,torch.tensor([2]))
test_eq(steps[-1].step_n,torch.tensor([1]))

In [None]:
#|hide
#|eval: false
!nbdev_export