# JAX Mava Quickstart Notebook
<img src="https://raw.githubusercontent.com/instadeepai/Mava/develop/docs/images/mava.png" />

### This notebook offers a simple initiation to [Mava](https://github.com/instadeepai/Mava) through the illustration of training a multi-agent PPO (MAPPO) using the Robot Warehouse environment as an example.

<a href="https://colab.research.google.com/github/instadeepai/Mava/blob/develop/quickstart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Requirements

We start by installing and importing the necessary packages.

In [None]:
# @title Install Mava
! pip install git+https://github.com/instadeepai/mava.git@feat/pure-jax-mava
! pip install "jax[cuda12_pip]<=0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [17]:
#@title Import required packages. (Run Cell)

from typing import Any, Callable, Dict, Sequence, Tuple
from colorama import Fore, Style

import optax
from optax._src.base import OptState
import chex
import distrax
import flax.linen as nn
from flax import struct
from flax.core.frozen_dict import FrozenDict
from flax.linen.initializers import constant, orthogonal
import jax
import jax.numpy as jnp
import numpy as np

# Env requirements
import jumanji
from jumanji.env import Environment
from jumanji.environments.routing.robot_warehouse import Observation, State
from jumanji.environments.routing.robot_warehouse.generator import RandomGenerator
from jumanji import specs
from jumanji.wrappers import AutoResetWrapper

# Mava Helpful functions and types
from mava.utils.jax import merge_leading_dims
from mava.utils.timing_utils import TimeIt
from mava.wrappers.jumanji import (
    AgentIDWrapper,
    LogWrapper,
    ObservationGlobalState,
    RwareMultiAgentWithGlobalStateWrapper,
)
from mava.types import ExperimentOutput, LearnerState, PPOTransition
from mava.evaluator import evaluator_setup

# Plot requirements
import matplotlib.pyplot as plt
from IPython.display import clear_output
import time
%matplotlib inline
import seaborn as sns
sns.set()
sns.set_style("white")
sns.color_palette("colorblind")
import time

# Trainer



## Network

Initially, we construct the ActorCritic network using components from the Flax library. The application of the actor-critic's function will then be "vmapped" across distinct agents, with the in_axes parameter applied solely to the observation and not the network parameters.

In [18]:
class ActorCritic(nn.Module):
    """Actor Critic Network."""

    action_dim: Sequence[int]

    @nn.compact
    def __call__(self, observation: Observation) -> Tuple[distrax.Categorical, chex.Array]:
        """Forward pass."""
        x = observation.agents_view

        actor_output = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
        actor_output = nn.relu(actor_output)
        actor_output = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(
            actor_output
        )
        actor_output = nn.relu(actor_output)
        actor_output = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_output)

        masked_logits = jnp.where(
            observation.action_mask,
            actor_output,
            jnp.finfo(jnp.float32).min,
        )
        actor_policy = distrax.Categorical(logits=masked_logits)

        y = observation.global_state

        critic_output = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(y)
        critic_output = nn.relu(critic_output)
        critic_output = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(
            critic_output
        )
        critic_output = nn.relu(critic_output)
        critic_output = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic_output
        )

        return actor_policy, jnp.squeeze(critic_output, axis=-1)

## Learner Function
The get_learner_fn function returns a learner function returns `ExperimentOutput`, encapsulating updated learner state, episode information, and loss metrics. This function is essential in training the MAPPO.

