# Trust Region Policy Optimization (TRPO)

In [None]:
!apt-get install -y xvfb

!pip install gym==0.23.1 \
    pytorch-lightning==1.6 \
    pyvirtualdisplay

!pip install git+https://github.com/google/brax.git@main

#### Setup virtual display

In [None]:
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

#### Import the necessary code libraries

In [None]:
import copy
import torch
import random
import gym
import matplotlib
import functools
import itertools
import math

import numpy as np
import matplotlib.pyplot as plt

import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW, Optimizer

from torch.distributions import Normal, kl_divergence

from pytorch_lightning import LightningModule, Trainer

import brax
from brax import envs
from brax.envs import to_torch
from brax.io import html

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

v = torch.ones(1, device='cuda')

In [None]:
@torch.no_grad()
def create_video(env, episode_length, policy=None):
  qp_array = []
  state = env.reset()
  for i in range(episode_length):
    if policy:
      loc, scale = policy(state)
      sample = torch.normal(loc, scale)
      action = torch.tanh(sample)
    else:
      action = env.action_space.sample()
    state, _, _, _ = env.step(action)
    qp_array.append(env.unwrapped._state.qp)
  return HTML(html.render(env.unwrapped._env.sys, qp_array))


@torch.no_grad()
def test_agent(env, episode_length, policy, episodes=10):

  ep_returns = []
  for ep in range(episodes):
    state = env.reset()
    done = False
    ep_ret = 0.0

    while not done:
      loc, scale = policy(state)
      sample = torch.normal(loc, scale)
      action = torch.tanh(sample)   
      state, reward, done, info = env.step(action)
      ep_ret += reward.item()

    ep_returns.append(ep_ret)

  return sum(ep_returns) / episodes  

#### Create the policy

#### Create the value network

#### Create the environment

In [None]:
class RunningMeanStd:
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = torch.zeros(shape, dtype=torch.float32).to(device)
        self.var = torch.ones(shape, dtype=torch.float32).to(device)
        self.count = epsilon

    def update(self, x):
        batch_mean = torch.mean(x, dim=0)
        batch_var = torch.var(x, dim=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count
        )


def update_mean_var_count_from_moments(
    mean, var, count, batch_mean, batch_var, batch_count
):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + torch.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count


class NormalizeObservation(gym.core.Wrapper):

    def __init__(self, env, epsilon=1e-8):
        super().__init__(env)
        self.num_envs = getattr(env, "num_envs", 1)
        self.obs_rms = RunningMeanStd(shape=self.observation_space.shape[-1])
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, dones, infos = self.env.step(action)
        obs = self.normalize(obs)
        return obs, rews, dones, infos

    def reset(self, **kwargs):
        return_info = kwargs.get("return_info", False)
        if return_info:
            obs, info = self.env.reset(**kwargs)
        else:
            obs = self.env.reset(**kwargs)
        obs = self.normalize(obs)
        if not return_info:
            return obs
        else:
            return obs, info

    def normalize(self, obs):
        self.obs_rms.update(obs)
        return (obs - self.obs_rms.mean) / torch.sqrt(self.obs_rms.var + self.epsilon)


In [None]:
entry_point = functools.partial(envs.create_gym_env, env_name='ant')
gym.register('brax-ant-v0', entry_point=entry_point)

In [None]:
env = gym.make("brax-ant-v0", episode_length=1000)
env = to_torch.JaxToTorchWrapper(env, device=device)
create_video(env, 1000)

#### TRPO optimizer

In [None]:
# Adapted from: https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py


# Copyright (c) 2019 Reinforcement Learning Working Group

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


def unflatten_tensors(flattened, tensor_shapes):
  flattened = flattened.cpu()
  tensor_sizes = list(map(np.prod, tensor_shapes))
  indices = np.cumsum(tensor_sizes)[:-1]
  return [
      np.reshape(pair[0], pair[1]).to(device)
      for pair in zip(np.split(flattened, indices), tensor_shapes)
  ]


