# Training Adversarial Inverse RL and State Marginal Matching Agents in Brax

In [Brax Training](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb) we tried out [gym](https://gym.openai.com/)-like environments and PPO, SAC, evolutionary search, and trajectory optimization algorithms. We can build various RL algorithms on top of these ultra-fast implementations. This colab runs a family of [adversarial inverse RL](https://arxiv.org/abs/1911.02256) algorithms, which includes [GAIL](https://papers.nips.cc/paper/2016/hash/cc7e2b878868cbae992d1fb743995d8f-Abstract.html) and [AIRL](https://arxiv.org/abs/1710.11248) as special cases. These algorithms minimize D(p(s,a), p\*(s,a)) or D(p(s), p\*(s)), the divergence D between the policy's state(-action) marginal distribution p(s,a) or p(s), and a given target distribution p\*(s,a) or p\*(s). As discussed in [f-MAX](https://arxiv.org/abs/1911.02256), these algorithms could also be used for [state-marginal matching](https://arxiv.org/abs/1906.05274) RL besides imitation learning. Let's try them out!

This provides a bare bone implementation based on minimal modifications to the
baseline [PPO](https://github.com/google/brax/blob/main/brax/training/ppo.py),
enabling training in a few minutes. More features, tunings, and benchmarked results will be added soon, including:  
* Support for training a mixture of policies 
* Examples for imitation learning



```
# This is formatted as code
```

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/brax/blob/main/notebooks/braxlines/irl_smm.ipynb)

In [None]:
#@title Install Brax and some helper modules
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime > Change runtime type, then select 'TPU' in the dropdown.

from datetime import datetime
import functools
import os

import jax
import jax.numpy as jnp
from IPython.display import HTML, clear_output 
import matplotlib.pyplot as plt

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

import tensorflow_probability as tfp
from brax.io import html
from brax.experimental.composer import composer
from brax.experimental.composer import observers
from brax.experimental.braxlines.training import ppo
from brax.experimental.braxlines.irl_smm import utils as irl_utils

tfp = tfp.substrates.jax
tfd = tfp.distributions

def visualize(sys, qps, save_path: str = None):
  """Renders a 3D visualization of the environment."""
  if save_path:
    html.save_html(save_path, sys, qps, make_dir=True)
  return HTML(html.render(sys, qps))

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

In [None]:
#@title Visualizing pre-included Brax environments
env_name = "ant"  # @param ['ant', 'halfcheetah', 'ant_cheetah', 'uni_ant', 'bi_ant']
env_space = "vel" # @param ['vel', 'pos', 'ang']
exp_name = "smm_multimode"  # @param ['smm', 'smm_multimode', 'smm_multimode3', 'smm_maxent']
algo_name = "gail2"  # @param ['gail', 'airl', 'gail2', 'fairl']
disc_arch = "mlp"  # @param ['linear', 'mlp']
logits_clip_range = 10.0# @param {'type': 'number'}
normalize_obs_for_disc = False # @param {'type': 'boolean'}
balance_data_for_disc = True # @param {'type': 'boolean'}

env_indices = {
    'vel': { # x-y velocity
      'ant': (13, 14),
      'humanoid': (22, 23),
      'halfcheetah': (11,),
      'uni_ant': (('vel:torso_ant1', 0),('vel:torso_ant1', 1)),
      'bi_ant': (('vel:torso_ant1', 0),('vel:torso_ant2', 0)),
    },
    'ang': { # angular velocity
      'ant': (17,),
      'uni_ant': (('ang:torso_ant1', 2),),
    },
}[env_space][env_name]
base_env_fn = composer.create_fn(env_name=env_name)
base_env = base_env_fn()

disc_arch = {
    'linear': (),
    'mlp': (32, 32),
}[disc_arch]
if exp_name in ['smm', 'smm_multimode', 'smm_multimode3', 'smm_maxent']:
  disc_fn = functools.partial(
        irl_utils.IRLDiscriminator,
        input_size=len(env_indices),
        obs_indices=env_indices,
        include_action=False,
        arch=disc_arch,
        logits_clip_range=logits_clip_range,
        )
else:
  raise NotImplementedError(exp_name)
disc = disc_fn(reward_type=algo_name,
               normalize_obs=normalize_obs_for_disc,
               balance_data=balance_data_for_disc,
               env=base_env)
extra_params = disc.init_model(rng=jax.random.PRNGKey(seed=0)) 
env_fn = irl_utils.create_fn(env_name=env_name, disc=disc)
env = env_fn()

# Visualize in 3D
env = env_fn()
jit_env_reset = jax.jit(env.reset)
state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))
clear_output()  # clear out jax.lax warning before rendering
visualize(env.sys, [state.qp])

