# JAX Mava Quickstart Notebook
<img src="https://raw.githubusercontent.com/instadeepai/Mava/develop/docs/images/mava.png" />

### This notebook provides an easy introducion to the [Mava](https://github.com/instadeepai/Mava) framework by showing how to construct a multi-agent system, and train it from scratch in a simple evironment. 

<a href="https://colab.research.google.com/github/instadeepai/Mava/blob/develop/examples/jax/quickstart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


### 1. Installing packages

We start by installing the necessary packages.

In [None]:
%%capture
#@title Install required packages. (Run Cell)
! rm -r ./Mava
! git clone https://github.com/instadeepai/Mava.git
!pip install ./Mava[reverb,jax,launchpad,envs]

# Installs for agent visualisation.
!pip install ./Mava[record_episode]
! apt-get update -y &&  apt-get install -y xvfb &&  apt-get install -y python-opengl && apt-get install ffmpeg && pip install pyvirtualdisplay 
# Google colab has an old version of cloudpickle - issue with lp.  
!pip install cloudpickle -U

### 2. Import the necessary modules

Import the necessary modules that will be used to construct the multi-agent system and visualise the agents.

In [None]:
#@title Import required packages. (Run Cell)
import functools
from datetime import datetime
from typing import Any

import optax
from absl import app, flags

from mava.systems.jax import mappo
from mava.utils.environments import debugging_utils
from mava.utils.loggers import logger_utils
from mava.components.jax.building.environments import MonitorExecutorEnvironmentLoop

# Imports for agent visualisation
import os
from IPython.display import HTML
from pyvirtualdisplay import Display

display = Display(visible=0, size=(1024, 768))
display.start()
os.environ["DISPLAY"] = ":" + str(display.display)

### 3. Train a multi-agent `PPO` system

We start by defining the networks and network optimizer for our agents.

In [None]:
# Create the network factory.
def network_factory(*args: Any, **kwargs: Any) -> Any:
    return mappo.make_default_networks(  # type: ignore
        policy_layer_sizes=(254, 254, 254),
        critic_layer_sizes=(512, 512, 256),
        *args,
        **kwargs,
    )
  
# Optimizer.
optimizer = optax.chain(
    optax.clip_by_global_norm(40.0), optax.scale_by_adam(), optax.scale(-1e-4)
)

