# JAX Mava Quickstart Notebook
<img src="https://raw.githubusercontent.com/instadeepai/Mava/develop/docs/images/mava.png" />

### This notebook provides an easy introducion to the [Mava](https://github.com/instadeepai/Mava) framework by showing how to construct a multi-agent system, and train it from scratch in a simple evironment. 

<a href="https://colab.research.google.com/github/instadeepai/Mava/blob/develop/examples/quickstart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


### 1. Installing packages

We start by installing the necessary packages.

In [None]:
%%capture
#@title Install required packages. (Run Cell)
# Fix common bug in installing 'box-2d'
!apt-get install swig

# Get Mava
%pip install "id-mava[reverb,jax,envs] @ git+https://github.com/instadeepai/mava.git"

# Visualisation stuff
!apt-get update -y &&  apt-get install -y xvfb &&  apt-get install -y python-opengl && apt-get install ffmpeg && pip install pyvirtualdisplay 

# Colab has an old version of cloudpickle
%pip install cloudpickle -U

# Agent visualisation dependencies
!apt-get install -y python-opengl
!apt-get install -y xvfb python-opengl ffmpeg
%pip install pyglet==1.3.2

At this point, it is necessary to restart the runtime. On the top menu, click `Runtime -> Restart runtime`. Alternatively, use the shortcut, `âŒ˜/Ctrl + M .`

### 2. Import the necessary modules

We first import the necessary modules that will be used to construct the multi-agent system and visualise the agents.

In [None]:
#@title Import required packages. (Run Cell)
import functools
from datetime import datetime
from typing import Any

import optax
from absl import app, flags

from mava.systems import ippo
from mava.utils.environments import debugging_utils
from mava.utils.loggers import logger_utils
from mava.components.building.environments import MonitorExecutorEnvironmentLoop

# Imports for agent visualisation
import os
from IPython.display import HTML
from pyvirtualdisplay import Display
import pyglet

display = Display(visible=0, size=(1024, 768))
display.start()
config = pyglet.gl.Config(double_buffer=True)
os.environ["DISPLAY"] = ":" + str(display.display)
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]= "false"
os.environ['PYGLET_SHADOW_WINDOW'] = '0'


### 3. Train an `IPPO` system

For this example, we will train IPPO on the simple-spread environment. This environment has 3 agents that all need to move to specific locations without bumping into each other. We start by defining the network architecture and network optimisers for our agents. The default network file for ippo can be found [here](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/ippo/networks.py).

In [None]:
# Create the network factory.
def network_factory(*args: Any, **kwargs: Any) -> Any:
    return ippo.make_default_networks(  # type: ignore
        policy_layer_sizes=(64, 64),
        critic_layer_sizes=(64, 64, 64),
        *args,
        **kwargs,
    )
  
# Optimisers.
policy_optimiser = optax.chain(
    optax.clip_by_global_norm(40.0), optax.scale_by_adam(), optax.scale(-1e-4)
)

critic_optimiser = optax.chain(
    optax.clip_by_global_norm(40.0), optax.scale_by_adam(), optax.scale(-1e-4)
)