In [None]:
#@title Generate and visualize target data distribution p\*(s, a) or p\*(s)
N =  250# @param{type: 'integer'}

def draw_2d_uniform(rng, N, low, high):
  rng, key = jax.random.split(rng)
  dist = tfd.Uniform(low=jnp.array(low), high=jnp.array(high))
  data_2d = dist.sample(sample_shape=N, seed=key)
  return rng, data_2d

rng = jax.random.PRNGKey(seed=0)
if exp_name == 'smm':
  rng, target_data_2d = draw_2d_uniform(rng, N=N, low=[-6.,-0.5], high=[-4.,0.5])
  target_data = target_data_2d[..., :len(env_indices)]
elif exp_name == 'smm_multimode':
  rng, sample1 = draw_2d_uniform(rng, N=N, low=[-6.,-0.5], high=[-4.,0.5])
  rng, sample2 = draw_2d_uniform(rng, N=N, low=[4.,-0.5], high=[6.,0.5])
  target_data_2d = jnp.concatenate([sample1, sample2], axis=0) 
  target_data = target_data_2d[..., :len(env_indices)]
elif exp_name == 'smm_multimode3':
  rng, sample1 = draw_2d_uniform(rng, N=N, low=[-2.,-2.], high=[-1.,-1.])
  rng, sample2 = draw_2d_uniform(rng, N=N, low=[1.,-2.], high=[2.,-1.])
  rng, sample3 = draw_2d_uniform(rng, N=N, low=[-0.5,1.], high=[0.5,2.])
  target_data_2d = jnp.concatenate([sample1, sample2, sample3], axis=0) 
  target_data = target_data_2d[..., :len(env_indices)]
elif exp_name == 'smm_maxent':
  rng, target_data_2d = draw_2d_uniform(rng, N=N, low=[-6.,-6.], high=[6.,6.])
  target_data = target_data_2d[..., :len(env_indices)]
else:
  raise NotImplementedError(exp_name)

disc.set_target_data(target_data)

print(f'target_data={target_data.shape}')
lim = jnp.max(jnp.abs(target_data_2d)) + 0.5
plt.scatter(x=target_data_2d[:, 0],
            y=target_data_2d[:, 1],
            c=jnp.array([0,0,1]))
plt.xlim((-lim, lim))
plt.ylim((-lim, lim))
plt.title('target (e.g. x-y velocities)')
plt.show()

In [None]:
#@title Training some pre-included Brax environments
num_timesteps_multiplier =  2# @param {type: 'integer'}

# We determined some reasonable hyperparameters offline and share them here.
n = num_timesteps_multiplier
train_fn = functools.partial(
      ppo.train, num_timesteps = 50_000_000*n, log_frequency = 20,
      reward_scaling = 10, episode_length = 1000, normalize_observations = True,
      action_repeat = 1, unroll_length = 5, num_minibatches = 32,
      num_update_epochs = 4, discounting = 0.95, learning_rate = 3e-4,
      entropy_cost = 1e-2, num_envs = 2048, batch_size = 1024
  )

times = [datetime.now()]
plotdata = {}
plotkeys = ['eval/episode_reward', 'losses/disc_loss', 'losses/total_loss',
            'losses/policy_loss', 'losses/value_loss', 'losses/entropy_loss']
grid = jnp.linspace(-6.5, 6.5, 25)
xgrid, ygrid = jnp.meshgrid(grid, grid)
datagrid = jnp.concatenate([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], axis=-1)
lim = jnp.max(jnp.abs(datagrid)) + 0.5

