# Welcome to Multi-agent Training using IPPO!
You will be responsible for completing two functions below (you can search for `NotImplementedError` functions to skip straight to them!) Once you have instantiated those, you will be able to run all cells and begin training your IPPO policies!

This tutorial will walk you through code for teaching a robotic ant to walk!


![GIF](https://gymnasium.farama.org/_images/ant.gif)

In this assignment, we will treat each of the ant's 4 legs as seperate agents:

![Diagram](https://robotics.farama.org/_images/ant_2x4.png)

As such, Agent 1 will control joints 0 and 1, Agent 2 will control joints 2 and 3, etc.

# Intial Setup
In this section we just install the necessary packages and download the `demo.xml` file from the shared drive.

In [1]:
# !pip install --upgrade pip                # Update pip
# !pip install equinox brax distreqx        # Install neccesary libraries
# !gdown 1ulG1WBbTFxkr2N9Yuf7DbGWSjBZBN04j  # Download demo.xml file

# Multi-Agent Reinforcement Learning
This code is set up to perform the neccesary inputs once and be able to run all cells below to develop the classes and helper functions needed to train policies. Each agent in the scenario will be modelled by an ActorCritic object, which internally handles the parameters for its Multi-Layer-Preceptron(MLP)/Artifical-Neural-Network(ANN) machine learning model.

We also have a training environment class `IPPO_Ant_Env` which takes in our description of the Ant robot in `demo.xml` and creates a high-fidelity simulator which can be vectorized (meaning we can make many many copies of the training environment at once and load them onto a GPU for very fast computation!)

We then have seperate training functions **(some of which you will have to fill in as part of your assignment)** which uses our ActorCritic networks to sample trajectories from our IPPO_Ant_env training environment and update the network parameters!

## Imports

In [2]:
# Environment Wrapper
import os, json
from datetime import datetime

import jax
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Array, PRNGKeyArray
import optax
import equinox as eqx
from tensorboardX import SummaryWriter
import mujoco as mj # we use mj to change some foundational things in the simulation
from brax import math # this contains some useful things, like safe_norm and quaternion utils
from typing import Dict, Literal, Optional, Tuple, Callable, List, Union, NamedTuple, Sequence
from collections import OrderedDict

# Mujoco/BRAX requires our env inherits from PipelineEnv, and the state used is State
from brax.envs.base import PipelineEnv, State

# mjcf interprets Mujoco XML files for BRAX, html will render rollouts of BRAX States
from brax.io import mjcf, html

import distreqx.distributions as dist
import dataclasses

## Training Code (Assignment Functions are here!)

In [None]:
def make_train(env, train_config, rng_init):
    # create a directory for saving the model and logs
    current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    save_path = f'data/ippo_ant/{current_datetime}'
    os.makedirs(save_path, exist_ok=False)
    writer = SummaryWriter(log_dir=save_path)
    with open(os.path.join(save_path, 'config.json'), 'w') as f:
        json.dump(train_config, f, indent=4)

    train_config["NUM_ACTORS"] = env.num_agents * train_config["NUM_ENVS"]
    train_config["NUM_UPDATES"] = (
        train_config["TOTAL_TIMESTEPS"] // train_config["NUM_STEPS"] // train_config["NUM_ENVS"]
    )
    train_config["MINIBATCH_SIZE"] = (
        train_config["NUM_ACTORS"] * train_config["NUM_STEPS"] // train_config["NUM_MINIBATCHES"]
    )

    def linear_schedule(count):
        frac = 1.0 - (count // (train_config["NUM_MINIBATCHES"] * train_config["UPDATE_EPOCHS"])) / train_config["NUM_UPDATES"]
        return config["LR"] * frac

    network = ActorCritic(
        key=rng_init,
        actor_layer_sizes=[env.observation_space(env.agents[0]).shape[0], 64, 64, env.action_space(env.agents[0]).shape[0]],
        critic_layer_sizes=[env.observation_space(env.agents[0]).shape[0], 64, 64, 1],
        actor_kernel_init=[jnp.sqrt(2), jnp.sqrt(2), 0.01],
        critic_kernel_init=[jnp.sqrt(2), jnp.sqrt(2), 1],
        activation=jax.nn.tanh,
    )

    if config["ANNEAL_LR"]:
        opt = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(learning_rate=linear_schedule, eps=1e-5),
        )
    else:
        opt = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(config["LR"], eps=1e-5)
        )

    opt_state = opt.init(network)

    def train(rng):
        # INIT ENV
        rng, _rng = jr.split(rng)
        reset_rng = jr.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset)(reset_rng)

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                network, opt_state, env_state, last_obs, update_count, rng = runner_state
                obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
                # SELECT ACTION
                rng, _rng = jr.split(rng)
                mean, scale, value = eqx.filter_vmap(network)(obs_batch)

                pi = eqx.filter_vmap(dist.MultivariateNormalDiag)(mean, scale)
                pi_log_prob = lambda d, a: d.log_prob(a)  # helper for filter_vmap
                action = pi.sample(_rng)
                log_prob = eqx.filter_vmap(pi_log_prob)(pi, action)

                env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents)

                # STEP ENV
                rng, _rng = jr.split(rng)
                rng_step = jr.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(env.step)(
                    rng_step, env_state, env_act,
                )

                info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
                transition = Transition(
                    batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(),
                    action,
                    value,
                    batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
                    log_prob,
                    obs_batch,
                    info,
                )
                runner_state = (network, opt_state, env_state, obsv, update_count, rng)
                return runner_state, transition

            runner_state, traj_batch = filter_scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )
            # CALCULATE ADVANTAGE
            network, opt_state, env_state, last_obs, update_count, rng = runner_state

            last_obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
            _, _, last_val = eqx.filter_vmap(network)(last_obs_batch)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]:

                    gae, next_value = gae_and_next_value       
                    done, value, reward = transition.done, transition.value, transition.reward
                    gamma      = config["GAMMA"] * (1.0 - done)
                    gae_lambda = config["GAE_LAMBDA"]

                    delta = reward + gamma * next_value - value
                    gae = delta + gamma * gae_lambda * gae

                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=8,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(network, traj_batch, gae, targets):
                        # RERUN NETWORK
                        mean, scale, value = eqx.filter_vmap(network)(traj_batch.obs)
                        pi = eqx.filter_vmap(dist.MultivariateNormalDiag)(mean, scale)
                        pi_log_prob = lambda d, a: d.log_prob(a)  # helper for filter_vmap
                        log_prob = eqx.filter_vmap(pi_log_prob)(pi, traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)

                        # def _loss_actor(log_prob, log_prob_k, normalized_gae, clip_epsilon):
                        #   raise NotImplementedError('Actor Loss not implemented yet!')
                        #   # TODO: USE THE INPUTS TO THIS FUNCTION TO CALCULATE THE ACTOR LOSS

                        #   return loss_actor
                        def _loss_actor(log_prob, log_prob_k, normalized_gae, clip_epsilon):
                            ratio        = jnp.exp(log_prob - log_prob_k)                     # r_t(θ)
                            surrogate_1  = ratio * normalized_gae
                            surrogate_2  = jnp.clip(ratio, 1.0 - clip_epsilon,
                                                           1.0 + clip_epsilon) * normalized_gae
                            loss_clipped = jnp.minimum(surrogate_1, surrogate_2)              # element-wise
                            return loss_clipped, ratio
                        
                        loss_actor, ratio = _loss_actor(log_prob, traj_batch.log_prob, gae, config["CLIP_EPS"])
                        # loss_actor = -_loss_actor(log_prob, traj_batch.log_prob, gae, config["CLIP_EPS"]).mean()
                        loss_actor = -loss_actor.mean()

                        pi_entropy = lambda d: d.entropy()
                        entropy = eqx.filter_vmap(pi_entropy)(pi).mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy, ratio)

                    network, opt_state = train_state
                    grad_fn = eqx.filter_value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(network, traj_batch, advantages, targets)
                    updates, opt_state = opt.update(grads, opt_state)
                    network = eqx.apply_updates(network, updates)

                    loss_info = {
                        "total_loss": total_loss[0],
                        "actor_loss": total_loss[1][1],
                        "critic_loss": total_loss[1][0],
                        "entropy": total_loss[1][2],
                        "ratio": total_loss[1][3],
                    }
                    return (network, opt_state), loss_info
                network, opt_state, traj_batch, advantages, targets, rng = update_state

                rng, _rng = jr.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ACTORS"]
                ), "batch size must be equal to number of steps * number of actors"
                permutation = jr.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree.map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree.map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                minibatches = jax.tree.map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                (network, opt_state), loss_info = filter_scan(
                    _update_minbatch, (network, opt_state), minibatches
                )
                update_state = (network, opt_state, traj_batch, advantages, targets, rng)
                return update_state, loss_info

            def callback(metric):
                step = metric["update_step"]
                for key, value in metric.items():
                    if key != "update_step":
                        writer.add_scalar(key, value, step)

            update_state = (network, opt_state, traj_batch, advantages, targets, rng)

            update_state, loss_info = filter_scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            network, opt_state = update_state[0], update_state[1]

            metric = traj_batch.info
            rng = update_state[-1]

            update_count = update_count + 1
            r0 = {"ratio0": loss_info["ratio"][0,0].mean()}
            loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
            metric = jax.tree.map(lambda x: x.mean(), metric)
            metric["update_step"] = update_count
            metric["env_step"] = update_count * config["NUM_STEPS"] * config["NUM_ENVS"]
            metric = {**metric, **loss_info, **r0}
            jax.experimental.io_callback(callback, None, metric)
            runner_state = (network, opt_state, env_state, last_obs, update_count, rng)
            return runner_state, metric

        rng, _rng = jr.split(rng)
        runner_state = (network, opt_state, env_state, obsv, jnp.array(0), _rng)
        runner_state, metric = filter_scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )

        current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        eqx.tree_serialise_leaves(f"data/training_network_{current_datetime}.eqx", network)
        return {"runner_state": runner_state, "metrics": metric, "network": network}

    return train

