## Outlook

In this colab we give a detailed documentation of a version of the A2C algorithm using SaLinA, so as to better understand the inner mechanisms.

### Installation

The SaLinA library is [here](https://github.com/facebookresearch/salina).

In [4]:
import functools
import time

%pip install gym==0.21.0
%pip install git+https://github.com/facebookresearch/salina.git@main
%pip install pygame




Note: you may need to restart the kernel to use updated packages.
Collecting git+https://github.com/facebookresearch/salina.git@main
  Cloning https://github.com/facebookresearch/salina.git (to revision main) to /tmp/pip-req-build-aa2nevyc
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/salina.git /tmp/pip-req-build-aa2nevyc
  Resolved https://github.com/facebookresearch/salina.git to commit ec9d114a5a0d79157430f59ac89a9171f5c67f35
  Preparing metadata (setup.py) ... [?25ldone
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


### Install our package:
Our github repo is [here](https://github.com/Anidwyd/pandroide-svpg)

In [5]:
%pip install git+https://github.com/Anidwyd/pandroide-svpg.git@main

Collecting git+https://github.com/Anidwyd/pandroide-svpg.git@main
  Cloning https://github.com/Anidwyd/pandroide-svpg.git (to revision main) to /tmp/pip-req-build-mvgdx3fv
  Running command git clone --filter=blob:none --quiet https://github.com/Anidwyd/pandroide-svpg.git /tmp/pip-req-build-mvgdx3fv
  Resolved https://github.com/Anidwyd/pandroide-svpg.git to commit 181a857e687de70f958f90dee07f2035165b9d0f
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: svpg
  Building wheel for svpg (setup.py) ... [?25ldone
[?25h  Created wheel for svpg: filename=svpg-1.0.dev0+181a857e687de70f958f90dee07f2035165b9d0f-py3-none-any.whl size=18945 sha256=94d82ffd01a73f00b60420c25f224fffd36af41f857bbfc2f8a21ee51adb7696
  Stored in directory: /tmp/pip-ephem-wheel-cache-oda9vdc6/wheels/51/50/c8/de3e463680b37a45484181746d89953c15bfa7a0c55976d26a
Successfully built svpg
Installing collected packages: svpg
  Attempting uninstall: svpg
    Found existing installatio

In [1]:
from svpg.algos.a2c.mono.main_loop import run_a2c

### Helper function

[This function](https://github.com/Anidwyd/pandroide-svpg/blob/main/svpg/helpers/utils.py) is used below in the following piece of code later at the bottom of this colab:

`# Compute A2C loss`

`action_logp = _index(action_probs, action).log()`

It is used to transform the TxBxA action log probabilities matrix with a TxB index matrix to a TxB matrix where we have selected the log prob of the action taken by the agent.

### Definition of agents

All agents are defined [here](https://github.com/Anidwyd/pandroide-svpg/blob/main/svpg/algos/a2c/mono/agents.py)

### The Logger class

Using the logger provided by SaLiNa that in fact is TensorBoard under the hood.The code is [here](https://github.com/Anidwyd/pandroide-svpg/blob/main/svpg/helpers/logger.py)

### Setup the optimizers

We use a single optimizer to tune the parameters of the actor (in the prob_agent part) and the critic (in the critic_agent part). It would be possible to have two optimizers which would work separately on the parameters of each component agent, but it would be more complicated because updating the actor requires the gradient of the critic.
The code is [here](https://github.com/Anidwyd/pandroide-svpg/blob/main/svpg/algos/a2c/mono/optimizer.py)

### Compute critic loss and a2c loss

Note the `critic[1:].detach()` in the computation of the temporal difference target. The idea is that we compute this target as a function of $V(s_{t+1})$, but we do not want to apply gradient descent on this $V(s_{t+1})$, we will only apply gradient descent to the $V(s_t)$ according to this target value.

In practice, `x.detach()` detaches a computation graph from a tensor, so it avoids computing a gradient over this tensor.

Note also the trick to deal with terminal states. If the state is terminal, $V(s_{t+1})$ does not make sense. Thus we need to ignore this term. So we multiply the term by (1 - done): if done is False (=0), we get the term. If done is true (=1), we are at a terminal state and (1- done) = 0, so we ignore the term. This trick is used in many RL libraries, e.g. SB3.

The code is [here](https://github.com/Anidwyd/pandroide-svpg/blob/main/svpg/algos/a2c/mono/loss.py)

## Main training loop

Note that everything about the shared workspace between all the agents is completely hidden under the hood. This results in a gain of productivity, at the expense of having to dig into the salina code if you want to understand the details, change the multiprocessing model, etc.

Note that we `optimizer.zero_grad()`, `loss.backward()` and `optimizer.step()` lines. Several things need to be explained here.
- `optimizer.zero_grad()` is necessary to cancel all the gradients computed at the previous iterations
- note that we sum all the losses, both for the critic and the actor, before applying back-propagation with `loss.backward()`. At first glance, summing these losses may look weird, as the actor and the critic receive different updates with different parts of the loss. This mechanism relies on the central property of tensor manipulation libraries like TensorFlow and pytorch. In pytorch, each loss tensor comes with its own graph of computation for back-propagating the gradient, in such a way that when you back-propagate the loss, the adequate part of the loss is applied to the adequate parameters.
These mechanisms are partly explained [here](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html).
- since the optimizer has been set to work with both the actor and critic parameters, `optimizer.step()` will optimize both agents and pytorch ensure that each will receive its own part of the gradient.

The code for the main training loop is [here](https://github.com/Anidwyd/pandroide-svpg/blob/main/svpg/algos/a2c/mono/run_a2c.py)

## Definition of the parameters

The logger is defined as `salina.logger.TFLogger` so as to use a tensorboard visualisation.

In [2]:
params={
  "logger":{
    "classname": "salina.logger.TFLogger",
    "log_dir": "./tmp",
    "cache_size": 10000,
    "every_n_seconds": 10,
    "verbose": False,    
    },

  "algorithm":{
    "env_seed": 432,
    "n_envs": 8,
    "n_timesteps": 16,
    "max_epochs": 10000,
    "discount_factor": 0.95,
    "entropy_coef": 0.001,
    "critic_coef": 1.0,
    "a2c_coef": 0.1,
    "architecture":{"hidden_size": 32},
    "env":{
      "classname": "svpg.algos.a2c.mono.agents.make_env",
      "env_name": "CartPole-v1",
      "max_episode_steps": 500,
    },
    "optimizer":
    {
      "classname": "torch.optim.Adam",
      "lr": 0.01,
    }
  }
}

### Launching tensorboard to visualize the results

In [3]:
%load_ext tensorboard
%tensorboard --logdir ./tmp
from omegaconf import OmegaConf
config = OmegaConf.create(params)
run_a2c(config)

Reusing TensorBoard on port 6006 (pid 15024), started 0:01:21 ago. (Use '!kill 15024' to kill it.)