# Training in Brax

Once an environment is created in brax, we can quickly train it using brax's built-in training algorithms. Let's try it out!

In [1]:
import functools
import jax
import os

from datetime import datetime
# from jax import numpy as jp
import matplotlib.pyplot as plt
import brax
from IPython.display import HTML, clear_output
# try:
#   import brax
# except ImportError:
#   !pip install git+https://github.com/google/brax.git@main
#   clear_output()
#   import brax

import flax
from brax.io import model
from brax.io import json
from brax.io import html
from learning.agents.apg import train as apg
# from brax.training.agents.sac import train as sac
from mujoco_playground import registry
# if 'COLAB_TPU_ADDR' in os.environ:
#   from jax.tools import colab_tpu
#   colab_tpu.setup_tpu()

First let's pick an environment and a backend to train an agent in. 

Recall from the [Brax Basics](https://github.com/google/brax/blob/main/notebooks/basics.ipynb) colab, that the backend specifies which physics engine to use, each with different trade-offs between physical realism and training throughput/speed. The engines generally decrease in physical realism but increase in speed in the following order: `generalized`,  `positional`, then `spring`.


In [None]:
#@title Load Env { run: "auto" }

env_name = 'HumanoidStand'  # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
env = registry.load(env_name)

state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

# HTML(html.render(env.mjx_model, [state.pipeline_state]))

In [10]:
state.data._impl.contact.elasticity

AttributeError: 'Contact' object has no attribute 'elasticity'

# Training

Brax provides out of the box the following training algorithms:

* [Proximal policy optimization](https://github.com/google/brax/blob/main/brax/training/agents/apg/train.py)
* [Soft actor-critic](https://github.com/google/brax/blob/main/brax/training/agents/sac/train.py)
* [Evolutionary strategy](https://github.com/google/brax/blob/main/brax/training/agents/es/train.py)
* [Analytic policy gradients](https://github.com/google/brax/blob/main/brax/training/agents/apg/train.py)
* [Augmented random search](https://github.com/google/brax/blob/main/brax/training/agents/ars/train.py)

Trainers take as input an environment function and some hyperparameters, and return an inference function to operate the environment.

# Training

Let's train the Ant policy using the `generalized` backend with apg.

In [None]:
#@title Training

# We determined some reasonable hyperparameters offline and share them here.
train_fn = {
  'inverted_pendulum': functools.partial(apg.train, num_evals=20, policy_updates=1e6, episode_length=1000, normalize_observations=True, action_repeat=1, learning_rate=3e-4,  num_envs=2048, seed=1),
  # 'inverted_double_pendulum': functools.partial(apg.train, num_evals=20,  episod_length=1000, normalize_observations=True, action_repeat=1, unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=2048, seed=1),
  # 'ant': functools.partial(apg.train,  num_evals=10,  episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096, batch_size=2048, seed=1),
  # 'humanoid': functools.partial(apg.train,  num_evals=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=10, num_minibatches=32, num_updates_per_batch=8, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=2048, seed=1),
  # 'reacher': functools.partial(apg.train, num_evals=20, episode_length=1000, normalize_observations=True, action_repeat=4, unroll_length=50, num_minibatches=32, num_updates_per_batch=8, discounting=0.95, learning_rate=3e-4, entropy_cost=1e-3, num_envs=2048, batch_size=256, max_devices_per_host=8, seed=1),
   'HumanoidStand': functools.partial(apg.train, num_evals=20, policy_updates=1e6, episode_length=1000, normalize_observations=True, action_repeat=1, learning_rate=6e-4, num_envs=2048, seed=1),
  # 'hopper': functools.partial(sac.train, num_evals=20, episode_length=1000, normalize_observations=True, action_repeat=1, discounting=0.997, learning_rate=6e-4, num_envs=128, grad_updates_per_step=64, max_devices_per_host=1, max_replay_size=1048576, min_replay_size=8192, seed=1),
  # 'walker2d': functools.partial(sac.train, num_evals=20, episode_length=1000, normalize_observations=True, action_repeat=1, discounting=0.997, learning_rate=6e-4, num_envs=128, batch_size=128, grad_updates_per_step=32, max_devices_per_host=1, max_replay_size=1048576, min_replay_size=8192, seed=1),
  # 'halfcheetah': functools.partial(apg.train, num_evals=20, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=20, num_minibatches=32, num_updates_per_batch=8, discounting=0.95, learning_rate=3e-4, entropy_cost=0.001, num_envs=2048, seed=3),
  # 'pusher': functools.partial(apg.train, num_evals=20, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=30, num_minibatches=16, num_updates_per_batch=8, discounting=0.95, learning_rate=3e-4,entropy_cost=1e-2, num_envs=2048, seed=3),
}[env_name]


# max_y = {'ant': 8000, 'halfcheetah': 8000, 'hopper': 2500, 'humanoid': 13000, 'humanoidstandup': 75_000, 'reacher': 5, 'walker2d': 5000, 'pusher': 0}[env_name]
# min_y = {'reacher': -100, 'pusher': -150}.get(env_name, 0)

xdata, ydata = [], []
times = [datetime.now()]

def progress(num_steps, metrics):
  times.append(datetime.now())
  xdata.append(num_steps)
  ydata.append(metrics['eval/episode_reward'])
  clear_output(wait=True)
  # plt.xlim([0, train_fn.keywords['num_timesteps']])
  # plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.plot(xdata, ydata)
  plt.show()

make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

env observation size 67
data contact Contact(dist=Traced<float32[16]>with<BatchTrace> with
  val = Traced<float32[128,16]>with<DynamicJaxprTrace>
  batch_dim = 0, pos=Traced<float32[16,3]>with<BatchTrace> with
  val = Traced<float32[128,16,3]>with<DynamicJaxprTrace>
  batch_dim = 0, frame=Traced<float32[16,3,3]>with<BatchTrace> with
  val = Traced<float32[128,16,3,3]>with<DynamicJaxprTrace>
  batch_dim = 0, includemargin=Traced<float32[16]>with<BatchTrace> with
  val = Traced<float32[128,16]>with<DynamicJaxprTrace>
  batch_dim = 0, friction=Traced<float32[16,5]>with<BatchTrace> with
  val = Traced<float32[128,16,5]>with<DynamicJaxprTrace>
  batch_dim = 0, solref=Traced<float32[16,2]>with<BatchTrace> with
  val = Traced<float32[128,16,2]>with<DynamicJaxprTrace>
  batch_dim = 0, solreffriction=Traced<float32[16,2]>with<BatchTrace> with
  val = Traced<float32[128,16,2]>with<DynamicJaxprTrace>
  batch_dim = 0, solimp=Traced<float32[16,5]>with<BatchTrace> with
  val = Traced<float32[128,16,

TypeError: scan body function carry input and carry output must have the same pytree structure, but they differ:

The input carry component state.data._impl.contact is a <class 'mujoco.mjx._src.types.Contact'> but the corresponding component of the carry output is a <class 'brax.base.Contact'>, so their Python types differ.

Revise the function so that the carry output has the same pytree structure as the carry input.

The trainers return an inference function, parameters, and the final set of metrics gathered during evaluation.

# Saving and Loading Policies

Brax can save and load trained policies:

In [None]:
model.save_params('/tmp/params', params)
params = model.load_params('/tmp/params')
inference_fn = make_inference_fn(params)

The trainers return an inference function, parameters, and the final set of metrics gathered during evaluation.

# Saving and Loading Policies

Brax can save and load trained policies:

In [None]:
#@title Visualizing a trajectory of the learned inference function

# create an env with auto-reset
# env = envs.create(env_name=env_name, backend=backend)
import mediapy as media
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)

rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
for _ in range(1000):
  rollout.append(state)
  act_rng, rng = jax.random.split(rng)
  act, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_env_step(state, act)
frames = env.render(rollout)
media.show_video(frames, fps=1.0 / env.dt)

🙌 See you soon!