## Our Actor-Critic Network Class

In [4]:
class ActorCritic(eqx.Module):
    """
    This is a feed forward actor critic combined module. This is useful for IPPO
    where we modify both the actor and critic simultaneously and therefore can
    group the actor and critic together here and use optax to get an optimization
    state across the entire eqx.Module
    """

    # Learned variables
    actor_layers: Tuple[eqx.nn.Linear, ...]
    critic_layers: Tuple[eqx.nn.Linear, ...]
    log_std: jax.Array

    # Static parameters
    activation: Callable = eqx.field(static=True)

    def __init__(
        self,
        key: jr.PRNGKey,
        actor_layer_sizes: List[int] = [6, 64, 64, 2],
        critic_layer_sizes: List[int] = [6, 64, 64, 1],
        actor_kernel_init: List[float] = [jnp.sqrt(2), jnp.sqrt(2), 0.01],
        critic_kernel_init: List[float] = [jnp.sqrt(2), jnp.sqrt(2), 1.0],
        activation: Callable = jax.nn.relu,
    ):
        self.activation = activation
        actor_key, critic_key = jr.split(key)

        # —— actor network ——
        actor_keys = jr.split(actor_key, num=len(actor_layer_sizes))
        self.actor_layers = []

        for i, (in_f, out_f) in enumerate(
            zip(actor_layer_sizes[:-1], actor_layer_sizes[1:])
        ):
            layer = eqx.nn.Linear(in_f, out_f, key=actor_keys[i])

            # (Re‑)initialise using orthogonal(scale) + constant(0) bias
            wkey, _ = jr.split(actor_keys[i])
            weight = jax.nn.initializers.orthogonal(actor_kernel_init[i])(
                wkey, (out_f, in_f), jnp.float32
            )
            bias = jnp.zeros((out_f,), dtype=jnp.float32)

            # Update the layer – eqx Modules are frozen, so use object.__setattr__
            object.__setattr__(layer, "weight", weight)
            object.__setattr__(layer, "bias", bias)

            self.actor_layers.append(layer)

        # —— critic network ——
        critic_keys = jr.split(critic_key, len(critic_layer_sizes) - 1)
        self.critic_layers = []

        for i, (in_f, out_f) in enumerate(
            zip(critic_layer_sizes[:-1], critic_layer_sizes[1:])
        ):
            layer = eqx.nn.Linear(in_f, out_f, key=critic_keys[i])

            wkey, _ = jr.split(critic_keys[i])
            weight = jax.nn.initializers.orthogonal(critic_kernel_init[i])(
                wkey, (out_f, in_f), jnp.float32
            )
            bias = jnp.zeros((out_f,), dtype=jnp.float32)

            object.__setattr__(layer, "weight", weight)
            object.__setattr__(layer, "bias", bias)

            self.critic_layers.append(layer)

        # —— learnable log‑std parameter ——
        self.log_std = jnp.zeros((actor_layer_sizes[-1],))  # broadcasted over batch at runtime

    def __call__(self, x: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array]:
        """Returns (policy_dist, value_estimate) for an input batch `x`."""

        # —— actor ——
        h = x
        for layer in self.actor_layers[:-1]:
            h = self.activation(layer(h))
        actor_mean = self.actor_layers[-1](h)                        # (B, action_dim)
        actor_scale = jnp.exp(self.log_std)

        # —— critic ——
        h = x
        for layer in self.critic_layers[:-1]:
            h = self.activation(layer(h))
        value = jnp.squeeze(self.critic_layers[-1](h), axis=-1)      # (B,)

        return actor_mean, actor_scale, value