We now select the environment we want to train on. We will use the [simple spread](https://github.com/instadeepai/Mava#debugging) environment.

In [None]:
env_name = "simple_spread"
action_space = "discrete"

environment_factory = functools.partial(
    debugging_utils.make_environment,
    env_name=env_name,
    action_space=action_space,
    )

Next, we specify the logging and checkpointing configuration for our system. 

In [None]:
# Directory to store checkpoints and log data. 
base_dir = "~/mava"

# File name 
mava_id = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

# Log every [log_every] seconds
log_every = 15
logger_factory = functools.partial(
    logger_utils.make_logger,
    directory=base_dir,
    to_terminal=True,
    to_tensorboard=True,
    time_stamp=mava_id,
    time_delta=log_every,
)

# Checkpointer appends "Checkpoints" to experiment_path
experiment_path = f"{base_dir}/{mava_id}"

Finally, we construct our multi-agent PPO system.

In [None]:
# Create the system.
system = mappo.MAPPOSystem()

# Add the gameplay monitor component
system.update(MonitorExecutorEnvironmentLoop)

# Build the system.
system.build(
    environment_factory=environment_factory,
    network_factory=network_factory,
    logger_factory=logger_factory,
    experiment_path=experiment_path,
    optimizer=optimizer,
    run_evaluator=True,
    sample_batch_size=20,
    num_epochs=15,
    num_executors=1,
    multi_process=False,
    single_process_max_episodes=500,
    record_every=1,
 )

In this example we use MAPPO which is already implemented inside Mava. However, we add the `MonitorExecutorEnvironmentLoop` component to record episodes for later inspection. Mava is designed with flexibility in mind and allow users to easily build on top of existing systems. A system is just a collection of components, which can be though of as self-contained pieces of code (building block) that add functionality to a system. The MAPPO system's components can be found [here](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/mappo/system.py).

We now run our system for 500 episodes. The `system.launch` method initialise and runs an executor (experience generator), a trainer (network updater), data server and parameter server process.

In [None]:
%%capture

# Launch the system.
system.launch()

### 4. Visuallise our training results using Tensorboard


Load the tensorboard exension. You will need to wait for the system to complete its 500 episodes.

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

To view training results, start tensorboard and point it to where we specified Mava should log to (~/mava/$mava_id).

A good score is a `RawEpisodeReturn` between 30-40. Although this system is stochastic, it should atleast reach that score by 500 executor episodes.    

In [None]:
%tensorboard --logdir ~/mava/$mava_id

### 5. View agent recordings
Once a good score is reached, you can view the learned multi-agent behaviour by viewing the agent recordings.

#### Check if any agent recordings are available. 

In [None]:
! ls ~/mava/$mava_id/recordings

#### View the latest agent recording. 

In [None]:
import glob
import os 
import IPython

# Recordings
list_of_files = glob.glob(f"/root/mava/{mava_id}/recordings/*.html")

if(len(list_of_files) == 0):
  print("No recordings are available yet. Please wait or run the 'Run Multi-Agent PPO System.' cell if you haven't already done this.")
else:
  latest_file = max(list_of_files, key=os.path.getctime)
  print("Run the next cell to visualize your agents!")

If the agents are trained (*usually around 500 episodes...*), they should move to the assigned landmarks.

<img src="https://raw.githubusercontent.com/instadeepai/Mava/develop/docs/images/simple_spread.png" width="250" height="250" />

In [None]:
# Latest file needs to point to the latest recording
IPython.display.HTML(filename=latest_file)

That's it! You have now trained a multi-agent system in Mava.

## What's next?
Mava is explicitly designed to make it easy to build custom multi-agent systems. As mentioned above, each system is just a combination of various [components](https://github.com/instadeepai/Mava/tree/develop/mava/components/jax). These components place self-contained pieces of code at selective places (called hooks) in the system. Each of these hook function names start with the `on` keyword. When the system is building it calls the hooks inside the [builder](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/builder.py), and when it is executing it calls the hooks inside the [executor](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/executor.py), [trainer](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/trainer.py) and [parameter_server](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/parameter_server.py).

Let us now try and build our own custom component and add that to the PPO implementation.  We will update the [advantage estimation component](https://github.com/instadeepai/Mava/blob/develop/mava/components/jax/training/advantage_estimation.py) with a simpler advantage estimate and overwrite the `on_training_utility_fns` located in the [trainer](https://github.com/instadeepai/Mava/blob/develop/mava/systems/jax/trainer.py#L45). This component creates a function called `gae_fn` which is called by the system to calculate an [advantage estimate](https://towardsdatascience.com/generalized-advantage-estimate-maths-and-code-b5d5bd3ce737) when training.

We start by defining our new component below. Notice that we store our new advantage function inside `trainer.store.gae_fn`. The `store` is a place where components can save variables for other components to access.

In [None]:
"""Trainer components for advantage calculations."""

from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Type
from mava.components.jax.training.advantage_estimation import GAE

import jax
import jax.numpy as jnp
import numpy as np
import rlax

from mava.callbacks import Callback
from mava.components.jax.training.base import Utility
from mava.core_jax import SystemTrainer

@dataclass
class AConfig:
    a_lambda: float = 0.95


class simpler_advantage_estimate(GAE):
    def __init__(
        self,
        config: AConfig = AConfig(),
    ):
        self.config = config

    def on_training_utility_fns(self, trainer: SystemTrainer) -> None:
        def simpler_advantage_estimate(
            rewards: jnp.ndarray, discounts: jnp.ndarray, values: jnp.ndarray
        ) -> Tuple[jnp.ndarray, jnp.ndarray]:


            # Instead of using the GAE we use a simpler one step
            # advantage estimate.

            # Pad the rewards so that rewards at the end can also be calculated.
            zeros_mask = jnp.zeros(shape=rewards.shape)
            padded_rewards = jnp.concatenate([rewards, zeros_mask], axis=0)
            cum_rewards = rewards.copy()
            seq_len = len(rewards)
            for i in range(1, seq_len):
                cum_rewards+=padded_rewards[i : i +seq_len]\
                              *jnp.power(self.config.a_lambda, i)

            # Calculate the advantage estimate.
            advantages = cum_rewards[:-1] - values[:-1]
            
            # Stop gradients from flowing through the advantage estimate.
            advantages = jax.lax.stop_gradient(advantages)

            # Set the target values and stop gradients from flowing backwards
            # through the target values.
            target_values = cum_rewards[:-1]
            target_values = jax.lax.stop_gradient(target_values)

            return advantages, target_values

        trainer.store.gae_fn = simpler_advantage_estimate

    @staticmethod
    def name() -> str:
        return "gae_fn"

    @staticmethod
    def config_class() -> Optional[Callable]:
        return AConfig

Now that we have this new component we can add it to the PPO system using `system.update`. You can retrain the system by executing the cell below. Once training is complete you can again run steps 4 and 5 (above) to view the training results and gameplay footage. Does this system perform better or worse than the previous system? 

Feel free to update the advantage component to see if you can achieve better training results. For more examples using different systems, environments and architectures, visit our [github page](https://github.com/instadeepai/Mava/tree/develop/examples/tf).

In [None]:
# File name 
mava_id = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

# Log every [log_every] seconds
log_every = 15
logger_factory = functools.partial(
    logger_utils.make_logger,
    directory=base_dir,
    to_terminal=True,
    to_tensorboard=True,
    time_stamp=mava_id,
    time_delta=log_every,
)

# Checkpointer appends "Checkpoints" to experiment_path
experiment_path = f"{base_dir}/{mava_id}"

# Create the system.
system = mappo.MAPPOSystem()

# Add the gameplay monitor component
system.update(MonitorExecutorEnvironmentLoop)

# Update the system with out custom component.
system.update(simpler_advantage_estimate)

# Build the system.
system.build(
    environment_factory=environment_factory,
    network_factory=network_factory,
    logger_factory=logger_factory,
    experiment_path=experiment_path,
    optimizer=optimizer,
    run_evaluator=True,
    sample_batch_size=5,
    num_epochs=15,
    num_executors=1,
    multi_process=False,
    single_process_max_episodes=500,
    record_every=10,
 )

# Launch the system.
system.launch()