In [None]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [None]:
#|default_exp agents.trpo

In [None]:
#|export
# Python native modules
from typing import *
from typing_extensions import Literal
import typing 
# Third party libs
import numpy as np
import torch
import torchdata.datapipes as dp 
from torchdata.dataloader2.graph import DataPipe,traverse,replace_dp
# Local modules
from fastrl.core import *
from fastrl.pipes.core import *
from fastrl.torch_core import *
from fastrl.layers import *
from fastrl.data.block import *
from fastrl.envs.gym import *

# TRPO
> Trust Region Policy Optimization via online-learning for continuous action domains

[(Schulman et al., 2017) [TRPO] Trust Region Policy Optimization](https://arxiv.org/abs/1502.05477).

## Core

In [None]:
#|export
class AdvantageStep(typing.NamedTuple):
    state:       torch.FloatTensor=torch.FloatTensor([0])
    action:      torch.FloatTensor=torch.FloatTensor([0])
    next_state:  torch.FloatTensor=torch.FloatTensor([0])
    terminated:  torch.BoolTensor=torch.BoolTensor([1])
    truncated:   torch.BoolTensor=torch.BoolTensor([1])
    reward:      torch.FloatTensor=torch.LongTensor([0])
    total_reward:torch.FloatTensor=torch.FloatTensor([0])
    advantage:   torch.FloatTensor=torch.FloatTensor([0])
    env_id:      torch.LongTensor=torch.LongTensor([0])
    proc_id:     torch.LongTensor=torch.LongTensor([0])
    step_n:      torch.LongTensor=torch.LongTensor([0])
    episode_n:   torch.LongTensor=torch.LongTensor([0])
    image:       torch.FloatTensor=torch.FloatTensor([0])
    
    def clone(self):
        return self.__class__(
            **{fld:getattr(self,fld).clone() for fld in self.__class__._fields}
        )
    
    def detach(self):
        return self.__class__(
            **{fld:getattr(self,fld).detach() for fld in self.__class__._fields}
        )
    
    def device(self,device='cpu'):
        return self.__class__(
            **{fld:getattr(self,fld).to(device=device) for fld in self.__class__._fields}
        )

    def to(self,*args,**kwargs):
        return self.__class__(
            **{fld:getattr(self,fld).to(*args,**kwargs) for fld in self.__class__._fields}
        )
    
    @classmethod
    def random(cls,seed=None,**flds):
        _flds,_annos = cls._fields,cls.__annotations__

        def _random_annos(anno):
            t = anno(1)
            if anno==torch.BoolTensor: t.random_(2) 
            else:                      t.random_(100)
            return t

        return cls(
            *(flds.get(
                f,_random_annos(_annos[f])
            ) for f in _flds)
        )

add_namedtuple_doc(
AdvantageStep,
"""Represents a single step in an environment similar to `SimpleStep` however has
an addition field called `advantage`.""",
advantage="""Generally characterized as $A(s,a) = Q(s,a) - V(s)$""",
**{f:getattr(SimpleStep,f).__doc__ for f in SimpleStep._fields}
)

In [None]:
show_doc(AdvantageStep)

---

### AdvantageStep

>      AdvantageStep (state:torch.FloatTensor=tensor([0.]),
>                     action:torch.FloatTensor=tensor([0.]),
>                     next_state:torch.FloatTensor=tensor([0.]),
>                     terminated:torch.BoolTensor=tensor([True]),
>                     truncated:torch.BoolTensor=tensor([True]),
>                     reward:torch.FloatTensor=tensor([0]),
>                     total_reward:torch.FloatTensor=tensor([0.]),
>                     advantage:torch.FloatTensor=tensor([0.]),
>                     env_id:torch.LongTensor=tensor([0]),
>                     proc_id:torch.LongTensor=tensor([0]),
>                     step_n:torch.LongTensor=tensor([0]),
>                     episode_n:torch.LongTensor=tensor([0]),
>                     image:torch.FloatTensor=tensor([0.]))

Represents a single step in an environment similar to `SimpleStep` however has
an addition field called `advantage`.

Parameters:

 - **state**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`Both the initial state of the environment and the previous state.
 - **action**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`The action that was taken to transition from `state` to `next_state`
 - **next_state**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`Both the next state, and the last state in the environment
 - **terminated**:`<class 'torch.BoolTensor'>`  = `tensor([True])`Represents an ending condition for an environment such as reaching a goal or 'living long enough' as 
                    described by the MDP.
                    Good reference is: https://github.com/openai/gym/blob/39b8661cb09f19cb8c8d2f59b57417517de89cb0/gym/core.py#L151-L155
 - **truncated**:`<class 'torch.BoolTensor'>`  = `tensor([True])`Represents an ending condition for an environment that can be seen as an out of bounds condition either
                   literally going out of bounds, breaking rules, or exceeding the timelimit allowed by the MDP.
                   Good reference is: https://github.com/openai/gym/blob/39b8661cb09f19cb8c8d2f59b57417517de89cb0/gym/core.py#L151-L155'
 - **reward**:`<class 'torch.FloatTensor'>`  = `tensor([0])`The single reward for this step.
 - **total_reward**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`The total accumulated reward for this episode up to this step.
 - **advantage**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`Generally characterized as $A(s,a) = Q(s,a) - V(s)$
 - **env_id**:`<class 'torch.LongTensor'>`  = `tensor([0])`The environment this step came from (useful for debugging)
 - **proc_id**:`<class 'torch.LongTensor'>`  = `tensor([0])`The process this step came from (useful for debugging)
 - **step_n**:`<class 'torch.LongTensor'>`  = `tensor([0])`The step number in a given episode.
 - **episode_n**:`<class 'torch.LongTensor'>`  = `tensor([0])`The episode this environment is currently running through.
 - **image**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`Intended for display and logging only. If the intention is to use images for training an
               agent, then use a env wrapper instead.

## Memory
> Policy gradient online models use short term trajectory samples instead of
ER / iid memory

In [None]:
#|export
class AdvantageBuffer(dp.iter.IterDataPipe):
    debug=False
    def __init__(self,
            source_datapipe:DataPipe,
            # Will accumulate up to `bs` or when the episode has terminated.
            bs=1000,
            # If the `self.device` is not cpu, and `store_as_cpu=True`, then
            # calls to `sample()` will dynamically move them to `self.device`, and
            # next `sample()` will move them back to cpu before producing new samples.
            # This can be slower, but can save vram.
            # If `store_as_cpu=False`, then samples stay on `self.device`
            #
            # If being run with n_workers>0, shared_memory, and fork, this MUST be true. This is needed because
            # otherwise the tensors in the memory will remain shared with the tensors created in the 
            # dataloader.
            store_as_cpu:bool=True
        ):
        self.source_datapipe = source_datapipe
        self.bs = bs
        self.store_as_cpu = store_as_cpu
        self.device = None

    def to(self,*args,**kwargs):
        self.device = kwargs.get('device',None)

    def __repr__(self):
        return str({k:v if k!='memory' else f'{len(self)} elements' for k,v in self.__dict__.items()})

    def __len__(self): return self._sz_tracker
    
    def __iter__(self) -> AdvantageStep:
        self.env_advantage_buffer:Dict[Literal['env'],list] = {}
        for step in self.source_datapipe:
            if self.debug: print('Adding to advantage buffer: ',step)
            env_id = int(step.env_id.detach().cpu())
            if env_id not in self.env_advantage_buffer: 
                self.env_advantage_buffer[env_id] = []
            self.env_advantage_buffer[env_id].append(step)

            if any((step.truncated,step.terminated,len(self.env_advantage_buffer)>self.bs)):
                yield AdvantageStep()

    @classmethod
    def insert_dp(cls,old_dp=GymStepper) -> Callable[[DataPipe],DataPipe]:
        def _insert_dp(pipe):
            v = replace_dp(
                traverse(pipe,only_datapipe=True),
                find_dp(traverse(pipe,only_datapipe=True),old_dp),
                cls(find_dp(traverse(pipe,only_datapipe=True),old_dp))
            )
            return list(v.values())[0][0]
        return _insert_dp

add_docs(
AdvantageBuffer,
"""Samples entire trajectories instead of individual time steps.""",
to=torch.Tensor.to.__doc__
)

In [None]:
gym_pipe = GymTransformBlock(
    agent=None,seed=0,
    dp_augmentation_fns=[AdvantageBuffer.insert_dp()]
)(['Pendulum-v1'])

In [None]:
# list(gym_pipe.header(100))

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()