def filter_scan(f, init, xs, length=None, reverse=False, unroll=1):
    """
    although simple to implement equinox does not by default include a filter_scan
    function; see: https://github.com/patrick-kidger/equinox/issues/709
    """

    # Partition the initial carry and sequence inputs into dynamic and static parts
    init_dynamic, init_static = eqx.partition(init, eqx.is_array)
    xs_dynamic, xs_static = eqx.partition(xs, eqx.is_array)

    # Define the scanned function, handling the combination and partitioning
    def scanned_fn(carry_dynamic, x_dynamic):
        # Combine dynamic and static parts for the carry and input
        carry = eqx.combine(carry_dynamic, init_static)
        x = eqx.combine(x_dynamic, xs_static)

        # Apply the original function
        out_carry, out_y = f(carry, x)

        # Partition the outputs into dynamic and static parts
        out_carry_dynamic, out_carry_static = eqx.partition(out_carry, eqx.is_array)
        out_y_dynamic, out_y_static = eqx.partition(out_y, eqx.is_array)

        # Return dynamic outputs and wrap static outputs using Static to prevent tracing
        return out_carry_dynamic, (out_y_dynamic, eqx.internal.Static((out_carry_static, out_y_static)))

    # Use lax.scan with the modified scanned function
    final_carry_dynamic, (ys_dynamic, static_out) = jax.lax.scan(
        scanned_fn, init_dynamic, xs_dynamic, length=length, reverse=reverse, unroll=unroll
    )

    # Extract static outputs
    out_carry_static, ys_static = static_out.value

    # Combine dynamic and static parts of the outputs
    final_carry = eqx.combine(final_carry_dynamic, out_carry_static)
    ys = eqx.combine(ys_dynamic, ys_static)

    return final_carry, ys



#### Helper Functions

In [5]:
"""
Built off Gymnax spaces.py, this module contains jittable classes for action and
observation spaces.
"""
class Space(object):
    """
    Minimal jittable class for abstract jaxmarl space.
    """

    def sample(self, rng: PRNGKeyArray) -> Array:
        raise NotImplementedError

    def contains(self, x: jnp.int_) -> bool: # pyright: ignore
        raise NotImplementedError

class Box(Space):
	"""
	Minimal jittable class for array-shaped gymnax spaces.
	Add unboundedness - sampling from other distributions, etc.
	"""
	def __init__(
		self,
		low: float,
		high: float,
		shape: Tuple[int],
		dtype: jnp.dtype = jnp.float32,
	):
		self.low = low
		self.high = high
		self.shape = shape
		self.dtype = dtype

	def sample(self, rng: PRNGKeyArray) -> Array:
		"""Sample random action uniformly from 1D continuous range."""
		return jax.random.uniform(
			rng, shape=self.shape, minval=self.low, maxval=self.high
		).astype(self.dtype)

	def contains(self, x: jnp.int_) -> bool: # pyright: ignore
		"""Check whether specific object is within space."""
		# type_cond = isinstance(x, self.dtype)
		# shape_cond = (x.shape == self.shape)
		range_cond = jnp.logical_and(
			jnp.all(x >= self.low), jnp.all(x <= self.high)
		)
		return range_cond

class Transition(NamedTuple):
  done: jnp.ndarray
  action: jnp.ndarray
  value: jnp.ndarray
  reward: jnp.ndarray
  log_prob: jnp.ndarray
  obs: jnp.ndarray
  info: jnp.ndarray