We now select the environment we want to train on. We will use the [simple spread](https://github.com/instadeepai/Mava#debugging) environment.

In [None]:
env_name = "simple_spread"
action_space = "discrete"

environment_factory = functools.partial(
    debugging_utils.make_environment,
    env_name=env_name,
    action_space=action_space,
    )

Next, we specify the logging and checkpointing configuration for our system. 

In [None]:
# Directory to store checkpoints and log data. 
base_dir = "~/mava"

# File name 
mava_id = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

# Log every [log_every] seconds
log_every = 15
logger_factory = functools.partial(
    logger_utils.make_logger,
    directory=base_dir,
    to_terminal=True,
    to_tensorboard=True,
    time_stamp=mava_id,
    time_delta=log_every,
)

# Checkpointer appends "Checkpoints" to experiment_path
experiment_path = f"{base_dir}/{mava_id}"

Finally, we construct our multi-agent PPO system. A system in Mava represents the complete multi-agent setup and when it is launched, it runs executors (experience generators), a trainer (the network updater), a data server (buffer for experience) and a parameter server (to store network parameters).

In [None]:
# Create the system.
system = ippo.IPPOSystem()

# Add the gameplay monitor component
system.update(MonitorExecutorEnvironmentLoop)

# Build the system.
system.build(
    environment_factory=environment_factory,
    network_factory=network_factory,
    logger_factory=logger_factory,
    experiment_path=experiment_path,
    policy_optimiser=policy_optimiser,
    critic_optimiser=critic_optimiser,
    run_evaluator=True,
    epoch_batch_size=5,
    num_epochs=15,
    num_executors=2,
    multi_process=True,
    record_every=1,
)

In this example, we use IPPO which is already implemented inside Mava. However, we add the `MonitorExecutorEnvironmentLoop` component to record episodes for later inspection.  We use `system.update` because we already have an environment loop component inside Mava and we are simply updating it. If we want to add an entirely new component to the system, we use `system.add`. 

Mava is designed with flexibility in mind and allows users to easily build on top of existing systems. A system is just a collection of components, which can be though of as self-contained pieces of code (building block) that add functionality to a system. The IPPO system's components can be found [here](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/ippo/system.py). 

The arguments we provide to our system build overwrite the default config values from the existing components in the system:

`environment_factory` - used to construct the environment.

`network_factory` - used to construct the agent networks.

`logger_factory` - used to construct the loggers.

`experiment_path` - the destination to save the experiment results.

`policy_optimiser` - the optimiser used by the trainer to updated the policy weights.

`critic_optimiser` - the optimiser used by the trainer to updated the critic weights.

`run_evaluator` - a flag indicating whether a separate environment process should be run that tracks the system's performance using Tensorboard and possibly gameplay recordings.

`epoch_batch_size` - the batch size to use in the trainer when updating the agent networks.

`num_epochs` - the number of epochs to train on sampled data before discarding it.

`num_executors` - the number of experience generators (workers) to run in parallel.

`multi_process` - determines whether the code is run using multiple processes or using a single process. We are using the multiple processor setup for faster training.

`record_every` - determines how often the evaluator should record a gameplay video. 


We now run the system.

In [None]:
# Launch the system.
system.launch()

### 4. Visuallise our training results using Tensorboard


Load the tensorboard exension.

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

To view training results, start tensorboard and point it to the Mava `experiment_path` (where logs are saved). You might have to wait a few seconds and then refresh Tensorboard to see the logs.

A good score is a `evaluator/RawEpisodeReturn` between 30-40 after about 5 minutes of training (remember to refresh Tensorboard). Although this system is stochastic, it should reach that score after a few minutes.    

In [None]:
%tensorboard --logdir ~/mava/$mava_id

### 5. View agent recordings
Once a good score is reached, you can view the learned multi-agent behaviour by viewing the agent recordings.

Check if any agent recordings are available. 

In [None]:
! ls ~/mava/$mava_id/recordings

View the latest agent recording. 

In [None]:
import glob
import os 
import IPython

# Recordings
list_of_files = glob.glob(f"/root/mava/{mava_id}/recordings/*.html")

if(len(list_of_files) == 0):
  print("No recordings are available yet. Please wait or run the 'Run Multi-Agent PPO System.' cell if you haven't already done this.")
else:
  latest_file = max(list_of_files, key=os.path.getctime)
  print("Run the next cell to visualize your agents!")

If the agents are trained (*usually in about 5 minutes...*), they should move to the assigned landmarks.

<img src="https://raw.githubusercontent.com/instadeepai/Mava/develop/docs/images/simple_spread.png" width="250" height="250" />

In [None]:
# Latest file needs to point to the latest recording
IPython.display.HTML(filename=latest_file)

That's it! You have successfully trained a multi-agent system in Mava.

## What's next?

Now that you have an appreciation for the flexibility in Mava's design, let us now build our own custom component and add that to the existing IPPO system. Each component places self-contained pieces of code at specific points when we run the system, referred to as "hooks": and each of these hook function names starts with the "on" keyword. For example, When the system is building, it calls [these](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/builder.py) hooks defined in the builder, and when it is executing (generating training experience), it can call the hooks defined in the [executor](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/executor.py), [trainer](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/trainer.py) or [parameter_server](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/parameter_server.py).

A component must "subscribe" to the relevant hooks, depending on when the code needs to be executed. For example, after a training step (on_training_step_end()) or when selecting actions for each agent in order to generate training experience (on_execution_select_actions())

To illustrate this point, we will update the [advantage estimation component](https://github.com/instadeepai/Mava/blob/develop/mava/components/jax/training/advantage_estimation.py) with a simpler advantage estimate and override the on_training_utility_fns() hook defined in the trainer. This component defines a function called `gae_fn`, which is called by the system to calculate an advantage estimate when training. Notice below that we store this new advantage function inside `trainer.store.gae_fn`. The store is a place where components can save variables for other components to access.

In [None]:
%%capture
#@title Kill old runs. (Run Cell)
!ps aux  |  grep -i launchpad  |  awk '{print $2}'  |  xargs sudo kill -9

We start by defining our new component below.

In [None]:
"""Trainer components for advantage calculations."""

from dataclasses import dataclass
from typing import List, Optional, Tuple, Type
from mava.components.training.advantage_estimation import GAE

import jax
import jax.numpy as jnp
import numpy as np
import rlax

from mava.callbacks import Callback
from mava.components.training.base import Utility
from mava.core_jax import SystemTrainer

@dataclass
class AConfig:
    a_lambda: float = 0.95


class simpler_advantage_estimate(GAE):
    def __init__(
        self,
        config: AConfig = AConfig(),
    ):
        self.config = config

    def on_training_utility_fns(self, trainer: SystemTrainer) -> None:
        def simpler_advantage_estimate(
            rewards: jnp.ndarray, discounts: jnp.ndarray, values: jnp.ndarray
        ) -> Tuple[jnp.ndarray, jnp.ndarray]:


            # Instead of using the GAE we use a simpler one step
            # advantage estimate.

            # Pad the rewards so that rewards at the end can also be calculated.
            zeros_mask = jnp.zeros(shape=rewards.shape)
            padded_rewards = jnp.concatenate([rewards, zeros_mask], axis=0)
            cum_rewards = rewards.copy()
            seq_len = len(rewards)
            for i in range(1, seq_len):
                cum_rewards+=padded_rewards[i : i +seq_len]\
                              *jnp.power(self.config.a_lambda, i)

            # Calculate the advantage estimate.
            advantages = cum_rewards[:-1] - values[:-1]
            
            # Stop gradients from flowing through the advantage estimate.
            advantages = jax.lax.stop_gradient(advantages)

            # Set the target values and stop gradients from flowing backwards
            # through the target values.
            target_values = cum_rewards[:-1]
            target_values = jax.lax.stop_gradient(target_values)

            return advantages, target_values

        trainer.store.gae_fn = simpler_advantage_estimate

    @staticmethod
    def name() -> str:
        return "gae_fn"

Now that we have this new component we can add it to the PPO system using `system.update`. You can retrain the system by executing the cell below and again run steps 4 and 5 (above) to view the training results and gameplay footage. Does this system perform better or worse than the previous system? 

Feel free to update the advantage component to see if you can achieve better training results.

In [None]:
# File name 
mava_id = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

# Log every [log_every] seconds
log_every = 15
logger_factory = functools.partial(
    logger_utils.make_logger,
    directory=base_dir,
    to_terminal=True,
    to_tensorboard=True,
    time_stamp=mava_id,
    time_delta=log_every,
)

# Checkpointer appends "Checkpoints" to experiment_path
experiment_path = f"{base_dir}/{mava_id}"

# Create the system.
system = ippo.IPPOSystem()

# Add the gameplay monitor component
system.update(MonitorExecutorEnvironmentLoop)

# Update the system with out custom component.
system.update(simpler_advantage_estimate)

# Build the system.
system.build(
    environment_factory=environment_factory,
    network_factory=network_factory,
    logger_factory=logger_factory,
    experiment_path=experiment_path,
    policy_optimiser=policy_optimiser,
    critic_optimiser=critic_optimiser,
    run_evaluator=True,
    epoch_batch_size=5,
    num_epochs=15,
    num_executors=1,
    multi_process=True,
    record_every=10,
 )

# Launch the system.
system.launch()

Congratulations! You have created your own custom system. We hope that this tutorial has given you a taste of what Mava is capable of. We are excited to see how you use the repo. For more examples using different systems, environments and architectures, visit our [github page](https://github.com/instadeepai/Mava/tree/develop/ examples).