# Training Goal-Conditioned and Unsupervised RL Agents in Brax

In [Brax Training](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb) we tried out [gym](https://gym.openai.com/)-like environments and PPO, SAC, evolutionary search, and trajectory optimization algorithms. We can build various RL algorithms on top of these ultra-fast implementations. This colab runs a family of [variational GCRL](https://arxiv.org/abs/2106.01404) algorithms, which includes [goal-conditioned RL](http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.51.3077) and [DIAYN](https://arxiv.org/abs/1802.06070) as special cases. Let's try it out!

This provides a bare bone implementation based on minimal modifications to the
baseline [PPO](https://github.com/google/brax/blob/main/brax/training/ppo.py),
enabling training in a few minutes. More features, tunings, and benchmarked results will be added soon.

In [None]:
#@title Colab setup and imports
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'TPU'** in the dropdown.

from datetime import datetime
import functools
import os

import jax
import jax.numpy as jnp
from jax.tools import colab_tpu
from IPython.display import HTML, clear_output 
import matplotlib.pyplot as plt

try:
  import brax
except ImportError:
  !pip install git+https://github.com/frt03/brax.git@features/spectral_norm
  clear_output()
  import brax

from brax import envs
from brax.io import html
from brax.experimental.braxlines.training import ppo
from brax.experimental.braxlines.vgcrl import utils as vgcrl_utils

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

In [None]:
#@title Visualizing pre-included Brax environments { run: "auto" }
#@markdown Note: Similarly to how experiments were run in
#@markdown [DIAYN](https://arxiv.org/abs/1802.06070) 
#@markdown and [variational GCRL](https://arxiv.org/abs/2106.01404),
#@markdown we assume some prior knowledge about interesting dimensions
#@markdown of the environment `env_indices` (except `exp_name`='diayn_full').
#@markdown This is also used for skill visualization later.

env_name = "ant"  # @param ['ant', 'halfcheetah']
exp_name = "diayn"  # @param ['gcrl', 'cdiayn', 'diayn', 'diayn_full']
diayn_num_skills =   8# @param {type: 'integer'}
logits_clip_range = 10.0# @param {'type': 'number'} 
env_indices = {
    'ant': (13, 14),  # x-y velocities
    'humanoid': (22, 23),  # x-y velocities
    'halfcheetah': (11,),  # x velocity
}[env_name]
base_env_fn = envs.create_fn(env_name=env_name)
base_env = base_env_fn()
env_obs_size = base_env.observation_size

disc_fn = {
    'gcrl': functools.partial(
        vgcrl_utils.Discriminator,
        q_fn='indexing',
        z_size=len(env_indices), 
        q_fn_params=dict(indices=env_indices),
        dist_p_params = dict(scale=2.),
        dist_q_params = dict(scale=2.),
        ),
    'cdiayn': functools.partial(
        vgcrl_utils.Discriminator,
        q_fn='indexing_mlp', 
        z_size=len(env_indices), 
        fn_params=dict(
            indices = env_indices,
            output_size = len(env_indices),
            ),
        ),
    'diayn': functools.partial(
        vgcrl_utils.Discriminator,
        q_fn='indexing_mlp_s', 
        z_size=diayn_num_skills,
        q_fn_params=dict(
            indices = env_indices,
            output_size=diayn_num_skills,
            ),
        dist_p = 'UniformCategorial',
        dist_q = 'Categorial',
        logits_clip_range=logits_clip_range,
        ),
    'diayn_full': functools.partial(
        vgcrl_utils.Discriminator,
        q_fn='mlp', 
        z_size=diayn_num_skills,
        q_fn_params=dict(
            input_size = env_obs_size,
            output_size=diayn_num_skills,
            ),
        dist_p = 'UniformCategorial',
        dist_q = 'Categorial',
        logits_clip_range=logits_clip_range,
        ),
}[exp_name]
disc = disc_fn()
extra_params = disc.init_model(rng=jax.random.PRNGKey(seed=0))
env_fn = vgcrl_utils.create_fn(env_name=env_name, disc=disc)
env = env_fn()

def visualize(sys, qps, save_path: str = None):
  """Renders a 3D visualization of the environment."""
  if save_path:
    html.save_html(save_path, sys, qps)
  return HTML(html.render(sys, qps))

jit_env_reset = jax.jit(env.reset)
state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))
clear_output()  # clear out jax.lax warning before rendering

# Visualize in 3D
visualize(env.sys, [state.qp])

In [None]:
#@title Training some pre-included Brax environments
num_timesteps_multiplier =  2# @param {type: 'integer'}
disc_update_ratio = 1.0# @param {'type': 'number'}

# We determined some reasonable hyperparameters offline and share them here.
n = num_timesteps_multiplier
train_fn = {
  'ant': functools.partial(
      ppo.train, num_timesteps = 50_000_000*n, log_frequency = 20,
      reward_scaling = 10, episode_length = 1000, normalize_observations = True,
      action_repeat = 1, unroll_length = 5, num_minibatches = 32,
      num_update_epochs = 4, discounting = 0.95, learning_rate = 3e-4,
      entropy_cost = 1e-2, num_envs = 2048, batch_size = 1024
  ),
  'halfcheetah': functools.partial(
      ppo.train, num_timesteps = 50_000_000*n, log_frequency = 10,
      reward_scaling = 1, episode_length = 1000, normalize_observations = True,
      action_repeat = 1, unroll_length = 20, num_minibatches = 32,
      num_update_epochs = 8, discounting = 0.95, learning_rate = 3e-4,
      entropy_cost = 0.001, num_envs = 2048, batch_size = 512
  ),
}[env_name]