def batchify(x: dict, agent_list, num_actors):
    max_dim = max([x[a].shape[-1] for a in agent_list])
    def pad(z):
        return jnp.concatenate([z, jnp.zeros(z.shape[:-1] + (max_dim - z.shape[-1],))], -1)
    x = jnp.stack([x[a] if x[a].shape[-1] == max_dim else pad(x[a]) for a in agent_list])
    return x.reshape((num_actors, -1))

def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}

class LogEnvState(NamedTuple):
    env_state: State
    episode_returns: float
    episode_lengths: int
    returned_episode_returns: float
    returned_episode_lengths: int
    step_in_episode: jnp.ndarray
    first_pipeline_state: State
    first_obs: jnp.ndarray
    truncation: jnp.ndarray

## IPPO Training Environment

In [None]:
# LEAVE EVERYTHING IN THIS CELL UNCHANGED FOR ASSIGNMENT


# we must map each agent to their observations and actions in the whole environment
# This is a helper function to convert the ranges of observations
def listerize(ranges: List[Union[int, Tuple[int, int]]]) -> List[int]:
    """
    tuples mean that all global observations indexed from (0, 5) are included in that agents observation
    integers add that single global observation to the agent's observation
    in this case I am passing global observations to each agent for simplicity.
    Here is an example:

    ranges = {
        "agent_0": [(0,8)], # example: [(0, 5), 6, 7] == [0, 1, 2, 3, 4, 5, 6, 7]
        "agent_1": [(0,8)] # example: [(2, 5), 9, 10] == [2, 3, 4, 5, 9, 10]
    }
    agent_obs_mapping = {k: jnp.array(listerize(v)) for k, v in ranges.items()} # _agent_observation_mapping[env_name]
    """
    return [
        i
        for r in ranges
        for i in (range(r[0], r[1] + 1) if isinstance(r, tuple) else [r])
    ]