def progress(num_steps, metrics, optimizer_params):
  times.append(datetime.now())
  for key, v in metrics.items():
    plotdata[key] = plotdata.get(key, dict(x=[], y=[]))
    plotdata[key]['x'] += [num_steps]
    plotdata[key]['y'] += [v]
  clear_output(wait=True)
  num_figs = len(plotkeys) + 1
  fig, axs = plt.subplots(ncols=num_figs, figsize=(3.5*num_figs, 3))
  # plot learning curves
  for i, key in enumerate(plotkeys):
    if key in plotdata:
      axs[i].plot(plotdata[key]['x'], plotdata[key]['y'])
    axs[i].set(xlabel='# environment steps', ylabel=key)
    axs[i].set_xlim([0, train_fn.keywords['num_timesteps']])
  # plot discriminator visualization
  distgrid = disc.dist(datagrid[..., :len(env_indices)], params=optimizer_params['extra'])
  probsgrid = jax.nn.sigmoid(distgrid.logits)
  print(f'disc probs: max={probsgrid.max()}, min={probsgrid.min()}')
  colors = jnp.clip(jnp.array([[-2, 0, 2]]) * (probsgrid-0.5), a_min=0)
  axs[-1].scatter(x=datagrid[:, 0],
                  y=datagrid[:, 1],
                  c=colors)
  axs[-1].set_xlim((-lim, lim))
  axs[-1].set_ylim((-lim, lim))
  axs[-1].set(title='discriminator output (red=0, black=0.5, blue=1)')
  fig.tight_layout()
  plt.show()

extra_loss_fns = dict(disc_loss=disc.disc_loss_fn)
inference_fn, params, _ = train_fn(environment_fn=env_fn, 
    progress_fn=progress, extra_params=extra_params,
    extra_loss_fns=extra_loss_fns)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

In [None]:
#@title Visualizing a trajectory of the learned inference function
seed = 0  # @param {'type': 'integer'}
save_path = None # @param {'type': 'raw'}
save_path = save_path.format(
    date=datetime.now().strftime('%Y%m%d'),
    env_space=env_space,
    env_name=env_name,
    exp_name=exp_name,
    algo_name=algo_name) if save_path else save_path

jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)
qps, states = [], []
state = env.reset(rng=jax.random.PRNGKey(seed=seed))
while not state.done:
  qps.append(state.qp)
  states.append(state)
  act = jit_inference_fn(params, state.obs, state.rng)
  state = jit_env_step(state, act, params[0], params[-1])
visualize(env.sys, qps, save_path=save_path)

In [None]:
#@title Visualizing skills of the learned inference function in 2D plot
import numpy as np
from itertools import product
num_samples = 5  # @param {type: 'integer'}
time_subsampling = 10  # @param {type: 'integer'}
time_last_n = 500 # @param {type: 'integer'}
seed = 0  # @param {type: 'integer'}

# Reset and run environment
batch_env = env_fn(batch_size=num_samples)
state = batch_env.reset(
    jnp.array([jax.random.PRNGKey(seed+i) for i in range(num_samples)]))
states = [state]
jit_step = jax.jit(batch_env.step)
jit_inference_fn = jax.jit(inference_fn)
while not state.done.all():
  act = jit_inference_fn(params, state.obs, state.rng[0])
  state = jit_step(state, act, params[0], params[-1])
  states.append(state)

# Get indices of interest
obses_full = jnp.stack([state.obs for state in states],
                  axis=0)
obses = obses_full[-time_last_n:][::time_subsampling]
env_vars = batch_env.disc.index_obs(obses) # [T, num_samples, #env_indices]
target_vars = target_data
if env_vars.shape[-1] == 1:
  env_vars = jnp.concatenate([env_vars, jnp.zeros(env_vars.shape)], axis=-1)
  target_vars = jnp.concatenate([target_vars, jnp.zeros(target_vars.shape)], axis=-1)
print(f'env_vars.shape={env_vars.shape}')
print(f'target_vars.shape={target_vars.shape}')
env_vars_flat = env_vars.reshape(-1, 2)
target_vars_flat = target_vars.reshape(-1, 2)

# Plot
lim = jnp.max(jnp.abs(
    jnp.concatenate([env_vars_flat, target_vars_flat], axis=0))) + 0.5
fig, axs = plt.subplots(ncols=2, figsize=(3.5*3, 3))
fig.tight_layout()
axs[0].set(title='agent policy')
axs[0].set_xlim([-lim, lim])
axs[0].set_ylim([-lim, lim])
axs[0].scatter(x=env_vars_flat[:, 0], y=env_vars_flat[:, 1], c=[1,0,0], alpha=0.3)
axs[1].set(title='target')
axs[1].set_xlim([-lim, lim])
axs[1].set_ylim([-lim, lim])
axs[1].scatter(x=target_vars_flat[:, 0], y=target_vars_flat[:, 1], c=[0,0,1], alpha=0.3)
plt.show()