In [19]:
def get_learner_fn(
    env: jumanji.Environment, apply_fn: Callable, update_fn: Callable, config: Dict
) -> Callable:
    """Get the learner function."""

    def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]:
        """A single update of the network.

        This function steps the environment and records the trajectory batch for
        training. It then calculates advantages and targets based on the recorded
        trajectory and updates the actor and critic networks based on the calculated
        losses.

        Args:
            learner_state (NamedTuple):
                - params (FrozenDict): The current model parameters.
                - opt_state (OptState): The current optimizer state.
                - rng (PRNGKey): The random number generator state.
                - env_state (State): The environment state.
                - last_timestep (TimeStep): The last timestep in the current trajectory.
            _ (Any): The current metrics info.
        """

        def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]:
            """Step the environment."""
            params, opt_state, rng, env_state, last_timestep = learner_state

            # SELECT ACTION
            rng, policy_rng = jax.random.split(rng)
            actor_policy, value = apply_fn(params, last_timestep.observation)
            action = actor_policy.sample(seed=policy_rng)
            log_prob = actor_policy.log_prob(action)

            # STEP ENVIRONMENT
            env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

            # LOG EPISODE METRICS
            done, reward = jax.tree_util.tree_map(
                lambda x: jnp.repeat(x, config["num_agents"]).reshape(config["num_envs"], -1),
                (timestep.last(), timestep.reward),
            )
            info = {
                "episode_return": env_state.episode_return_info,
                "episode_length": env_state.episode_length_info,
            }

            transition = PPOTransition(
                done, action, value, reward, log_prob, last_timestep.observation, info
            )
            learner_state = LearnerState(params, opt_state, rng, env_state, timestep)
            return learner_state, transition

        # STEP ENVIRONMENT FOR ROLLOUT LENGTH
        learner_state, traj_batch = jax.lax.scan(
            _env_step, learner_state, None, config["rollout_length"]
        )

        # CALCULATE ADVANTAGE
        params, opt_state, rng, env_state, last_timestep = learner_state
        _, last_val = apply_fn(params, last_timestep.observation)

        def _calculate_gae(
            traj_batch: PPOTransition, last_val: chex.Array
        ) -> Tuple[chex.Array, chex.Array]:
            """Calculate the GAE."""

            def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple:
                """Calculate the GAE for a single transition."""
                gae, next_value = gae_and_next_value
                done, value, reward = (
                    transition.done,
                    transition.value,
                    transition.reward,
                )
                delta = reward + config["gamma"] * next_value * (1 - done) - value
                gae = delta + config["gamma"] * config["gae_lambda"] * (1 - done) * gae
                return (gae, value), gae

            _, advantages = jax.lax.scan(
                _get_advantages,
                (jnp.zeros_like(last_val), last_val),
                traj_batch,
                reverse=True,
                unroll=16,
            )
            return advantages, advantages + traj_batch.value

        advantages, targets = _calculate_gae(traj_batch, last_val)

        def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
            """Update the network for a single epoch."""

            def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
                """Update the network for a single minibatch."""
                params, opt_state = train_state
                traj_batch, advantages, targets = batch_info

                def _loss_fn(
                    params: FrozenDict,
                    opt_state: OptState,
                    traj_batch: PPOTransition,
                    gae: chex.Array,
                    targets: chex.Array,
                ) -> Tuple:
                    """Calculate the loss."""
                    # RERUN NETWORK
                    actor_policy, value = apply_fn(params, traj_batch.obs)
                    log_prob = actor_policy.log_prob(traj_batch.action)

                    # CALCULATE VALUE LOSS
                    value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
                        -config["clip_eps"], config["clip_eps"]
                    )
                    value_losses = jnp.square(value - targets)
                    value_losses_clipped = jnp.square(value_pred_clipped - targets)
                    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

                    # CALCULATE ACTOR LOSS
                    ratio = jnp.exp(log_prob - traj_batch.log_prob)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                    loss_actor1 = ratio * gae
                    loss_actor2 = (
                        jnp.clip(
                            ratio,
                            1.0 - config["clip_eps"],
                            1.0 + config["clip_eps"],
                        )
                        * gae
                    )
                    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                    loss_actor = loss_actor.mean()
                    entropy = actor_policy.entropy().mean()

                    total_loss = (
                        loss_actor + config["vf_coef"] * value_loss - config["ent_coef"] * entropy
                    )
                    return total_loss, (value_loss, loss_actor, entropy)

                grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                loss_info, grads = grad_fn(params, opt_state, traj_batch, advantages, targets)

                # Compute the parallel mean (pmean) over the batch.
                # This calculation is inspired by the Anakin architecture demo notebook.
                # available at https://tinyurl.com/26tdzs5x
                # This pmean could be a regular mean as the batch axis is on the same device.
                grads, loss_info = jax.lax.pmean((grads, loss_info), axis_name="batch")
                # pmean over devices.
                grads, loss_info = jax.lax.pmean((grads, loss_info), axis_name="device")

                updates, new_opt_state = update_fn(grads, opt_state)
                new_params = optax.apply_updates(params, updates)

                return (new_params, new_opt_state), loss_info

            params, opt_state, traj_batch, advantages, targets, rng = update_state
            rng, shuffle_rng = jax.random.split(rng)

            # SHUFFLE MINIBATCHES
            batch_size = config["rollout_length"] * config["num_envs"]
            permutation = jax.random.permutation(shuffle_rng, batch_size)
            batch = (traj_batch, advantages, targets)
            batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch)
            shuffled_batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=0), batch
            )
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.reshape(x, [config["num_minibatches"], -1] + list(x.shape[1:])),
                shuffled_batch,
            )

            # UPDATE MINIBATCHES
            (params, opt_state), loss_info = jax.lax.scan(
                _update_minibatch, (params, opt_state), minibatches
            )

            update_state = (params, opt_state, traj_batch, advantages, targets, rng)
            return update_state, loss_info

        update_state = (params, opt_state, traj_batch, advantages, targets, rng)

        # UPDATE EPOCHS
        update_state, loss_info = jax.lax.scan(
            _update_epoch, update_state, None, config["ppo_epochs"]
        )

        params, opt_state, traj_batch, advantages, targets, rng = update_state
        learner_state = LearnerState(params, opt_state, rng, env_state, last_timestep)
        metric = traj_batch.info
        return learner_state, (metric, loss_info)

    def learner_fn(learner_state: LearnerState) -> ExperimentOutput:
        """Learner function.

        This function represents the learner, it updates the network parameters
        by iteratively applying the `_update_step` function for a fixed number of
        updates. The `_update_step` function is vectorized over a batch of inputs.

        Args:
            learner_state (NamedTuple):
                - params (FrozenDict): The initial model parameters.
                - opt_state (OptState): The initial optimizer state.
                - rng (chex.PRNGKey): The random number generator state.
                - env_state (LogEnvState): The environment state.
                - timesteps (TimeStep): The initial timestep in the initial trajectory.
        """

        batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch")

        learner_state, (metric, loss_info) = jax.lax.scan(
            batched_update_step, learner_state, None, config["num_updates_per_eval"]
        )
        total_loss, (value_loss, loss_actor, entropy) = loss_info
        return ExperimentOutput(
            learner_state=learner_state,
            episodes_info=metric,
            total_loss=total_loss,
            value_loss=value_loss,
            loss_actor=loss_actor,
            entropy=entropy,
        )

    return learner_fn