# ================================ #
# Our Multi-agent Ant Training Environment!
# ================================ #
class IPPO_Ant_Env(PipelineEnv):
    def __init__(
        self,
        mode: Literal["centralized", "decentralized"] = "decentralized",

        # base env settings
        xml_path='demo.xml',
        backend='positional', # this is important for the ant
        ctrl_cost_weight=0.5,
        use_contact_forces=False,
        contact_cost_weight=5e-4,
        healthy_reward=1.0,
        terminate_when_unhealthy=True,
        healthy_z_range=(0.2, 1.0),
        contact_force_range=(-1.0, 1.0),
        reset_noise_scale=0.1,
        exclude_current_positions_from_observation=True,

        # multi agent env settings
        episode_length: int = 1000,
        action_repeat: int = 1,
        auto_reset: bool = True,
        homogenisation_method: Optional[Literal["max", "concat"]] = None,

        # logging settings
        replace_info: bool = False,
        **kwargs
    ):

        # ============================= #
        # Mujoco/BRAX System Init #
        # ============================= #

        # brax provides utils to load xml file into mujoco
        sys = mjcf.load(xml_path)

        # n_frames are the number of simulation timesteps between new actions
        # in the interim we zero-order-hold the previous action
        n_frames = 5

        # in the case of the Ant the physics backed matters, and they reduce the physics timestep
        # from 0.01 to 0.005 and double n_frames to keep the interaction timestep the same.
        # I speculate that the increased fidelity of the timestep for these backends is required
        # for reliable, robust physics calculations
        if backend in ['spring', 'positional']:
            sys = sys.tree_replace({'opt.timestep': 0.005})
            n_frames = 10

        # again more examples of using the "sys" to modify fundamental parts of the simulator.
        if backend == 'mjx':
            sys = sys.tree_replace({
                'opt.solver': mj.mjtSolver.mjSOL_NEWTON, # this determins how we solve contact forces
                'opt.disableflags': mj.mjtDisableBit.mjDSBL_EULERDAMP, # I believe this damping term stabilizes some dynamics
                'opt.iterations': 1, # number of optimization iterations for contact
                'opt.ls_iterations': 4, # number of line search iterations per optimization iteration for contact
            })

        # again more examples of using the "sys" to modify fundamental parts of the simulator
        if backend == 'positional':
            # does the same actuator strength work as in spring
            sys = sys.replace(
                actuator=sys.actuator.replace(
                    gear=200 * jnp.ones_like(sys.actuator.gear)
                )
            )

        # the kwargs are designed to pass options to the brax backend, in this case we
        # modify the kwargs n_frames according to the above code before passing it to super
        kwargs['n_frames'] = kwargs.get('n_frames', n_frames)

        # actually creating the BRAX system
        super().__init__(sys=sys, backend=backend, **kwargs)

        # various parameters for the Ant environment
        self._ctrl_cost_weight = ctrl_cost_weight
        self._use_contact_forces = use_contact_forces
        self._contact_cost_weight = contact_cost_weight
        self._healthy_reward = healthy_reward
        self._terminate_when_unhealthy = terminate_when_unhealthy
        self._healthy_z_range = healthy_z_range
        self._contact_force_range = contact_force_range
        self._reset_noise_scale = reset_noise_scale
        self._exclude_current_positions_from_observation = (
            exclude_current_positions_from_observation
        )
        if self._use_contact_forces:
            raise NotImplementedError('use_contact_forces not implemented.')


        # ================================== #
        # Multi-Agent Environment Init #
        # ================================== #

        # UNCHANGED logging settings
        self.replace_info = replace_info
        self.episode_length = episode_length
        self.action_repeat = action_repeat
        self.auto_reset = auto_reset
        self.homogenisation_method = homogenisation_method

        # tuples mean that all global observations indexed from (0, 5) are included in that agents observation
        # integers add that single global observation to the agent's observation
        # in this case I am passing global observations to each agent for simplicity
        ranges = {
            # below is an example, also check out the listerize helper function above
            "agent_0": [(0, 5), 6, 7, 9, 11, (13, 18), 19, 20],
            "agent_1": [(0, 5), 7, 8, 9, 11, (13, 18), 21, 22],
            "agent_2": [(0, 5), 7, 9, 10, 11, (13, 18), 23, 24],
            "agent_3": [(0, 5), 7, 9, 11, 12, (13, 18), 25, 26],
        }
        self.agent_obs_mapping = {k: jnp.array(listerize(v)) for k, v in ranges.items()} # _agent_observation_mapping[env_name]

        # the agent action mapping is simpler, so we just use the indices of the actions
        self.agent_action_mapping = {
            "agent_0": jnp.array([0, 1]),
            "agent_1": jnp.array([2, 3]),
            "agent_2": jnp.array([4, 5]),
            "agent_3": jnp.array([6, 7]),
        }
        self.agents = list(self.agent_obs_mapping.keys())

        # ======================================================== #
        # UNCHANGED: Boilerplate That Doesn't Require Modification #
        # ======================================================== #

        # setup the obs and action spaces for each agent
        self.num_agents = len(self.agent_obs_mapping)
        obs_sizes = {
            agent: self.num_agents
            + max([o.size for o in self.agent_obs_mapping.values()])
            if homogenisation_method == "max"
            else self.env.observation_size
            if homogenisation_method == "concat"
            else obs.size
            for agent, obs in self.agent_obs_mapping.items()
        }
        act_sizes = {
            agent: max([a.size for a in self.agent_action_mapping.values()])
            if homogenisation_method == "max"
            else self.env.action_size
            if homogenisation_method == "concat"
            else act.size
            for agent, act in self.agent_action_mapping.items()
        }
        self.observation_spaces = {
            agent: Box(-jnp.inf, jnp.inf, shape=(obs_sizes[agent],),)
            for agent in self.agents
        }
        self.action_spaces = {
            agent: Box(-1.0, 1.0, shape=(act_sizes[agent],),)
            for agent in self.agents
        }

        # utility function to batchify floats originally placed in JaxMARLWrapper
        self._batchify_floats = lambda x: jnp.stack([x[a] for a in self.agents])

        # utility function to get obs, action spaces for each agent - required by the ppo algs
        self.observation_space = lambda agent: self.observation_spaces[agent]
        self.action_space = lambda agent: self.action_spaces[agent]

    # ======================================== #
    # design global observation function #
    # ======================================== #

    def get_global_obs(self, pipeline_state: State) -> jax.Array:

        # the pipeline_state is an attribute of env_state, which in turn is an attribute of state.
        # pipeline_state is what defines the physical simulation state, primarily through
        # pipeline_state.q and pipeline_state.qd which are the generalized positions and
        # velocities respectively. The rest of pipeline_state is read_only and useful
        # for creating expressive observations - as we can do in this function

        qpos = pipeline_state.q
        qvel = pipeline_state.qd

        if self._exclude_current_positions_from_observation:
            qpos = pipeline_state.q[2:]

        return jnp.concatenate([qpos] + [qvel])

    # ================================================= #
    # design pipeline_state random reset function #
    # ================================================= #

    def get_random_pipeline_state(self, rng):

        # here you must decide how to randomly instantiate the generalized positions and
        # velocities q, and qd of the pipeline_state and then create the pipeline_state
        # to be used in the reset function and the automatic reset functionality later on.
        # I will explain the automatic reset later! fear not!

        rng, rng1, rng2 = jax.random.split(rng, 3)

        low, hi = -self._reset_noise_scale, self._reset_noise_scale
        q = self.sys.init_q + jax.random.uniform(
            rng1, (self.sys.q_size(),), minval=low, maxval=hi
        )
        qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),))

        pipeline_state = self.pipeline_init(q, qd)
        return pipeline_state

    def reset(self, rng: jr.PRNGKey) -> Tuple[Dict[str, jax.Array], State]:
        # =========================== #
        # ResetS the environment #
        # =========================== #

        # reset the global environment as you see fit for your environment. You
        # should end up with a new env_state: State and agent_obs: Dict[str, jax.Array]
        # as the below example does. NOTE you should also include the exact same info
        # dict in the resulting env_state as the example does.

        pipeline_state = self.get_random_pipeline_state(rng); rng, _ = jr.split(rng)
        global_obs = self.get_global_obs(pipeline_state)

        reward, done, zero = jnp.zeros(3)
        metrics = {
            'reward_forward': zero,
            'reward_survive': zero,
            'reward_ctrl': zero,
            'reward_contact': zero,
            'x_position': zero,
            'y_position': zero,
            'distance_from_origin': zero,
            'x_velocity': zero,
            'y_velocity': zero,
            'forward_reward': zero,
        }

        # NOTE it is essential that info is created like this and added to env_state
        info = {
            "returned_episode_returns": jnp.zeros(self.num_agents),
            "returned_episode_lengths": jnp.zeros(self.num_agents),
            "returned_episode": jnp.zeros(self.num_agents).astype(jnp.bool_)
        }

        env_state = State(pipeline_state, global_obs, reward, done, metrics, info)

        agent_obs = self.map_global_obs_to_agents(global_obs)

        # ============================= #
        # UNCHANGED: log state wrapping #
        # ============================= #

        # NOTE we change the "first_pipeline_state" and "first_obs" at every
        # usage, therefore we generate a new pair here to be used as the first
        # upon the next automatic reset in step - I will explain the automatic reset
        # later! fear not!

        new_first_pipeline_state = self.get_random_pipeline_state(rng)
        new_first_obs = self.get_global_obs(new_first_pipeline_state)

        # the struct we use to log the agent observations and the env state
        log_state = LogEnvState(
            env_state,
            jnp.zeros((self.num_agents,)),
            jnp.zeros((self.num_agents,)),
            jnp.zeros((self.num_agents,)),
            jnp.zeros((self.num_agents,)),
            jnp.zeros((), jnp.int32), # the env step number in the current rollout
            new_first_pipeline_state,
            new_first_obs,
            jnp.array(0.)
        )

        return agent_obs, log_state

    def step(
        self,
        rng: jr.PRNGKey, # this is not used in our deterministic env
        state: State, # this is the LogEnvState
        actions: Dict[str, jax.Array], # this is the agentic actions
    ) -> Tuple[
        Dict[str, jax.Array], State, Dict[str, float], Dict[str, bool], Dict
    ]:

        # We first ensure that states that were previously done (that already
        # have had their states reset) have their done flag reset
        state = state._replace(
            env_state=state.env_state.replace(
                done=jnp.zeros_like(state.env_state.done)
            )
        )

        # =============================================================== #
        # global env_state step (reward, obs, state, metrics, done) #
        # =============================================================== #
        # here we calculate the global reward, obs, state, metrics, and the global_done
        # NOTE the global_done is only for early termination (the end of episode termination
        # situation is handled later automatically - look through the remainder of this
        # method to understand)

        # we get the global actions
        global_action = self.map_agents_to_global_action(actions)

        # save the old pipeline_state to calculate some velocities
        pipeline_state0 = state.env_state.pipeline_state
        assert pipeline_state0 is not None

        # get the NEXT pipeline state yaaay
        next_pipeline_state = self.pipeline_step(state.env_state.pipeline_state, global_action)  # type: ignore

        # environment specific calculations
        velocity = (next_pipeline_state.x.pos[0] - pipeline_state0.x.pos[0]) / self.dt
        forward_reward = velocity[0]

        min_z, max_z = self._healthy_z_range

        is_healthy = jnp.where(next_pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0)
        is_healthy = jnp.where(next_pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy)

        if self._terminate_when_unhealthy:
            healthy_reward = self._healthy_reward
        else:
            healthy_reward = self._healthy_reward * is_healthy
        ctrl_cost = self._ctrl_cost_weight * jnp.sum(jnp.square(global_action))
        contact_cost = 0.0

        # finalise the global_obs, global_reward, and global_done (just for early termination)
        global_obs = self.get_global_obs(next_pipeline_state)
        global_reward = forward_reward + healthy_reward - ctrl_cost - contact_cost
        global_done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0

        # finally we update the environment specific metrics that you have chosen to track
        state.env_state.metrics.update(
            reward_forward=forward_reward,
            reward_survive=healthy_reward,
            reward_ctrl=-ctrl_cost,
            reward_contact=-contact_cost,
            x_position=next_pipeline_state.x.pos[0, 0],
            y_position=next_pipeline_state.x.pos[0, 1],
            distance_from_origin=math.safe_norm(next_pipeline_state.x.pos[0]),
            x_velocity=velocity[0],
            y_velocity=velocity[1],
            forward_reward=forward_reward,
        )

        # agent obs is formed AFTER we decide if the env is reset or not as this will
        # change the state which we observe

        # ==================================================================== #
        # UNCHANGED: Episode Wrapping Step (brax.envs.wrappers.EpisodeWrapper) #
        # ==================================================================== #

        step_in_episode = state.step_in_episode + self.action_repeat

        global_done = jnp.where(step_in_episode >= self.episode_length, jnp.ones_like(state.env_state.done), global_done)
        state = state._replace(
            truncation = jnp.where(
                step_in_episode >= jnp.array(self.episode_length), 1 - global_done, jnp.zeros_like(state.env_state.done)
            )
        )

        # ===================================== #
        # Agentic Rewards Dones Transform #
        # ===================================== #
        # default behaviour is just to give all agents same global reward
        reward = {agent: global_reward for agent in self.agents}
        reward["__all__"] = global_reward
        done = {agent: global_done.astype(jnp.bool_) for agent in self.agents}
        done["__all__"] = global_done.astype(jnp.bool_)

        # create new env_state here in ase global rewards or global dones rely on agentic things
        env_state = state.env_state.replace(
            pipeline_state=next_pipeline_state,
            obs=global_obs,
            reward=global_reward,
            done=global_done
        )

        # ============================================================================== #
        # UNCHANGED: Auto Reset Wrapping Post-Step (brax.envs.wrappers.AutoResetWrapper) #
        # ============================================================================== #
        def where_done(x, y):
            done = env_state.done
            if done.shape:
                done = jnp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1))  # type: ignore
            return jnp.where(done, x, y)


        next_pipeline_state = jax.tree.map(
            where_done, state.first_pipeline_state, env_state.pipeline_state
        )

        global_obs = jax.tree.map(where_done, state.first_obs, global_obs)
        env_state = env_state.replace(pipeline_state=next_pipeline_state, obs=global_obs)
        agent_obs = self.map_global_obs_to_agents(global_obs)

        first_pipeline_state = self.get_random_pipeline_state(rng); rng, _ = jr.split(rng)
        first_obs = self.get_global_obs(first_pipeline_state)
        state = state._replace(
            first_pipeline_state=first_pipeline_state,
            first_obs=first_obs
        )

        if self.auto_reset is True:
            step_in_episode = jnp.where(env_state.done, jnp.zeros_like(step_in_episode), step_in_episode)
            state = state._replace(step_in_episode=step_in_episode)

        # ======================= #
        # UNCHANGED: Log wrapping #
        # ======================= #
        ep_done = done["__all__"]
        new_episode_return = state.episode_returns + self._batchify_floats(reward)
        new_episode_length = state.episode_lengths + 1
        state = state._replace(
            env_state=env_state,
            episode_returns=new_episode_return * (1 - ep_done),
            episode_lengths=new_episode_length * (1 - ep_done),
            returned_episode_returns=state.returned_episode_returns * (1 - ep_done)
            + new_episode_return * ep_done,
            returned_episode_lengths=state.returned_episode_lengths * (1 - ep_done)
            + new_episode_length * ep_done,
            step_in_episode=step_in_episode
        )

        info = env_state.info
        if self.replace_info:
            info = {}
        info["returned_episode_returns"] = state.returned_episode_returns
        info["returned_episode_lengths"] = state.returned_episode_lengths
        info["returned_episode"] = jnp.full((self.num_agents,), ep_done)

        return agent_obs, state, reward, done, info

    # ================================================================ #
    # UNCHANGED: mapping agent actions and obs to and from global actions and obs #
    # ================================================================ #
    def map_agents_to_global_action(
        self, agent_actions: Dict[str, jnp.ndarray]
    ) -> jnp.ndarray:
        global_action = jnp.zeros(self.action_size)
        for agent_name, action_indices in self.agent_action_mapping.items():
            if self.homogenisation_method == "max":
                global_action = global_action.at[action_indices].set(
                    agent_actions[agent_name][: action_indices.size]
                )
            elif self.homogenisation_method == "concat":
                global_action = global_action.at[action_indices].set(
                    agent_actions[agent_name][action_indices]
                )
            else:
                global_action = global_action.at[action_indices].set(
                    agent_actions[agent_name]
                )
        return global_action

    def map_global_obs_to_agents(self, global_obs: jax.Array) -> Dict[str, jax.Array]:
        """Maps the global observation vector to the individual agent observations.
        Args:
            global_obs: The global observation vector.
        Returns:
            A dictionary mapping agent names to their observations. The mapping method
            is determined by the homogenisation_method parameter.
        """
        agent_obs = {}
        for agent_idx, (agent_name, obs_indices) in enumerate(
            self.agent_obs_mapping.items()
        ):
            if self.homogenisation_method == "max":
                # Vector with the agent idx one-hot encoded as the first num_agents
                # elements and then the agent's own observations (zero padded to
                # the size of the largest agent observation vector)
                agent_obs[agent_name] = (
                    jnp.zeros(
                        self.num_agents
                        + max([v.size for v in self.agent_obs_mapping.values()])
                    )
                    .at[agent_idx]
                    .set(1)
                    .at[agent_idx + 1 : agent_idx + 1 + obs_indices.size]
                    .set(global_obs[obs_indices])
                )
            elif self.homogenisation_method == "concat":
                # Zero vector except for the agent's own observations
                # (size of the global observation vector)
                agent_obs[agent_name] = (
                    jnp.zeros(global_obs.shape)
                    .at[obs_indices]
                    .set(global_obs[obs_indices])
                )
            else:
                # Just agent's own observations
                agent_obs[agent_name] = global_obs[obs_indices]
        return agent_obs

