# Training Goal-Conditioned and Unsupervised RL Agents in Brax

In [Brax Training](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb) we tried out [gym](https://gym.openai.com/)-like environments and PPO, SAC, evolutionary search, and trajectory optimization algorithms. We can build various RL algorithms on top of these ultra-fast implementations. This colab runs a family of [variational GCRL](https://arxiv.org/abs/2106.01404) algorithms, which includes [goal-conditioned RL](http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.51.3077) and [DIAYN](https://arxiv.org/abs/1802.06070) as special cases. Let's try it out!

This provides a bare bone implementation based on minimal modifications to the
baseline [PPO](https://github.com/google/brax/blob/main/brax/training/ppo.py),
enabling training in a few minutes. More features, tunings, and benchmarked results will be added soon.



```
# This is formatted as code
```

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/vgcrl.ipynb)

In [None]:
#@title Colab setup and imports
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'TPU'** in the dropdown.

from datetime import datetime
import functools
import os

import jax
import jax.numpy as jnp
from IPython.display import HTML, clear_output
import matplotlib.pyplot as plt

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax.io import file as io_file
from brax.io import html
from brax.experimental.composer import composer
from brax.experimental.braxlines.common import evaluators
from brax.experimental.braxlines.common import logger_utils
from brax.experimental.braxlines.training import ppo
from brax.experimental.braxlines.vgcrl import evaluators as vgcrl_evaluators
from brax.experimental.braxlines.vgcrl import utils as vgcrl_utils

if "COLAB_TPU_ADDR" in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

In [None]:
#@title Define experiment parameters

#@markdown **Task Parameters**
#@markdown 
#@markdown As in [DIAYN](https://arxiv.org/abs/1802.06070)
#@markdown and [variational GCRL](https://arxiv.org/abs/2106.01404),
#@markdown we assume some task knowledge about interesting dimensions
#@markdown of the environment `obs_indices` defined by `env_space`.
#@markdown This is also used for evaluation and visualization.
#@markdown
#@markdown When the **task parameters** are the same, the metrics computed by
#@markdown [vgcrl/evaluators.py](https://github.com/google/brax/blob/main/brax/experimental/braxlines/vgcrl/evaluators.py)
#@markdown are directly comparable across experiment runs with different
#@markdown **experiment parameters**. 
env_name = 'ant'  # @param ['ant', 'halfcheetah', 'ant_cheetah', 'uni_ant', 'bi_ant']
env_space = 'vel'  # @param ['vel', 'pos', 'ang']
env_scale = 5.0 #@param{'type': 'number'}
obs_indices = {
    'vel': { # x-y velocity
      'ant': (13, 14),
      'humanoid': (22, 23),
      'halfcheetah': (11,),
      'uni_ant': (('vel:torso_ant1', 0),('vel:torso_ant1', 1)),
      'bi_ant': (('vel:torso_ant1', 0),('vel:torso_ant2', 0)),
    },
    'ang': { # angular velocity
      'ant': (17,),
      'uni_ant': (('ang:torso_ant1', 2),),
    },
}[env_space][env_name]

#@markdown **Experiment Parameters**
#@markdown See below and [vgcrl/utils.py](https://github.com/google/brax/blob/main/brax/experimental/braxlines/vgcrl/utils.py)
algo_name = 'diayn'  # @param ['gcrl', 'cdiayn', 'diayn', 'diayn_full', 'fixed_gcrl']
logits_clip_range = 5.0  # @param {'type': 'number'}
normalize_obs_for_disc = False  # @param {'type': 'boolean'}
seed =   0# @param {type: 'integer'}
diayn_num_skills = 8  # @param {type: 'integer'}
spectral_norm = True  # @param {type: 'boolean'}
output_path = None  # @param {'type': 'raw'}
output_path = output_path.format(
    date=datetime.now().strftime('%Y%m%d'),
    env_space=env_space,
    env_scale=env_scale,
    env_name=env_name,
    algo_name=algo_name) if output_path else None

In [None]:
# @title Visualizing Brax environments
# Create baseline environment to get observation specs
base_env_fn = composer.create_fn(env_name=env_name)
base_env = base_env_fn()
env_obs_size = base_env.observation_size

# Create discriminator-parameterized environment
disc_fn = vgcrl_utils.create_disc_fn(algo_name=algo_name,
                   observation_size=env_obs_size,
                   obs_indices=obs_indices,
                   scale=env_scale,
                   diayn_num_skills = diayn_num_skills,
                   logits_clip_range=logits_clip_range,
                   spectral_norm=spectral_norm)
disc = disc_fn(env=base_env, normalize_obs=normalize_obs_for_disc)
extra_params = disc.init_model(rng=jax.random.PRNGKey(seed=seed))
env_fn = vgcrl_utils.create_fn(env_name=env_name, disc=disc)

# Visualize
env = env_fn()
jit_env_reset = jax.jit(env.reset)
state = jit_env_reset(rng=jax.random.PRNGKey(seed=seed))
clear_output()  # clear out jax.lax warning before rendering
HTML(html.render(env.sys, [state.qp]))

In [None]:
#@title Training
num_timesteps_multiplier = 4  # @param {type: 'integer'}
disc_update_ratio = 1.0  # @param {'type': 'number'}

tab = logger_utils.Tabulator(output_path=f'{output_path}/training_curves.csv', append=False)

# We determined some reasonable hyperparameters offline and share them here.
n = num_timesteps_multiplier
train_fn = functools.partial(
    ppo.train,
    num_timesteps=50_000_000 * n,
    log_frequency=20,
    reward_scaling=10,
    episode_length=1000,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=5,
    num_minibatches=32,
    num_update_epochs=4,
    discounting=0.95,
    learning_rate=3e-4,
    entropy_cost=1e-2,
    num_envs=2048,
    batch_size=1024)

times = [datetime.now()]
plotdata = {}
plotkeys = ['eval/episode_reward', 'losses/disc_loss']

def plot(output_path:str =None, output_name:str = 'training_curves'):
  num_figs = len(plotkeys)
  fig, axs = plt.subplots(ncols=num_figs, figsize=(3.5 * num_figs, 3))
  for i, key in enumerate(plotkeys):
    if key in plotdata:
      axs[i].plot(plotdata[key]['x'], plotdata[key]['y'])
    axs[i].set(xlabel='# environment steps', ylabel=key)
    axs[i].set_xlim([0, train_fn.keywords['num_timesteps']])
  fig.tight_layout()
  if output_path:
    with io_file.File(f'{output_path}/{output_name}.png', 'wb') as f:
      plt.savefig(f)

def progress(num_steps, metrics, _):
  times.append(datetime.now())
  for key, v in metrics.items():
    plotdata[key] = plotdata.get(key, dict(x=[], y=[]))
    plotdata[key]['x'] += [num_steps]
    plotdata[key]['y'] += [v]
  # the first step does not include losses
  if num_steps > 0:
    tab.add(num_steps=num_steps, **metrics)
    tab.dump()
  clear_output(wait=True)
  plot()
  plt.show()

extra_loss_fns = dict(disc_loss=disc.disc_loss_fn) if extra_params else None
extra_loss_update_ratios = dict(
    disc_loss=disc_update_ratio) if extra_params else None
inference_fn, params, _ = train_fn(
    environment_fn=env_fn,
    progress_fn=progress,
    extra_params=extra_params,
    extra_loss_fns=extra_loss_fns,
    extra_loss_update_ratios=extra_loss_update_ratios,
)
clear_output(wait=True)
plot(output_path=output_path)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

In [None]:
#@title Visualizing a trajectory of the learned inference function
#@markdown If `z_value` is `None`, sample `z`, else fix `z` to `z_value`.
z_value = 1  # @param {'type': 'raw'}
eval_seed = 0  # @param {'type': 'integer'}

z = {
    'fixed_gcrl': jnp.ones(env.z_size) * z_value,
    'gcrl': jnp.ones(env.z_size) * z_value,
    'cdiayn': jnp.ones(env.z_size) * z_value,
    'diayn': jax.nn.one_hot(jnp.array(int(z_value)), env.z_size),
    'diayn_full': jax.nn.one_hot(jnp.array(int(z_value)), env.z_size),
}[algo_name] if z_value is not None else None

states = evaluators.visualize_env(
    env_fn,
    inference_fn,
    params,
    batch_size=0,
    seed = eval_seed,
    reset_args=(z,),
    step_args=(params[0], params[-1],),
    output_path=output_path,
    output_name=f'video_z_{z_value}',
)
HTML(html.render(env.sys, [state.qp for state in states]))

In [None]:
#@title Visualizing skills of the learned inference function in 2D plot
num_samples_per_z = 5  # @param {type: 'integer'}
time_subsampling = 10  # @param {type: 'integer'}
time_last_n = 500 # @param {type: 'integer'}
eval_seed = 0  # @param {type: 'integer'}

vgcrl_evaluators.visualize_skills(
    env_fn,
    inference_fn,
    obs_indices,
    params,
    env_scale,
    algo_name,
    output_path,
    verbose=True,
    num_samples_per_z=num_samples_per_z,
    time_subsampling=time_subsampling,
    time_last_n=time_last_n,
    seed=eval_seed)
plt.show()