# Training Divergence Minimization (D-min) RL  algorithms in Brax

In [Brax Training](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb) we tried out [gym](https://gym.openai.com/)-like environments and PPO, SAC, evolutionary search, and trajectory optimization algorithms. We can build various RL algorithms on top of these ultra-fast implementations. This colab runs a family of [adversarial inverse RL](https://arxiv.org/abs/1911.02256) algorithms, which includes [GAIL](https://papers.nips.cc/paper/2016/hash/cc7e2b878868cbae992d1fb743995d8f-Abstract.html) and [AIRL](https://arxiv.org/abs/1710.11248) as special cases. These algorithms minimize D(p(s,a), p\*(s,a)) or D(p(s), p\*(s)), the divergence D between the policy's state(-action) marginal distribution p(s,a) or p(s), and a given target distribution p\*(s,a) or p\*(s). As discussed in [f-MAX](https://arxiv.org/abs/1911.02256), these algorithms could also be used for [state-marginal matching](https://arxiv.org/abs/1906.05274) RL besides imitation learning. Let's try them out!

This provides a bare bone implementation based on minimal modifications to the
baseline [PPO](https://github.com/google/brax/blob/main/brax/training/ppo.py),
enabling training in a few minutes. More features, tunings, and benchmarked results will be added.



```
# This is formatted as code
```

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/dmin.ipynb)

In [None]:
#@title Install Brax and some helper modules
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime > Change runtime type, then select 'TPU' in the dropdown.

from datetime import datetime
import functools
import os

import jax
import jax.numpy as jnp
from IPython.display import HTML, clear_output 
import matplotlib.pyplot as plt
import numpy as np

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

import tensorflow_probability as tfp
from brax.io import html
from brax.experimental.composer import composer
from brax.experimental.braxlines import experiments
from brax.experimental.braxlines.common import evaluators
from brax.experimental.braxlines.envs.obs_indices import OBS_INDICES
from brax.experimental.braxlines.training import ppo
from brax.experimental.braxlines.irl_smm import evaluators as irl_evaluators
from brax.experimental.braxlines.irl_smm import utils as irl_utils

tfp = tfp.substrates.jax
tfd = tfp.distributions

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

In [None]:
#@title Define task and experiment parameters

#@markdown **Task Parameters**
#@markdown 
#@markdown As in [SMM](https://arxiv.org/abs/1906.05274)
#@markdown and [f-MAX](https://arxiv.org/abs/1911.02256),
#@markdown we assume some task knowledge about interesting dimensions
#@markdown of the environment `obs_indices` and their range `obs_scale`.
#@markdown This is also used for evaluation and visualization
#@markdown
#@markdown When the **task parameters** are the same, the metrics computed by
#@markdown [irl_smm/evaluators.py](https://github.com/google/brax/blob/main/brax/experimental/braxlines/vgcrl/evaluators.py)
#@markdown are directly comparable across experiment runs with different
#@markdown **experiment parameters**. 
env_name = 'ant'  # @param ['ant', 'humanoid', 'halfcheetah', 'uni_ant', 'bi_ant']
obs_indices = 'vel'  # @param ['vel']
target_num_modes =   2# @param{'type': 'integer'}
obs_scale = 8.0 #@param{'type': 'number'}
obs_indices_str = obs_indices
obs_indices = OBS_INDICES[obs_indices][env_name]

#@markdown **Experiment Parameters**
#@markdown See [irl_smm/utils.py](https://github.com/google/brax/blob/main/brax/experimental/braxlines/irl_smm/utils.py)
reward_type = "mle"  # @param ['gail', 'airl', 'gail2', 'fairl', 'mle']
logits_clip_range = 10.0# @param {'type': 'number'}
normalize_obs_for_disc = False # @param {'type': 'boolean'}
normalize_obs_for_rl = True # @param {'type': 'boolean'}
spectral_norm = False  # @param {'type': 'boolean'}
gradient_penalty_weight = 0.0 #@param {type: 'number'}
env_reward_multiplier = 0.0 # @param {'type': 'number'}
evaluate_dist = False # @param{type: 'boolean'}
seed = 0 # @param{type: 'integer'}

output_path = '' # @param {'type': 'string'}
task_name = "" # @param {'type': 'string'}
exp_name = '' # @param {'type': 'string'}
if output_path:
  output_path = output_path.format(
    date=datetime.now().strftime('%Y%m%d'))
  task_name = task_name or f'{env_name}_{obs_indices}_{obs_scale}_{target_num_modes}'
  exp_name = exp_name or f'{reward_type}'
  output_path = f'{output_path}/{task_name}/{exp_name}'
print(f'output_path={output_path}')

In [None]:
# @title Generate target distribution to match
target_num_samples = 250  # @param{type: 'integer'}

rng = jax.random.PRNGKey(seed=seed)
jit_get_dist = jax.jit(
    functools.partial(
        irl_utils.get_multimode_dist,
        indexed_obs_dim=len(obs_indices),
        num_modes=target_num_modes, scale=obs_scale))
target_dist = jit_get_dist()
target_data = target_dist.sample(
    seed=rng, sample_shape=(target_num_samples,))
target_data_2d = irl_utils.make_2d(target_data)

print(f'target_data={target_data.shape}')
plt.scatter(
    x=target_data_2d[:, 0], y=target_data_2d[:, 1], c=jnp.array([0, 0, 1]))
plt.xlim((-obs_scale, obs_scale))
plt.ylim((-obs_scale, obs_scale))
plt.title('target')
plt.show()

In [None]:
target_data.shape

In [None]:
# @title Make environment and inference_fn
visualize = False # @param {'type': 'boolean'}

base_env_fn = composer.create_fn(env_name=env_name)
base_env = base_env_fn()
disc = irl_utils.IRLDiscriminator(
    input_size=len(obs_indices),
    obs_indices=obs_indices,
    obs_scale=obs_scale,
    include_action=False,
    logits_clip_range=logits_clip_range,
    spectral_norm=spectral_norm,
    gradient_penalty_weight=gradient_penalty_weight,
    reward_type=reward_type,
    normalize_obs=normalize_obs_for_disc,
    target_data=target_data,
    target_dist_fn=jit_get_dist,
    env=base_env)
extra_params = disc.init_model(rng=jax.random.PRNGKey(seed=0))
env_fn = irl_utils.create_fn(
    env_name=env_name,
    wrapper_params=dict(
        disc=disc,
        env_reward_multiplier=env_reward_multiplier,
    ))
eval_env_fn = functools.partial(env_fn, auto_reset=False)

# make inference functions and goals for evaluation
core_env = env_fn()
params, inference_fn = ppo.make_params_and_inference_fn(
    core_env.observation_size,
    core_env.action_size,
    normalize_observations=normalize_obs_for_rl,
    extra_params=extra_params)
inference_fn = jax.jit(inference_fn)

# Visualize in 3D
if visualize:
  env = env_fn()
  jit_env_reset = jax.jit(env.reset)
  state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))
  clear_output()  # clear out jax.lax warning before rendering
  HTML(html.render(env.sys, [state.qp]))

