# Outlook

In this colab we give a detailed documentation of a version of the A2C algorithm using SaLinA, so as to better understand the inner mechanisms.

### Installation

The SaLinA library is [here](https://github.com/facebookresearch/salina).

Note the trick: we first try to import, if it fails we install the github repository and import again.

In [1]:
import functools
import time

%pip install gym==0.21.0
%pip install git+https://github.com/facebookresearch/salina.git@main
%pip install pygame




Note: you may need to restart the kernel to use updated packages.
Collecting git+https://github.com/facebookresearch/salina.git@main
  Cloning https://github.com/facebookresearch/salina.git (to revision main) to /tmp/pip-req-build-v93tlwek
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/salina.git /tmp/pip-req-build-v93tlwek
  Resolved https://github.com/facebookresearch/salina.git to commit 748b11563e5bea2c4a50d1043b6cfdf238d49664
  Preparing metadata (setup.py) ... [?25ldone
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
%pip install git+https://github.com/Anidwyd/pandroide-svpg.git@main

Collecting git+https://github.com/Anidwyd/pandroide-svpg.git@main
  Cloning https://github.com/Anidwyd/pandroide-svpg.git (to revision main) to /tmp/pip-req-build-vexyb648
  Running command git clone --filter=blob:none --quiet https://github.com/Anidwyd/pandroide-svpg.git /tmp/pip-req-build-vexyb648
  Resolved https://github.com/Anidwyd/pandroide-svpg.git to commit 99cd1106bf51f00764e5b6399984e01032c79395
  Preparing metadata (setup.py) ... [?25lerror
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m [31m[8 lines of output][0m
  [31m   [0m Traceback (most recent call last):
  [31m   [0m   File "<string>", line 36, in <module>
  [31m   [0m   File "<pip-setuptools-caller>", line 34, in <module>
  [31m   [0m   File "/tmp/pip-req-build-vexyb648/setup.py", line 17, in <module>
  [31m   [0m     long_description=read('README')
  [31m   [0m   Fil

## Imports

Below, we import standard python packages, pytorch packages, hydra and gym environments.

According to [the documentation](https://hydra.cc/docs/intro/) "Hydra is an open-source Python framework (from facebook research, NDLR) that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. The name Hydra comes from its ability to run multiple similar jobs - much like a Hydra with multiple heads."

This is hydra that makes it possible that by just defining the `def run_a2c(cfg):` function and then executing a long `params = {...}` variable at the bottom of this colab, the code is run with the parameters without calling an explicit main.

More precisely, the code is run by calling

`from omegaconf import DictConfig, OmegaConf`

`config=OmegaConf.create(params)`

`run_a2c(config)`

at the very bottom of the colab, after starting tensorboard.


In fact, Hydra can do many more things for you, such as launching many jobs on a cluster each with its own configuration (agent, environment, CPU or GPU, etc.). It also provides a mechanism to instantiate classes and functions as parameters, which makes your program more flexible. 

[OpenAI gym](https://gym.openai.com/) is a collection of benchmark environments to evaluate RL algorithms.

In [None]:
import copy
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

import hydra

import gym
# The TimeLimit wrapper is useful to specify a max number of steps for an RL episode
from gym.wrappers import TimeLimit

### SaLinA imports

As explained in [the white paper](https://arxiv.org/pdf/2110.07910.pdf), everything in SaLinA is an Agent.

This construct is defined in [the salina/agent.py](https://github.com/facebookresearch/salina/blob/main/salina/agent.py) file as the Agent class.

In practice, in RL one should rather use `TAgents`, that is agents that use a time index in their `__call__` function. But this is an abstract class, which only adds an abstraction layer for not much, so Ludovic Denoyer rather advises to directly use the `Agent` class.

Some of the comments below are just copy-pasted from the paper or from the code.

In [None]:
import salina

# Following Ludovic Denoyer's advice, we use Agent rather than TAgent
# `TAgent` is used as a convention 
# to represent agents that use a time index in their `__call__` function (not mandatory)
from salina import Agent, get_arguments, get_class, instantiate_class

# Agents(agent1,agent2,agent3,...) executes the different agents the one after the other
# TemporalAgent(agent) executes an agent (e.g a TAgent) over multiple timesteps in the workspace, 
# or until a given condition is reached
from salina.agents import Agents, RemoteAgent, TemporalAgent

# GymAgent (resp. AutoResetGymAgent) are agents able to execute a batch of gym environments
# without (resp. with) auto-resetting. These agents produce multiple variables in the workspace: 
# ’env/env_obs’, ’env/reward’, ’env/timestep’, ’env/done’, ’env/initial_state’, ’env/cumulated_reward’, 
# ... When called at timestep t=0, then the environments are automatically reset. 
# At timestep t>0, these agents will read the ’action’ variable in the workspace at time t − 1
from salina.agents.gyma import AutoResetGymAgent, GymAgent

# Not present in the A2C version...
from salina.logger import TFLogger

### Helper function

The function below is used below in the following piece of code later at the bottom of this colab:

`# Compute A2C loss`

`action_logp = _index(action_probs, action).log()`

It is used to transform the TxBxA action log probabilities matrix with a TxB index matrix to a TxB matrix where we have selected the log prob of the action taken by the agent.

In [None]:
def _index(tensor_3d, tensor_2d):
    """This function is used to index a 3d tensors using a 2d tensor"""
    x, y, z = tensor_3d.size()
    t = tensor_3d.reshape(x * y, z)
    tt = tensor_2d.reshape(x * y)
    v = t[torch.arange(x * y), tt]
    v = v.reshape(x, y)
    return v

## Definition of agents

The [A2C](http://proceedings.mlr.press/v48/mniha16.pdf) algorithm is an actor-critic algorithm. Thus we need an Actor agent and a Critic agent. 
The actor agent is built on an intermediate ProbAgent.

As explained above, in principle all agents should be built on the TAgent class, which itself inherits from the Agent class. In practice, TAgent is still in Salina for historical reasons, just using Agent simplifies the understanding. A TAgent is just an Agent with a 't' argument in the forward function.

### ProbAgent

A ProbAgent is a one hidden layer neural network which takes an observation as input and whose output is a probability given by a final softmax layer.

Note that to get the input observation from the environment we call

`observation = self.get(("env/env_obs", t))`

and that to perform an action in the environment we call

`self.set(("action_probs", t), probs)`

In [None]:
class ProbAgent(Agent):
    def __init__(self, observation_size, hidden_size, n_actions):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(observation_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions),
        )

    def forward(self, t, **kwargs):
        observation = self.get(("env/env_obs", t))
        scores = self.model(observation)
        probs = torch.softmax(scores, dim=-1)
        self.set(("action_probs", t), probs)

### ActionAgent

The ActionAgent takes action probabilities as input (coming from the ProbAgent) and outputs an action. In the deterministic case it takes the argmax, in the stochastic case it samples from the Categorical distribution.

In [None]:
class ActionAgent(Agent):
    def __init__(self):
        super().__init__()

    def forward(self, t, stochastic, **kwargs):
        probs = self.get(("action_probs", t))
        if stochastic:
            action = torch.distributions.Categorical(probs).sample()
        else:
            action = probs.argmax(1)

        self.set(("action", t), action)

### CriticAgent

A CriticAgent is a one hidden layer neural network which takes an observation as input and whose output is the value of this observation. It thus implements a $V(s)$ function. It would be straightforward to define another CriticAgent (call it a CriticQAgent by contrast to a CriticVAgent) that would take an observation and an action as input.

TODO: explain why we need the `squeeze(-1)`.

In [None]:
class CriticAgent(Agent):
    def __init__(self, observation_size, hidden_size, n_actions):
        super().__init__()
        self.critic_model = nn.Sequential(
            nn.Linear(observation_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, t, **kwargs):
        observation = self.get(("env/env_obs", t))
        critic = self.critic_model(observation).squeeze(-1)
        self.set(("critic", t), critic)

## Create the environment agent

### Using a gym environment

The function below is used in the params section at the bottom of the colab with `"env":{
      "classname": "__main__.make_env",
      "env_name": "CartPole-v0",
      "max_episode_steps": 100,
    },`

Using this instantiation approach from a function is useful if you define a new env for instance i.e you just change the 'classname' and put the arguments of the constructor directly and everything will work fine. This may be not natural a first sight, but if you start to use it, you will never go back again :) 

In [None]:
def make_env(env_name, max_episode_steps):
    return TimeLimit(gym.make(env_name), max_episode_steps=max_episode_steps)

The `instantiate_class`, `get_class` and `get_arguments` functions are available in the [main/salina/__init__.py file](https://github.com/facebookresearch/salina/blob/main/salina/__init__.py). The `get_class` function reads the `classname` in the parameters to create the appropriate type of object, and the `get_arguments` function reads the local paremeters and their values to set them into the corresponding object. 

Note that in practice Hydra provides the same mechanisms, so the Hydra `instantiate` function could have been used instead.

In [None]:
class EnvAgent(GymAgent):
  # Create the environment agent
  # This agent implements N gym environments with auto-reset
  def __init__(self, cfg):
    super().__init__(
      get_class(cfg.algorithm.env),
      get_arguments(cfg.algorithm.env),
      n_envs=cfg.algorithm.n_envs,
    )
    self.env = instantiate_class(cfg.algorithm.env)

  # TODO: replace the code below by a unique context-sensitive function 
  # that returns self.action_space.shape[0] or self.action_space.n
  # depending on whether the action space is a Box or a Discrete space


  # This is necessary to create the corresponding RL agent
  def get_obs_and_actions_sizes(self):
    if self.action_space.isinstance(gym.spaces.Box):
        # Return the size of the observation and action spaces of the environment
        # In the case of a continuous action environment
        return self.observation_space.shape[0], self.action_space.shape[0]
    elif self.action_space.isinstance(gym.spaces.Discrete):
        # Return the size of the observation and action spaces of the environment
      return self.observation_space.shape[0], self.action_space.n
    else:
      print ("unknown type of action space", self.action_space)
      return None

### Create the A2C agent

The code below is rather straightforward. Note that we have not defined anything about data collection, using a RolloutBuffer or something to store the n_step return so far. This will come inside the training loop below.

Interestingly, the loop between the policy and the environment is first defined as a collection of agents, and then embedeed into a single TemporalAgent.

We delete the environment (not the environment agent) with `del env_agent.env` once we do not need it anymore just to avoid mistakes afterwards.

In [None]:
# Create the A2C Agent
def create_a2c_agent(cfg, env_agent):
  observation_size,  n_actions = env_agent.get_obs_and_actions_sizes()
  del env_agent.env
  prob_agent = ProbAgent(
      observation_size, cfg.algorithm.architecture.hidden_size, n_actions
  )
  action_agent = ActionAgent()
  critic_agent = CriticAgent(
    observation_size, cfg.algorithm.architecture.hidden_size, n_actions
  )

  # Combine env and policy agents
  agent = Agents(env_agent, prob_agent, action_agent)
  # Get an agent that is executed on a complete workspace
  agent = TemporalAgent(agent)
  agent.seed(cfg.algorithm.env_seed)
  return agent, prob_agent, critic_agent

### The Logger class

The logger class below is not generic, it is specifically designed in the context of this A2C colab.

The logger parameters are defined below in `params = { "logger":{ ...`

In this colab, the logger is defined as `salina.logger.TFLogger` so as to use a tensorboard visualisation (see the parameters part below).
Note that the salina Logger is also saving the log in a readable format such that you can use `Logger.read_directories(...)` to read multiple logs, create a dataframe, and analyze many experiments afterward in a notebook for instance. 

The code for the different kinds of loggers is available in the [main/salina/logger.py file](https://github.com/facebookresearch/salina/blob/main/salina/logger.py).

Having logging provided under the hood is one of the features where using RL libraries like SaLinA will allow you to save time.

`instantiate_class` is an inner SaLinA mechanism. The `instantiate_class`function is available in the [main/salina/__init__.py file](https://github.com/facebookresearch/salina/blob/main/salina/__init__.py).

In [None]:
class Logger():

  def __init__(self, cfg):
    self.logger = instantiate_class(cfg.logger)

  def add_log(self, log_string, loss, epoch):
    self.logger.add_scalar(log_string, loss.item(), epoch)

  # Log losses
  def log_losses(self, cfg, epoch, critic_loss, entropy_loss, a2c_loss):
    self.add_log("critic_loss", critic_loss, epoch)
    self.add_log("entropy_loss", entropy_loss, epoch)
    self.add_log("a2c_loss", a2c_loss, epoch)


### Setup the optimizers

We use a single optimizer to tune the parameters of the actor (in the prob_agent part) and the critic (in the critic_agent part). It would be possible to have two optimizers which would work separately on the parameters of each component agent, but it would be more complicated because updating the actor requires the gradient of the critic.

In [None]:
# Configure the optimizer over the a2c agent
def setup_optimizers(cfg, prob_agent, critic_agent):
  optimizer_args = get_arguments(cfg.algorithm.optimizer)
  parameters = nn.Sequential(prob_agent, critic_agent).parameters()
  optimizer = get_class(cfg.algorithm.optimizer)(parameters, **optimizer_args)
  return optimizer

### Execute agent

This is the tricky part with SaLinA, the one we need to understand in detail. The difficulty lies in the copy of the last step and the way to deal with the n_steps return.

The call to `agent(workspace, t=1, n_steps=cfg.algorithm.n_timesteps - 1, stochastic=True)` makes the agent run a number of steps in the workspace. In practice, it calls [this function](https://github.com/facebookresearch/salina/blob/47bea8b980ca3ce2461ada82a94c2e4cc59f125d/salina/agent.py#L58) which makes a forward pass of the agent network using the workspace data and updates the workspace accordingly.

Now, if we start at the first epoch (`epoch=0`), we start from the first step (`t=0`). But when subsequently we perform the next epochs (`epoch>0`), there is a risk that we do not cover the transition at the border between the previous epoch and the current epoch. To avoid this risk, we need to shift the time indexes, hence the (`t=1`) and (`cfg.algorithm.n_timesteps - 1`).

In [None]:
def execute_agent(cfg, epoch, workspace, agent):
  if epoch > 0:
      workspace.zero_grad()
      workspace.copy_n_last_steps(1)
      agent(
        workspace, t=1, n_steps=cfg.algorithm.n_timesteps - 1, stochastic=True
      )
  else:
    agent(workspace, t=0, n_steps=cfg.algorithm.n_timesteps, stochastic=True)

### Compute critic loss

Note the `critic[1:].detach()` in the computation of the temporal difference target. The idea is that we compute this target as a function of $V(s_{t+1})$, but we do not want to apply gradient descent on this $V(s_{t+1})$, we will only apply gradient descent to the $V(s_t)$ according to this target value.

In practice, `x.detach()` detaches a computation graph from a tensor, so it avoids computing a gradient over this tensor.

Note also the trick to deal with terminal states. If the state is terminal, $V(s_{t+1})$ does not make sense. Thus we need to ignore this term. So we multiply the term by (1 - done): if done is False (=0), we get the term. If done is true (=1), we are at a terminal state and (1- done) = 0, so we ignore the term. This trick is used in many RL libraries, e.g. SB3.

TODO: understand why we convert into float with `.float()` rather than into an integer.

In [None]:
def compute_critic_loss(cfg, reward, done, critic):
  # Compute temporal difference
  target = reward[1:] + cfg.algorithm.discount_factor * critic[1:].detach() * (1 - done[1:].float())
  td = target - critic[:-1]

  # Compute critic loss
  td_error = td ** 2
  critic_loss = td_error.mean()
  return critic_loss, td

### Compute A2C loss

In [None]:
def compute_a2c_loss(action_probs, action, td):
  action_logp = _index(action_probs, action).log()
  a2c_loss = action_logp[:-1] * td.detach()
  return a2c_loss.mean()

## Main training loop

Note that everything about the shared workspace between all the agents is completely hidden under the hood. This results in a gain of productivity, at the expense of having to dig into the salina code if you want to understand the details, change the multiprocessing model, etc.

Note that we `optimizer.zero_grad()`, `loss.backward()` and `optimizer.step()` lines. Several things need to be explained here.
- `optimizer.zero_grad()` is necessary to cancel all the gradients computed at the previous iterations
- note that we sum all the losses, both for the critic and the actor, before applying back-propagation with `loss.backward()`. At first glance, summing these losses may look weird, as the actor and the critic receive different updates with different parts of the loss. This mechanism relies on the central property of tensor manipulation libraries like TensorFlow and pytorch. In pytorch, each loss tensor comes with its own graph of computation for back-propagating the gradient, in such a way that when you back-propagate the loss, the adequate part of the loss is applied to the adequate parameters.
These mechanisms are partly explained [here](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html).
- since the optimizer has been set to work with both the actor and critic parameters, `optimizer.step()` will optimize both agents and pytorch ensure that each will receive its own part of the gradient.

In [None]:
def run_a2c(cfg):
  # 1)  Build the  logger
  logger = Logger(cfg)
  
  # 2) Create the environment agent
  env_agent = EnvAgent(cfg)

  # 3) Create the A2C Agent
  a2c_agent, prob_agent, critic_agent = create_a2c_agent(cfg, env_agent)

  # 4) Create the temporal critic agent to compute critic values over the workspace
  tcritic_agent = TemporalAgent(critic_agent)

  # 5) Configure the workspace to the right dimension
  # Note that no parameter is needed to create the workspace. 
  # In the training loop, calling the agent() and critic_agent() 
  # will take the workspace as parameter
  workspace = salina.Workspace()

  # 6) Configure the optimizer over the a2c agent
  optimizer = setup_optimizers(cfg, prob_agent, critic_agent)
  
  # 7) Training loop
  epoch = 0
  for epoch in range(cfg.algorithm.max_epochs):
    # Execute the agent in the workspace
    execute_agent(cfg, epoch, workspace, a2c_agent)

    # Compute the critic value over the whole workspace
    tcritic_agent(workspace, n_steps=cfg.algorithm.n_timesteps)

    # Get relevant tensors (size are timestep x n_envs x ....)
    critic, done, action_probs, reward, action = workspace[
        "critic", "env/done", "action_probs", "env/reward", "action"
      ]

    # Compute critic loss
    critic_loss, td = compute_critic_loss(cfg, reward, done, critic)

    # Compute entropy loss
    entropy_loss = torch.distributions.Categorical(action_probs).entropy().mean()

    # Compute A2C loss
    a2c_loss = compute_a2c_loss(action_probs, action, td)

    # Store the losses for tensorboard display
    logger.log_losses(cfg, epoch, critic_loss, entropy_loss, a2c_loss)

    # Compute the total loss
    loss = (
      -cfg.algorithm.entropy_coef * entropy_loss
      + cfg.algorithm.critic_coef * critic_loss
      - cfg.algorithm.a2c_coef * a2c_loss
    )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Compute the cumulated reward on final_state
    creward = workspace["env/cumulated_reward"]
    creward = creward[done]
    if creward.size()[0] > 0:
      logger.add_log("reward", creward.mean(), epoch)



## Definition of the parameters

The logger is defined as `salina.logger.TFLogger` so as to use a tensorboard visualisation.

In [None]:
params={
  "logger":{
    "classname": "salina.logger.TFLogger",
    "log_dir": "./tmp",
    "cache_size": 10000,
    "every_n_seconds": 10,
    "verbose": False,    
    },

  "algorithm":{
    "env_seed": 432,
    "n_envs": 8,
    "n_timesteps": 16,
    "max_epochs": 10000,
    "discount_factor": 0.95,
    "entropy_coef": 0.001,
    "critic_coef": 1.0,
    "a2c_coef": 0.1,
    "architecture":{"hidden_size": 32},
    "env":{
      "classname": "__main__.make_env",
      "env_name": "CartPole-v1",
      "max_episode_steps": 100,
    },
    "optimizer":
    {
      "classname": "torch.optim.Adam",
      "lr": 0.01,
    }
  }
}

### Launching tensorboard to visualize the results

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir ./tmp
from omegaconf import DictConfig, OmegaConf
config=OmegaConf.create(params)
run_a2c(config)

TypeError: ignored