## Trainer Setup
The learner setup initializes components for training: the learner function, neural network, optimizer, environment, and states. It creates a function for learning, employs parallel processing over the cores for efficiency, and sets up initial states.

In [6]:
def learner_setup(
    env: Environment, rngs: chex.Array, config: Dict
) -> Tuple[callable, ActorCritic, LearnerState]:
    """Initialise learner_fn, network, optimiser, environment and states."""
    # Get available TPU cores.
    n_devices = len(jax.devices())
    # Get number of actions and agents.
    num_actions = int(env.action_spec().num_values[0])
    num_agents = env.action_spec().shape[0]
    config["num_agents"] = num_agents
    # PRNG keys.
    rng, rng_p = rngs

    # Define network and optimiser.
    network = ActorCritic(num_actions)
    optim = optax.chain(
        optax.clip_by_global_norm(config["max_grad_norm"]),
        optax.adam(config["lr"], eps=1e-5),
    )

    # Initialise observation.
    obs = env.observation_spec().generate_value()
    # Select only obs for a single agent.
    init_x = ObservationGlobalState(
        agents_view=obs.agents_view[0],
        action_mask=obs.action_mask[0],
        global_state=obs.global_state,
        step_count=obs.step_count[0],
    )
    init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x)

    # initialise params and optimiser state.
    params = network.init(rng_p, init_x)
    opt_state = optim.init(params)

    # Vmap network apply function over number of agents.
    vmapped_network_apply_fn = jax.vmap(
        network.apply,
        in_axes=(None, ObservationGlobalState(1, 1, None, 1)),
        out_axes=(1, 1),
    )

    # Get batched iterated update and replicate it to pmap it over cores.
    learn = get_learner_fn(env, vmapped_network_apply_fn, optim.update, config)
    learn = jax.pmap(learn, axis_name="device")

    # Broadcast params and optimiser state to cores and batch.
    broadcast = lambda x: jnp.broadcast_to(x, (n_devices, config["update_batch_size"]) + x.shape)
    params = jax.tree_map(broadcast, params)
    opt_state = jax.tree_map(broadcast, opt_state)

    # Initialise environment states and timesteps.
    rng, *env_rngs = jax.random.split(
        rng, n_devices * config["update_batch_size"] * config["num_envs"] + 1
    )
    env_states, timesteps = jax.vmap(env.reset, in_axes=(0))(
        jnp.stack(env_rngs),
    )

    # Split rngs for each core.
    rng, *step_rngs = jax.random.split(rng, n_devices * config["update_batch_size"] + 1)
    # Add dimension to pmap over.
    reshape_step_rngs = lambda x: x.reshape((n_devices, config["update_batch_size"]) + x.shape[1:])
    step_rngs = reshape_step_rngs(jnp.stack(step_rngs))
    reshape_states = lambda x: x.reshape(
        (n_devices, config["update_batch_size"], config["num_envs"]) + x.shape[1:]
    )
    env_states = jax.tree_util.tree_map(reshape_states, env_states)
    timesteps = jax.tree_util.tree_map(reshape_states, timesteps)

    init_learner_state = LearnerState(params, opt_state, step_rngs, env_states, timesteps)
    return learn, network, init_learner_state