def IPPO_generate_rollout(env, rng, jit=True, num_timesteps=100):
    agent_obs, state = env.reset(rng=rng); rng, _rng = jr.split(rng)
    rollout = []
    if jit is True:
        env_step = jax.jit(env.step)
    else:
        env_step = env.step
    ctrl = jr.uniform(_rng, shape=(env.action_size), minval=-1, maxval=1)

    agents = env.agents
    agent_ctrls = {"agent_0": ctrl[:2], "agent_1": ctrl[2:4], "agent_2": ctrl[4:6], "agent_3": ctrl[6:]}
    for i in range(num_timesteps):
        print(f"ctrl action chosen: {ctrl}")
        agent_obs, state, reward, done, info = env_step(rng, state, agent_ctrls)
        print(f"state.done: {done}")
        print(f"state.reward: {reward}")
        rollout.append(state.env_state.pipeline_state)
        print(f"step: {i}")
        print(f"info: {info}")
        if i == env.episode_length:
            print("THE PRIOR STATE.DONE SHOULD HAVE BEEN TRUE")
    current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    save_name = f'data/rollouts/{current_datetime}.html'
    os.makedirs("data/rollouts", exist_ok=True)
    with open(save_name, 'w') as f:
        f.write(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))

# def generate_rollout(env, policy, rng, jit=True, num_timesteps=100):
#     agent_obs, state = env.reset(rng=rng); rng, _rng = jr.split(rng)
#     rollout = []
#     if jit is True:
#         env_step = jax.jit(env.step)
#     else:
#         env_step = env.step
#     # ctrl = network(agent_obs)
#     ctrl = policy(agent_obs)