In [None]:
#@title Training
num_timesteps_multiplier =   4# @param {type: 'number'}

# We determined some reasonable hyperparameters offline and share them here.
ppo_params = experiments.defaults.get_ppo_params(
    env_name, num_timesteps_multiplier, default='ant')
train_fn = functools.partial(ppo.train, **ppo_params)

times = [datetime.now()]
plotdata = {}
plotkeys = [
    'eval/episode_reward', 'losses/disc_loss', 'losses/total_loss',
    'losses/policy_loss', 'losses/value_loss', 'losses/entropy_loss',
    'metrics/energy_dist']

def progress(num_steps, metrics, params):
  times.append(datetime.now())
  if evaluate_dist:
    metrics.update(irl_evaluators.estimate_energy_distance_metric(
        params=params, disc=disc, target_data=target_data, env_fn=env_fn,
        inference_fn=inference_fn))

  for key, v in metrics.items():
    plotdata[key] = plotdata.get(key, dict(x=[], y=[]))
    plotdata[key]['x'] += [num_steps]
    plotdata[key]['y'] += [v]
  clear_output(wait=True)
  num_figs = len(plotkeys) + 1
  fig, axs = plt.subplots(ncols=num_figs, figsize=(3.5 * num_figs, 3))
  # plot learning curves
  for i, key in enumerate(plotkeys):
    if key in plotdata:
      axs[i].plot(plotdata[key]['x'], plotdata[key]['y'])
    axs[i].set(xlabel='# environment steps', ylabel=key)
    axs[i].set_xlim([0, train_fn.keywords['num_timesteps']])
  irl_evaluators.visualize_disc(
      params=params, disc=disc, num_grid=25, fig=fig, axs=axs)
  plt.show()

extra_loss_fns = dict(disc_loss=disc.disc_loss_fn)
inference_fn, params, _ = train_fn(
    seed=seed,
    environment_fn=env_fn,
    progress_fn=progress,
    extra_params=extra_params,
    extra_loss_fns=extra_loss_fns)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

In [None]:
#@title Visualizing skills of the learned inference function in 2D plot
num_samples = 10  # @param {type: 'integer'}
time_subsampling = 10  # @param {type: 'integer'}
time_last_n = 500 # @param {type: 'integer'}
eval_seed = 0  # @param {type: 'integer'}

metrics = irl_evaluators.estimate_energy_distance_metric(
    params=params,
    disc=disc,
    target_data=target_data,
    env_fn=eval_env_fn,
    inference_fn=inference_fn,
    num_samples=num_samples,
    time_subsampling=time_subsampling,
    time_last_n=time_last_n,
    visualize=True,
    figsize=(3.5,3),
    seed=eval_seed,
    output_path=output_path,
)
print(metrics)
plt.show()

In [None]:
env_fn

In [None]:
#@title Visualizing a trajectory of the learned inference function
eval_seed = 0  # @param {'type': 'integer'}

env, states = evaluators.visualize_env(
    env_fn=eval_env_fn,
    inference_fn=inference_fn,
    params=params,
    batch_size=0,
    seed = eval_seed,
    step_args = (params['normalizer'], params['extra']),
    output_path=output_path,
)
HTML(html.render(env.sys, [state.qp for state in states]))