In [None]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [None]:
#|default_exp agents.trpo

In [None]:
#|export
# Python native modules
from typing import *
from typing_extensions import Literal
import typing 
# Third party libs
import numpy as np
import torch
from torch import nn
from torch.distributions import *
import torchdata.datapipes as dp 
from torchdata.dataloader2.graph import DataPipe,traverse,replace_dp
from fastcore.all import test_eq,test_ne
# Local modules
from fastrl.core import *
from fastrl.pipes.core import *
from fastrl.torch_core import *
from fastrl.layers import *
from fastrl.data.block import *
from fastrl.envs.gym import *

# TRPO
> Trust Region Policy Optimization via online-learning for continuous action domains

[(Schulman et al., 2015) [TRPO] Trust Region Policy Optimization](https://arxiv.org/abs/1502.05477).

Directly based on [`ikostrikov`'s implimentation](https://github.com/ikostrikov/pytorch-trpo) and
coda / explainations in [Shewchuk Cs.Cmu.Edu, 2022, Accessed 19 Nov 2022.](cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf)

## Core

In [None]:
#|export
class AdvantageStep(typing.NamedTuple):
    state:       torch.FloatTensor=torch.FloatTensor([0])
    action:      torch.FloatTensor=torch.FloatTensor([0])
    next_state:  torch.FloatTensor=torch.FloatTensor([0])
    terminated:  torch.BoolTensor=torch.BoolTensor([1])
    truncated:   torch.BoolTensor=torch.BoolTensor([1])
    reward:      torch.FloatTensor=torch.LongTensor([0])
    total_reward:torch.FloatTensor=torch.FloatTensor([0])
    advantage:   torch.FloatTensor=torch.FloatTensor([0])
    env_id:      torch.LongTensor=torch.LongTensor([0])
    proc_id:     torch.LongTensor=torch.LongTensor([0])
    step_n:      torch.LongTensor=torch.LongTensor([0])
    episode_n:   torch.LongTensor=torch.LongTensor([0])
    image:       torch.FloatTensor=torch.FloatTensor([0])
    
    def clone(self):
        return self.__class__(
            **{fld:getattr(self,fld).clone() for fld in self.__class__._fields}
        )
    
    def detach(self):
        return self.__class__(
            **{fld:getattr(self,fld).detach() for fld in self.__class__._fields}
        )
    
    def device(self,device='cpu'):
        return self.__class__(
            **{fld:getattr(self,fld).to(device=device) for fld in self.__class__._fields}
        )

    def to(self,*args,**kwargs):
        return self.__class__(
            **{fld:getattr(self,fld).to(*args,**kwargs) for fld in self.__class__._fields}
        )
    
    @classmethod
    def random(cls,seed=None,**flds):
        _flds,_annos = cls._fields,cls.__annotations__

        def _random_annos(anno):
            t = anno(1)
            if anno==torch.BoolTensor: t.random_(2) 
            else:                      t.random_(100)
            return t

        return cls(
            *(flds.get(
                f,_random_annos(_annos[f])
            ) for f in _flds)
        )

add_namedtuple_doc(
AdvantageStep,
"""Represents a single step in an environment similar to `SimpleStep` however has
an addition field called `advantage`.""",
advantage="""Generally characterized as $A(s,a) = Q(s,a) - V(s)$""",
**{f:getattr(SimpleStep,f).__doc__ for f in SimpleStep._fields}
)

In [None]:
show_doc(AdvantageStep)

---

### AdvantageStep

>      AdvantageStep (state:torch.FloatTensor=tensor([0.]),
>                     action:torch.FloatTensor=tensor([0.]),
>                     next_state:torch.FloatTensor=tensor([0.]),
>                     terminated:torch.BoolTensor=tensor([True]),
>                     truncated:torch.BoolTensor=tensor([True]),
>                     reward:torch.FloatTensor=tensor([0]),
>                     total_reward:torch.FloatTensor=tensor([0.]),
>                     advantage:torch.FloatTensor=tensor([0.]),
>                     env_id:torch.LongTensor=tensor([0]),
>                     proc_id:torch.LongTensor=tensor([0]),
>                     step_n:torch.LongTensor=tensor([0]),
>                     episode_n:torch.LongTensor=tensor([0]),
>                     image:torch.FloatTensor=tensor([0.]))

Represents a single step in an environment similar to `SimpleStep` however has
an addition field called `advantage`.

Parameters:

 - **state**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`Both the initial state of the environment and the previous state.
 - **action**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`The action that was taken to transition from `state` to `next_state`
 - **next_state**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`Both the next state, and the last state in the environment
 - **terminated**:`<class 'torch.BoolTensor'>`  = `tensor([True])`Represents an ending condition for an environment such as reaching a goal or 'living long enough' as 
                    described by the MDP.
                    Good reference is: https://github.com/openai/gym/blob/39b8661cb09f19cb8c8d2f59b57417517de89cb0/gym/core.py#L151-L155
 - **truncated**:`<class 'torch.BoolTensor'>`  = `tensor([True])`Represents an ending condition for an environment that can be seen as an out of bounds condition either
                   literally going out of bounds, breaking rules, or exceeding the timelimit allowed by the MDP.
                   Good reference is: https://github.com/openai/gym/blob/39b8661cb09f19cb8c8d2f59b57417517de89cb0/gym/core.py#L151-L155'
 - **reward**:`<class 'torch.FloatTensor'>`  = `tensor([0])`The single reward for this step.
 - **total_reward**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`The total accumulated reward for this episode up to this step.
 - **advantage**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`Generally characterized as $A(s,a) = Q(s,a) - V(s)$
 - **env_id**:`<class 'torch.LongTensor'>`  = `tensor([0])`The environment this step came from (useful for debugging)
 - **proc_id**:`<class 'torch.LongTensor'>`  = `tensor([0])`The process this step came from (useful for debugging)
 - **step_n**:`<class 'torch.LongTensor'>`  = `tensor([0])`The step number in a given episode.
 - **episode_n**:`<class 'torch.LongTensor'>`  = `tensor([0])`The episode this environment is currently running through.
 - **image**:`<class 'torch.FloatTensor'>`  = `tensor([0.])`Intended for display and logging only. If the intention is to use images for training an
               agent, then use a env wrapper instead.

## Memory
> Policy gradient online models use short term trajectory samples instead of
ER / iid memory

In [None]:
#|export
@torch.jit.script
def discounted_cumsum_(t:torch.Tensor,gamma:float,reverse:bool=False):
    """Performs a cumulative sum on `t` where `gamma` is applied for each index
    >1."""
    if reverse:
        # We do +2 because +1 is needed to avoid out of index t[idx], and +2 is needed
        # to avoid out of index for t[idx+1].
        for idx in range(t.size(0)-2,-1,-1):
            t[idx] = t[idx] + t[idx+1] * gamma
    else:
        for idx in range(1,t.size(0)):
            t[idx] = t[idx] + t[idx-1] * gamma

In [None]:
#|hide
# https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964/2
# with torch.autograd.profiler.profile(use_cuda=True) as prof:
#     discounted_cumsum(torch.ones(500),0.99)
# print(prof)

In [None]:
#|export
class AdvantageBuffer(dp.iter.IterDataPipe):
    debug=False
    def __init__(self,
            # A datapipe that produces `StepType`s.
            source_datapipe:DataPipe,
            # A model that takes in a `state` and outputs a single value 
            # representing $V$, where as $Q$ is $V + reward$
            critic:nn.Module,
            # Will accumulate up to `bs` or when the episode has terminated.
            bs=1000,
            # The discount factor, otherwise known as $\gamma$, is defined in 
            # (Shulman et al., 2016) as '... $\gamma$ introduces bias into
            # the policy gradient estimate...'.
            discount:float=0.99,
            # $\lambda$ is unqiue to GAE and manages importance to values when 
            # they are in accurate is defined in (Shulman et al., 2016) as '... $\lambda$ < 1
            # introduces bias only when the value function is inaccurate....'.
            gamma:float=0.99
        ):
        self.source_datapipe = source_datapipe
        self.bs = bs
        self.critic = critic
        self.device = None
        self.discount = discount
        self.gamma = gamma
        self.env_advantage_buffer:Dict[Literal['env'],list] = {}

    def to(self,*args,**kwargs):
        self.device = kwargs.get('device',None)

    def __repr__(self):
        return str({k:v if k!='env_advantage_buffer' else f'{len(self)} elements' 
                    for k,v in self.__dict__.items()})

    def __len__(self): return self._sz_tracker

    def update_advantage_buffer(self,step:StepType) -> int:
        if self.debug: 
            print('Adding to advantage buffer: ',step)
        env_id = int(step.env_id.detach().cpu())
        if env_id not in self.env_advantage_buffer: 
            self.env_advantage_buffer[env_id] = []
        self.env_advantage_buffer[env_id].append(step)
        return env_id
        
    def zip_steps(
            self,
            steps:List[StepType]
        ) -> Tuple[torch.FloatTensor,torch.FloatTensor,torch.BoolTensor]:
            step_subset = [(o.reward,o.state,o.truncated or o.terminated) for o in steps]
            zipped_fields = zip(*step_subset)
            return L(zipped_fields).map(torch.vstack)

    def delta_calc(self,reward,v,v_next,done):
        return reward + (self.gamma * v * done) - v_next

    def __iter__(self) -> AdvantageStep:
        self.env_advantage_buffer:Dict[Literal['env'],list] = {}
        for step in self.source_datapipe:
            env_id = self.update_advantage_buffer(step)
            done = step.truncated or step.terminated
            if done or len(self.env_advantage_buffer[env_id])>self.bs:
                steps = self.env_advantage_buffer[env_id]
                rewards,states,dones = self.zip_steps(steps)
                # We vstack the final next_state so we have a complete picture
                # of the state transitions and matching reward/done shapes.
                values = self.critic(torch.vstack((states,steps[-1].next_state)))
                delta = self.delta_calc(rewards,values[:-1],values[1:],dones)
                discounted_cumsum_(delta,self.discount*self.gamma,reverse=True)

                for _step,gae_advantage in zip(*(steps,delta)):
                    yield AdvantageStep(
                        advantage=gae_advantage,
                        **{f:getattr(_step,f) for f in _step._fields}
                    )

    @classmethod
    def insert_dp(cls,critic,old_dp=GymStepper) -> Callable[[DataPipe],DataPipe]:
        def _insert_dp(pipe):
            v = replace_dp(
                traverse(pipe,only_datapipe=True),
                find_dp(traverse(pipe,only_datapipe=True),old_dp),
                cls(find_dp(traverse(pipe,only_datapipe=True),old_dp),critic=critic)
            )
            return list(v.values())[0][0]
        return _insert_dp

add_docs(
AdvantageBuffer,
"""Collects an entire episode, calculates the advantage for each step, then
yields that episode's `AdvantageStep`s.

This is described in the original paper `(Shulman et al., 2016) High-Dimensional 
Continuous Control Usin Generalized Advantage Estimation`.

This algorithm is based on the concept of advantage:

$A_{\pi}(s,a) = Q_{\pi}(s,a) - V_{\pi}(s)$

Where (Shulman et al., 2016) pg 5 calculates it as:

$\hat{A}_{t}^{GAE(\gamma,\lambda)} = \sum_{l=0}^{\infty}(\gamma\lambda)^l\delta_{t+l}^V$

Where (Shulman et al., 2016) pg 4 defines $\delta$ as:

$\delta_t^V = r_t + \gamma V(s_{t+1}) - V(s_{t})$
""",
to=torch.Tensor.to.__doc__,
update_advantage_buffer="Adds `step` to `env_advantage_buffer` based on the environment id.",
zip_steps="""Given `steps`, strip out the `Tuple[reward,state,truncated or terminated]` fields,
and `torch.vstack` them.""",
delta_calc="""Calculates $\delta_t^V = r_t + \gamma V(s_{t+1}) - V(s_{t})$ which 
is the advantage difference between state transitions."""
)

In [None]:
from fastrl.layers import Critic

In [None]:
critic = Critic(3,0)

gym_pipe = GymTransformBlock(
    agent=None,seed=0,
    dp_augmentation_fns=[AdvantageBuffer.insert_dp(critic=critic)]
)(['Pendulum-v1'])

for chunk in gym_pipe.header(5):
    for step in chunk:
        test_eq(type(step),AdvantageStep)
        assert step.advantage!=0

## Actor

In [None]:
#|export
class OptionalClampLinear(Module):
    def __init__(self,num_inputs,state_dims,fix_variance:bool=False,
                 clip_min=0.3,clip_max=10.0):
        "Linear layer or constant block used for std."
        store_attr()
        if not self.fix_variance: 
            self.fc=nn.Linear(self.num_inputs,self.state_dims)
    
    def forward(self,x):
        if self.fix_variance: 
            return torch.full((x.shape[0],self.state_dims),1.0)
        else:                 
            return torch.clamp(nn.Softplus()(self.fc(x)),self.clip_min,self.clip_max)

# TODO(josiahls): This is probably a highly generic SimpleGMM tbh. Once we know this
# works, we should just rename this to SimpleGMM
class Actor(Module):
    def __init__(            
            self,
            state_sz:int,   # The input dim of the state / flattened conv output
            action_sz:int,  # The output dim of the actions
            hidden:int=400, # Number of neurons connected between the 2 input/output layers
            fix_variance:bool=False
        ):
        "Single-component GMM parameterized by a fully connected layer with optional std layer."
        store_attr()
        self.mu = nn.Sequential(
            nn.Linear(state_sz, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden),
            nn.Tanh(),
            nn.Linear(hidden, action_sz),
            nn.Tanh(),
        )
        self.std = OptionalClampLinear(state_sz,action_sz,fix_variance)
        
    def forward(self,x): return Independent(Normal(self.mu(x),self.std(x)),1)


add_docs(
Actor,
"""Produces continuous outputs from mean of a Gaussian distribution.""",
forward="Mean outputs from a parameterized Gaussian distribution."
)

The `Actor` is developed from the description found in `(Schulman et al., 2015)`: 

    ...we used a Gaussian distribution, where the covariance matrix was diagonal 
    and independent of the state. A neural network with several fully-connected (dense) 
    layers maps from the input features to the mean of a Gaussian distribution.

In [None]:
actor = Actor(4,2)
dist = actor(torch.randn(1,4))
dist.mean,dist.stddev,dist.log_prob(torch.randn(1,2))

(tensor([[-0.1615, -0.0329]], grad_fn=<TanhBackward0>),
 tensor([[1.0072, 0.5305]], grad_fn=<SqrtBackward0>),
 tensor([-1.6457], grad_fn=<SumBackward1>))

## Learning

![](../../images/(Schulman%20et%20al.%2C%202017)%20%5BTRPO%5D%20Trust%20Region%20Policy%20Optimization%20Algorithm%201.png)

We start with finding the direction of conjugate gradients. 

In [None]:
A = torch.tensor(
    [[3.,2.],[2.,6.]]
)
b = torch.tensor([[2.],[-8.]])

Ref [Shewchuk, 1994](https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf), but
$A$ is the gradients of the model, and $b$ (typically the bias) is the loss.

The below function is pretty much example `B2` pg 50.

In [None]:
#|export
def conjugate_gradients(
    # A function that takes the direction `d` and applies it to `A`.
    # The simplest example of this found would be:
    # `lambda d:A@d`
    Ad_f:Callable[[torch.Tensor],torch.Tensor],  
    # The bias or in TRPO's case the loss.
    b:torch.Tensor, 
    # Number of steps to go for assuming we are not less than `residual_tol`.
    nsteps:int, 
    # If the residual is less than this, then we have arrived at the local minimum.
    # Note that (Shewchuk, 1994) they mention that this should be E^2 * rdotr_0
    residual_tol=1e-10, 
    device="cpu"
):
    # The final direction to go in.
    x = torch.zeros(b.size()).to(device)
    # Would typically be b - Ax, however in TRPO's case this has already been 
    # done in the loss function.
    r = b.clone()
    # The first direction is the first residual.
    d = b.clone()
    rdotr = r.T @ r # \sigma_{new} pg50
    for i in range(nsteps):
        _Ad = Ad_f(d) # _Ad is also considered `q`
        # Determines the size / rate / step size of the direction
        alpha = rdotr / (d.T @ _Ad)

        x += alpha * d
        # [Shewchuk, 1994](https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf) pg 49:
        #
        # The fast recursive formula for the residual is usually used, but once every 50 iterations, the exact residual
        # is recalculated to remove accumulated floating point error. Of course, the number 50 is arbitrary; for large
        # n \sqrt{n}, ©
        # might be appropriate.
        #
        # @josiah: This is kind of weird since we are using `Ad_f`. Maybe we can
        # have an optional param for A direction to do the residual reset?
        #
        # if nsteps > 50 and i % int(torch.sqrt(i)) == 0:
        #     r = b - Ax
        # else:
        r -= alpha * _Ad
        new_rdotr = r.T @ r
        beta = new_rdotr / rdotr
        d = r + beta * d
        rdotr = new_rdotr
        # Same as \sigma_{new} < E^2\sigma
        if rdotr < residual_tol:
            break
    return x

add_docs(
conjugate_gradients,
"""Conjugating Gradients builds on the idea of Conjugate Directions.

As noted in:
[Shewchuk, 1994](https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf)

We want "everytime we take a step, we got it right the first time" pg 21. 

In otherwords, we have a model, and we have the gradients and the loss. Using the 
loss, what is the the smartest way to change/optimize the gradients?

`Conjugation` is the act of makeing the `parameter space / gradient space` easier to 
optimize over. In technical terms, we find `nsteps` directions to change the gradients
toward that are orthogonal to each other and to the `parameter space / gradient space`.

In otherwords, what is the direction that is most optimal, and what is the 
direction that if used to find `x` will reduce `Ax - b` to 0. 
"""
)

In [None]:
conjugate_gradients(
    lambda d:A@d,
    b - A@torch.tensor([[50.],[50.]]),
    10
)

tensor([[-48.],
        [-52.]])

In [None]:
#|export
def backtrack_line_search(
    x:torch.Tensor,
    r:torch.Tensor,
    error_f:Callable,
    expected_improvement_rate:torch.Tensor,
    accaptance_tolerance:float,
    n_max_backtracks:int=10
):
    e = error_f(x)
    for (n_back,alpha) in enumerate(.5**torch.range(n_max_backtracks)):
        x_new = x + alpha * r 
        e_new = error_f(x_new)
        improvement = e - e_new
        expected_improvement = expected_improvement_rate * alpha 

        ratio = improvement / expected_improvement
        if ratio.item() > accaptance_tolerance and improvement.item() > 0:
            return True, x_new
    return False, x

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()