#     agents = env.agents
#     agent_ctrls = {"agent_0": ctrl[:2], "agent_1": ctrl[2:4], "agent_2": ctrl[4:6], "agent_3": ctrl[6:]}
#     # agent_ctrls = {agents[i]: ctrl[:2], "agent_1": ctrl[2:]}
#     for i in range(num_timesteps):
#         print(f"ctrl action chosen: {ctrl}")
#         agent_obs, state, reward, done, info = env_step(rng, state, agent_ctrls)
#         print(f"state.done: {done}")
#         print(f"state.reward: {reward}")
#         rollout.append(state.env_state.pipeline_state)
#         print(f"step: {i}")
#         print(f"info: {info}")
#         if i == env.episode_length:
#             print("THE PRIOR STATE.DONE SHOULD HAVE BEEN TRUE")
#     current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
#     save_name = f'data/rollouts/{current_datetime}.html'
#     os.makedirs("data/rollouts", exist_ok=True)
#     with open(save_name, 'w') as f:
#         f.write(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))

def generate_rollout(env, policy, rng, jit=True, num_timesteps=100):
    agent_obs, state = env.reset(rng=rng)
    rng, _rng = jr.split(rng)
    rollout = []

    env_step = jax.jit(env.step) if jit else env.step

    def get_actions(agent_obs_dict):
        actions = []
        for agent_id in env.agents:
            obs = agent_obs_dict[agent_id].reshape(1, -1)  # (1, obs_dim)
            mean, scale, value = eqx.filter_vmap(policy)(obs)
            actions.append(mean[0])
        return jnp.concatenate(actions)

    ctrl = get_actions(agent_obs)
    agent_ctrls = {
        "agent_0": ctrl[:2],
        "agent_1": ctrl[2:4],
        "agent_2": ctrl[4:6],
        "agent_3": ctrl[6:]
    }

    for i in range(num_timesteps):
        agent_obs, state, reward, done, info = env_step(rng, state, agent_ctrls)
        rollout.append(state.env_state.pipeline_state)
        ctrl = get_actions(agent_obs)
        agent_ctrls = {
            "agent_0": ctrl[:2],
            "agent_1": ctrl[2:4],
            "agent_2": ctrl[4:6],
            "agent_3": ctrl[6:]
        }

    current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    save_name = f'data/rollouts/{current_datetime}.html'
    os.makedirs("data/rollouts", exist_ok=True)
    with open(save_name, 'w') as f:
        f.write(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))