# Rendering and logging tools

## Rendering
The `render_one_episode` function simulates and visualizes one episode using a trained MAPPO model that will be passed to the function using `params`.

In [20]:
def render_one_episode(config, params, seed) -> Tuple:
    """Render one episode using trained MAPPO"""

    def env_step(episode_state, apply_fn):
        """Step the environment."""
        # PRNG keys.
        rng, env_state, last_timestep, step_count_, return_, states = episode_state

        # Select action.
        rng, _rng = jax.random.split(rng)
        pi, _ = apply_fn(params, last_timestep.observation)

        if config["evaluation_greedy"]:
            action = pi.mode()
        else:
            action = pi.sample(seed=_rng)

        # Step environment.
        env_state, timestep = env.step(env_state, action)

        # Log episode metrics.
        return_ += timestep.reward
        step_count_ += 1
        states.append(env_state)
        episode_state = (rng, env_state, timestep, step_count_, return_, states)
        return episode_state

    # Network
    env = jumanji.make(config["env_name"])
    env = RwareMultiAgentWithGlobalStateWrapper(env)
    num_actions = int(env.action_spec().num_values[0])
    network = ActorCritic(num_actions)
    vmapped_network_apply_fn = jax.vmap(
        network.apply,
        in_axes=(None, ObservationGlobalState(0, 0, None, 0)),
    )

    # Rng
    rng = jax.random.PRNGKey(seed)

    # Build and Initialise env
    state, timestep=env.reset(rng)

    states = []
    episode_state = (rng, state, timestep, 0, 0, states)
    while not episode_state[2].last():
      episode_state= env_step(episode_state, vmapped_network_apply_fn)

    # Record and print results
    rng, env_state, last_timestep, step_count_, return_, states = episode_state
    env.animate(states=states, save_path="./rware.gif")
    print(f"{Fore.CYAN}{Style.BRIGHT}EPISODE RETURN: {return_}{Style.RESET_ALL}")
    print(f"{Fore.CYAN}{Style.BRIGHT}EPISODE LENGTH:{step_count_}{Style.RESET_ALL}")

##Logging:
The `plot_performance` function visualizes the performance of the algorithm, this plot will be refreshed each time evaluation interval happens!

In [21]:
def plot_performance(metrics, ep_returns, start_time):
      plt.figure(figsize=(8, 4))
      clear_output(wait=True)

      ep_returns.append(metrics.episodes_info["episode_return"].mean())
      # Plot the data
      plt.plot(np.linspace(0, (time.time()-start_time)/ 60.0, len(list(ep_returns))),list(ep_returns))
      plt.xlabel('Run Time [Minutes]')
      plt.ylabel('Episode Return')
      plt.title(f'Robotic Warehouse with 4 Agents')
      # Show the plot
      plt.show()
      return ep_returns