times = [datetime.now()]
plotdata = {}
plotkeys = [
  'eval/episode_reward',
  'losses/disc_loss',
]

def progress(num_steps, metrics, _):
  times.append(datetime.now())
  for key, v in metrics.items():
    plotdata[key] = plotdata.get(key, dict(x=[], y=[]))
    plotdata[key]['x'] += [num_steps]
    plotdata[key]['y'] += [v]
  clear_output(wait=True)
  num_figs = len(plotkeys)
  fig, axs = plt.subplots(ncols=num_figs, figsize=(3.5*num_figs, 3))
  for i, key in enumerate(plotkeys):
    if key in plotdata:
      axs[i].plot(plotdata[key]['x'], plotdata[key]['y'])
    axs[i].set(xlabel='# environment steps', ylabel=key)
    axs[i].set_xlim([0, train_fn.keywords['num_timesteps']])
  fig.tight_layout()
  plt.show()

extra_loss_fns = dict(disc_loss=functools.partial(
    vgcrl_utils.disc_loss_fn, 
    disc=disc)) if extra_params else None
extra_loss_update_ratios = dict(disc_loss=disc_update_ratio
                                ) if extra_params else None
inference_fn, params, _ = train_fn(environment_fn=env_fn, 
                                   progress_fn=progress,
                                   extra_params=extra_params,
                                   extra_loss_fns=extra_loss_fns,
                                   extra_loss_update_ratios=extra_loss_update_ratios,
                                   )

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

In [None]:
#@title Visualizing a trajectory of the learned inference function
rollout_z = "fix"  # @param ['sample', 'fix']
diayn_index =   7# @param {type: 'integer'}
seed = 0  # @param {type: 'integer'}
save_path = '/tmp/{env_name}_{exp_name}_{index}.html' # @param {'type': 'raw'}
save_path = save_path.format(env_name=env_name, exp_name=exp_name, index=diayn_index) 

jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)
qps = []
states = []
z = {
    'gcrl':jnp.ones(env.z_size) * 2.,
    'cdiayn':jnp.ones(env.z_size),
    'diayn': jax.nn.one_hot(jnp.array(diayn_index), env.z_size),
    'diayn_full': jax.nn.one_hot(jnp.array(diayn_index), env.z_size),
         }[exp_name] if rollout_z == "fix" else None
state = jit_env_reset(rng=jax.random.PRNGKey(seed=seed), z=z)
while not state.done:
  qps.append(state.qp)
  states.append(state)
  act = jit_inference_fn(params, state.obs, state.rng)
  state = jit_env_step(state, act, params[-1])

visualize(env.sys, qps, save_path=save_path)

In [None]:
#@title Visualizing skills of the learned inference function in 2D plot
import numpy as np
from itertools import product
num_samples_per_z = 5  # @param {type: 'integer'}
time_subsampling = 10  # @param {type: 'integer'}
time_last_n = 500 # @param {type: 'integer'}
seed = 0  # @param {type: 'integer'}

O = env_obs_size
Z = env.z_size
M = num_samples_per_z

# Sample {D} z's
if exp_name in ('gcrl', 'vgcrl'):
  batch_z = jnp.array(list(product(*([[-1,1]] * env.param_size))))
elif exp_name in ('diayn', 'diayn_full'):
  batch_z = jax.nn.one_hot(jnp.arange(0, Z), Z)
D = batch_z.shape[0]

# Repeat each z by {M} times
batch_z = jnp.repeat(batch_z, M, axis=0) # [D*M, Z] 

# Reset and run environment
batch_env = env_fn(batch_size=D*M)
state = batch_env.reset(
    jnp.array([jax.random.PRNGKey(seed+i) for i in range(D*M)]),
    z=batch_z)
states = [state]
jit_step = jax.jit(batch_env.step)
jit_inference_fn = jax.jit(inference_fn)
while not state.done.all():
  act = jit_inference_fn(params, state.obs, state.rng[0])
  state = jit_step(state, act, params[-1])
  states.append(state)

# Get indices of interest
obses = jnp.stack([state.obs for state in states],
                  axis=0)[-time_last_n:][::time_subsampling] # [T, D*M, O+D]
print(f'T={obses.shape[0]}, O={O}, Z={Z}, D={D}, M={M}')
env_obses, _ = batch_env.disc.split_obs(obses) # [T, D*M, O]
env_vars = env_obses[..., env_indices] # [T, D*M, 1 or 2]
if env_vars.shape[-1] == 1:
  env_vars = jnp.concatenate([env_vars, jnp.zeros(env_vars.shape)], axis=-1)
assert env_vars.shape[1:] == (D*M, 2), f'{env_vars.shape} incompatible {(D*M,2)}'
env_vars = env_vars.reshape(-1, D, M, 2).swapaxes(
    1,2).reshape(-1, D, 2) # [T*M, D, 2]

# Plot
def spec(N):                                             
    t = np.linspace(-510, 510, N)                                              
    return np.clip(np.stack([-t, 510-np.abs(t), t], axis=1), 0, 255)/255.
colours = spec(D) # [D, 3]
colours = np.stack([colours for i in range(env_vars.shape[0])]) # [T*M, D, 3]
colours = colours.reshape(-1, 3) # [T*M*D, 3]
env_vars = env_vars.reshape(-1, 2) # [T*M*D, 2]
plt.scatter(x=env_vars[:, 0], y=env_vars[:, 1], c=colours)
plt.show()