# Visualizing Experiments in Brax

This notebook is adapted from https://github.com/google/brax/blob/main/notebooks/training.ipynb 

In [2]:
from IPython.display import HTML, clear_output
from IPython.display import Image, Video

import numpy as np
import jax
import jax.numpy as jnp
from jax import config
import matplotlib.pyplot as plt
import wandb
import brax
from brax import envs
from brax import jumpy as jp
from brax.io import html
from brax.io import model
from brax.training.acme import running_statistics

from private_envs import multipusher, turningant, original_multipusher
from ppo import networks as ppo_networks

import brax_utils
import predictor

In [None]:
wandb.init(mode="disabled")
# We use T=100 for all our experiments
T = 100

First let's pick an environment to train an agent:

In [None]:
env_name = "pusher"
if env_name == "pusher":
    env = original_multipusher.MultiPusher(freeze_balls=True)
else:
    env = turningant.TurningAnt()
print(env.action_size)
print(env.observation_size)
state = env.reset(rng=jp.random_prngkey(seed=3))

# Uncomment to see the environment
# HTML(html.render(env.sys, [state.qp]))

The trainers return an inference function, parameters, and the final set of metrics gathered during evaluation.

# Saving and Loading Policies

Brax can save and load trained policies:

In [None]:
network_factory = ppo_networks.make_ppo_networks
normalize = running_statistics.normalize
ppo_network = network_factory(
    env.observation_size, env.action_size, preprocess_observations_fn=normalize
)

flag = False # set to true if using "constrained_ant"
if env_name == "ant" and flag == True:
    layers = {"policy_hidden_layer_sizes": (64,) * 3, "value_hidden_layer_sizes": (64,) * 3}
else:
    layers = {"policy_hidden_layer_sizes": (32,) * 4, "value_hidden_layer_sizes": (256,) * 5}
ppo_network = network_factory(
    env.observation_size, env.action_size, preprocess_observations_fn=normalize, **layers
)

make_policy = ppo_networks.make_inference_fn(ppo_network)
make_inference_fn = make_policy


Load in desired parameters

In [None]:
params_path = 'params/constrained_pusher'
params = model.load_params(params_path)
inference_fn = make_inference_fn(params)

# Visualizing a Policy's Behavior

We can use the policy to generate a rollout for visualization:

In [None]:
#@title Visualizing a trajectory of the learned inference function

# create an env with auto-reset
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)

rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)

for _ in range(T):
    rollout.append(state)
    act_rng, rng = jax.random.split(rng)
    act, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_env_step(state, act)

# The blue ball position is the hidden state, i.e., down is u=0 and up is u=1
HTML(html.render(env.sys, [s.qp for s in rollout]))

## Compute the full-trajectory MI of a model

In [None]:
# Set some hyperparameters
predictor.SEQUENCE_LENGTH = T
predictor.VOCAB_SIZE = env.action_size
predictor.env = env

# TRANSFORMER ARGS
transformer_config = {
    "vocab_size": predictor.VOCAB_SIZE,
    "output_vocab_size": predictor.OUTPUT_VOCAB_SIZE,
    "emb_dim": predictor.EMB_SIZE,
    "num_heads": predictor.NUM_HEADS,
    "qkv_dim": predictor.EMB_SIZE,
    "mlp_dim": predictor.EMB_SIZE,
    "num_layers": predictor.NUM_LAYERS,
    "max_len": predictor.SEQUENCE_LENGTH,
    "kernel_init": predictor.w_init,
    "logits_via_embedding": False,
}

predictor.full_trajectory_MI(params_path, seed=0, model_config=transformer_config, layers={})

## Save videos and frames of a model's rollouts

In [None]:
brax_utils.make_video(params, make_inference_fn, env, T=100, flip_camera=env_name != 'ant', curr_seed=4, n_seeds=1,
                      width=800, height=600, save_frames=[57],
                      frame_name='pusher_constrained_u=0', video_name='pusher_constrained_u=0.mp4'
)

brax_utils.make_video(params, make_inference_fn, env, T=100, flip_camera=env_name != 'ant', curr_seed=1, n_seeds=1,
                      hfov=50, width=1600, height=1200, save_frames=[],
                      frame_name='ant_constrained_u=1', video_name='ant_constrained_u=1.mp4'
)

In [None]:
Video("videos/pusher_constrained_u=0.mp4")

In [None]:
Image('frames/pusher_constrained_u=0_frame_57.png')