def _build_hessian_vector_product(func, params, reg_coeff=1e-5):
    param_shapes = [p.shape or torch.Size([1]) for p in params]
    f = func()
    f_grads = torch.autograd.grad(f, params, create_graph=True)

    def _eval(vector):
      unflatten_vector = unflatten_tensors(vector, param_shapes)
      

      assert len(f_grads) == len(unflatten_vector)
      grad_vector_product = torch.sum(
          torch.stack(
              [torch.sum(g * x) for g, x in zip(f_grads, unflatten_vector)]))

      hvp = list(
          torch.autograd.grad(grad_vector_product, params,
                              retain_graph=True))
      for i, (hx, p) in enumerate(zip(hvp, params)):
          if hx is None:
              hvp[i] = torch.zeros_like(p)

      flat_output = torch.cat([h.reshape(-1) for h in hvp])
      return flat_output + reg_coeff * vector

    return _eval


def _conjugate_gradient(f_Ax, b, cg_iters, residual_tol=1e-10):
    p = b.clone()
    r = b.clone()
    x = torch.zeros_like(b)
    rdotr = torch.dot(r, r)

    for _ in range(cg_iters):
        z = f_Ax(p)
        v = rdotr / torch.dot(p, z)
        x += v * p
        r -= v * z
        newrdotr = torch.dot(r, r)
        mu = newrdotr / rdotr
        p = r + mu * p

        rdotr = newrdotr
        if rdotr < residual_tol:
            break
    return x


class ConjugateGradientOptimizer(Optimizer):

    def __init__(self, params, max_constraint_value, cg_iters=10, max_backtracks=15,
                 backtrack_ratio=0.8, hvp_reg_coeff=1e-5, accept_violation=False):
      
        super().__init__(params, {})
        self._max_constraint_value = max_constraint_value
        self._cg_iters = cg_iters
        self._max_backtracks = max_backtracks
        self._backtrack_ratio = backtrack_ratio
        self._hvp_reg_coeff = hvp_reg_coeff
        self._accept_violation = accept_violation


    def step(self, closure):
      f_loss, f_constraint = closure()
      params = []
      grads = []
      for group in self.param_groups:
          for p in group['params']:
              if p.grad is not None:
                  params.append(p)
                  grads.append(p.grad.reshape(-1))

      flat_loss_grads = torch.cat(grads)
      f_Ax = _build_hessian_vector_product(f_constraint, params, self._hvp_reg_coeff)
      step_dir = _conjugate_gradient(f_Ax, flat_loss_grads, self._cg_iters)

      step_dir[step_dir.ne(step_dir)] = 0.

      step_size = np.sqrt(2.0 * self._max_constraint_value * (1. / (torch.dot(step_dir, f_Ax(step_dir)) + 1e-8)).cpu())

      if np.isnan(step_size):
          step_size = 1.

      descent_step = step_size * step_dir
      self._backtracking_line_search(params, descent_step, f_loss, f_constraint)


    def _backtracking_line_search(self, params, descent_step, f_loss, f_constraint):
        prev_params = [p.clone() for p in params]
        ratio_list = self._backtrack_ratio**np.arange(self._max_backtracks)
        loss_before = f_loss()

        param_shapes = [p.shape or torch.Size([1]) for p in params]
        descent_step = unflatten_tensors(descent_step, param_shapes)
        assert len(descent_step) == len(params)

        for ratio in ratio_list:
            for step, prev_param, param in zip(descent_step, prev_params, params):
                step = ratio * step
                new_param = prev_param.data - step
                param.data = new_param.data

            loss = f_loss()
            constraint_val = f_constraint()
            if (loss < loss_before and constraint_val <= self._max_constraint_value):
                break

#### Create the dataset

#### Create TRPO with generalized advantage estimation (GAE)

#### Purge logs and run the visualization tool (Tensorboard)

In [None]:
# Start tensorboard.
!rm -r /content/lightning_logs/
!rm -r /content/videos/
%load_ext tensorboard
%tensorboard --logdir /content/lightning_logs/

#### Train the policy