# Exeperiment Setup (function and Hyperparameters)


The `run_experiment` function executes an experiment by training MAPPO and evaluating its performance:

The function creates environments, sets up the learner and evaluator, and calculates the total timesteps for training. It then trains the model, evaluates it, plots the performance, and updates the learner state. After completing the specified number of evaluations, it returns the trained parameters.

In [22]:
def run_experiment(config: Dict) -> None:
    """Runs experiment."""
    # Create envs
    env = jumanji.make(config["env_name"])
    env = RwareMultiAgentWithGlobalStateWrapper(env)
    env = AutoResetWrapper(env)
    env = LogWrapper(env)
    eval_env = jumanji.make(config["env_name"])
    eval_env = RwareMultiAgentWithGlobalStateWrapper(eval_env)

    # PRNG keys.
    rng, rng_e, rng_p = jax.random.split(jax.random.PRNGKey(config["seed"]), num=3)

    # Setup learner.
    learn, network, learner_state = learner_setup(env, (rng, rng_p), config)

    # Setup evaluator.
    evaluator, _,(trained_params, eval_rngs) = evaluator_setup(
        eval_env=eval_env,
        rng_e=rng_e,
        network=network,
        params=learner_state.params,
        config=config,
        centralised_critic=True,
    )

    # Calculate total timesteps.
    n_devices = len(jax.devices())
    config["num_updates_per_eval"] = config["num_updates"] // config["num_evaluation"]
    timesteps_per_training = (
        n_devices
        * config["num_updates_per_eval"]
        * config["rollout_length"]
        * config["update_batch_size"]
        * config["num_envs"]
    )

    # Run experiment for a total number of evaluations.
    start_time=time.time()
    ep_returns=[]
    for i in range(config["num_evaluation"]):
        # Train.
        with TimeIt(
            tag=("COMPILATION" if i == 0 else "EXECUTION"),
            environment_steps=timesteps_per_training,
        ):
            learner_output = learn(learner_state)
            jax.block_until_ready(learner_output)


        # Prepare for evaluation.
        trained_params = jax.tree_util.tree_map(
            lambda x: x[:, 0, ...], learner_output.learner_state.params
        )
        rng_e, *eval_rngs = jax.random.split(rng_e, n_devices + 1)
        eval_rngs = jnp.stack(eval_rngs)
        eval_rngs = eval_rngs.reshape(n_devices, -1)

        # Evaluate.
        evaluator_output = evaluator(trained_params, eval_rngs)
        jax.block_until_ready(evaluator_output)
        ep_returns=plot_performance(evaluator_output, ep_returns, start_time)

        # Update runner state to continue training.
        learner_state = learner_output.learner_state

    # Return trained params to be used for rendering or testing.
    trained_params= jax.tree_util.tree_map(
            lambda x: x[0, 0, ...], learner_output.learner_state.params
        )
    return trained_params

##Config

The provided config dictionary sets various hyperparameters for the experiment

In [23]:
config = {
        "lr": 2.5e-4,
        "update_batch_size": 2,
        "rollout_length": 128, # Number of steps per episode rollout.
        "num_updates": 1000, # Total number of training updates.
        "num_envs": 64,  # Number of parallel environments.
        "ppo_epochs": 4,
        "num_minibatches": 2,
        "gamma": 0.99,
        "gae_lambda": 0.95,
        "clip_eps": 0.2,
        "ent_coef": 0.01,
        "vf_coef": 0.5,
        "max_grad_norm": 0.5,
        "env_name": "RobotWarehouse-v0",
        "num_eval_episodes": 32, # Number of episodes for evaluation.
        "num_evaluation": 30, # Number of evaluation runs.
        "evaluation_greedy": False, # Whether to use a greedy policy during evaluation.
        "seed":42
    }

# Run Experiment

Now we train MAPPO on `small-4ag-easy` scenarion from RobotWarehouse

In [None]:
# Run experiment
trained_params=run_experiment(config)
print(f"{Fore.CYAN}{Style.BRIGHT}MAPPO experiment completed{Style.RESET_ALL}")

Now let's render one episode using the trained system

In [None]:
render_one_episode(config, trained_params, 42)

In [None]:
import IPython
from IPython.display import Image
Image(filename='/content/rware.gif',embed=True)