## Perform IPPO Training

Here we can begin training our IPPO policies! The first time you run this cell, set `debug_mode` to `True` so that you can ensure that all code has been compiled correctly. **Note: This will print out a lot of information**

If there are no errors, turn `debug_mode` to `False` to train your IPPO policies!

# ***REMEMBER TO DOWNLOAD YOUR `data/` FOLDER AFTER TRAINING!***

In [7]:
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # dynamically allocate memory like pytorch does
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["MUJOCO_GL"] = "egl"                       # if you have NVIDIA + EGL drivers for rendering

jax.config.update('jax_default_matmul_precision', "highest") # sometimes certain contact dynamics need higher accuracy to prevent nans

# this set of configs lets us cache more values to lower JIT and thus compute times
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")

config = {
    "LR": 1e-3,
    "NUM_ENVS": 64,
    "NUM_STEPS": 300,
    "TOTAL_TIMESTEPS": 1e7,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 4,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 2e-6,
    "VF_COEF": 4.5,
    "MAX_GRAD_NORM": 0.5,
    "SEED": 0,
    "ANNEAL_LR": True,
    "DEVICE": 0,
    "DISABLE_JIT": False,
}

rng = jr.PRNGKey(config["SEED"])
rng, _rng = jr.split(rng)
env = IPPO_Ant_Env()
agent_obs, state = env.reset(jr.PRNGKey(0))

debug_mode = False  # set to True to debug the environment, False to train
if(debug_mode):
  jax.config.update("jax_debug_nans", True)   # will error if a nan is detected
  jax.config.update("jax_log_compiles", True) # will print out recompilations

  # a single test step
  action = {"agent_0": jnp.array([0.5, 0.5]), "agent_1": jnp.array([-0.5, -0.5]), "agent_2": jnp.array([-0.5, -0.5]), "agent_3": jnp.array([-0.5, -0.5])}
  agent_obs, state, reward, done, info = env.step(jr.PRNGKey(0), state, action)

  # a test rollout
  IPPO_generate_rollout(env, jr.PRNGKey(0), jit=True, num_timesteps=700)

else:
  jax.config.update("jax_debug_nans", False)
  jax.config.update("jax_log_compiles", False)

  # Start Training!
  print("INFO: training started!")
  config["ENV_NAME"] = env.__class__.__name__
  train = make_train(env, config, _rng)
  train_jit = jax.jit(train, device=jax.devices()[config["DEVICE"]])
  out = train_jit(rng)
  print(out)

  print("INFO: training complete")



INFO: training started!
{'metrics': {'actor_loss': Array([-1.12140214e-03, -7.33750989e-04, -5.76914113e-04, -2.98127910e-04,
       -6.34284923e-04, -1.30952033e-03, -1.24479982e-03, -8.60543747e-04,
       -3.63127328e-04, -1.45370024e-04,  2.73191574e-04, -7.59777759e-05,
       -2.81912275e-04, -5.58715139e-04, -4.19545482e-04, -4.70658095e-04,
       -3.44070868e-04, -3.09201860e-04, -3.08899267e-04, -1.52517678e-04,
       -1.86100064e-04, -3.15444137e-04, -2.27219862e-04, -2.43391813e-04,
       -1.70153638e-04, -1.41900207e-04, -2.75313476e-04, -1.84186720e-04,
       -1.11115325e-04, -2.11455801e-04, -1.10984598e-04, -1.41910074e-04,
       -1.30015906e-04, -1.83560376e-04, -1.86068166e-04, -1.61261181e-04,
       -9.95705050e-05, -1.31487934e-04, -6.11129217e-05, -1.15360548e-04,
       -1.23354839e-04, -1.18051088e-04, -6.57164928e-05, -1.04671606e-04,
       -7.27884908e-05, -9.73453862e-05, -9.81140038e-05, -1.12852831e-04,
       -8.09247285e-05, -9.68022941e-05, -6.94216

In [13]:
!tensorboard --logdir data/ippo_ant/ --port 6006

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.16.2 at http://localhost:6006/ (Press CTRL+C to quit)
E0801 13:12:19.863596 139850669860416 _internal.py:97] Error on request:
Traceback (most recent call last):
  File "/home/eason/.local/lib/python3.10/site-packages/werkzeug/serving.py", line 370, in run_wsgi
    execute(self.server.app)
  File "/home/eason/.local/lib/python3.10/site-packages/werkzeug/serving.py", line 331, in execute
    application_iter = app(environ, start_response)
  File "/home/eason/.local/lib/python3.10/site-packages/tensorboard/backend/application.py", line 528, in __call__
    return self._app(environ, start_response)
  File "/home/eason/.local/l

In [36]:
policy = out['network']
generate_rollout(env, policy, rng, jit=True, num_timesteps=520)