Copyright 2020 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

### **Discrete tabular experiments for evaluating imitation learning algorithms**

In [None]:
# @title Setup and define MDPs

SPARSE_INITIAL_STATES = False
EASY_REWARD = True
EPS = 1e-6  # avoid nans

import jax
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

%matplotlib inline

def softmax(values, temperature, prior_policy):
  log_values = values / temperature + jnp.log(prior_policy)
  # return jnp.exp(log_values - logsumexp(log_values, axis=1, keepdims=True))
  return jax.nn.softmax(log_values, axis=1)

def soft_policy_iteration(reward, transition, initial,  prior_policy, alpha, gamma, n_iters=1000, pi_init=None, q_init=None):
  d_state, d_action, _ = transition.shape
  q = jnp.zeros((d_state, d_action)) if q_init is None else q_init
  pi = jnp.ones(((d_state, d_action))) / d_action if pi_init is None else pi_init

  def step(pi, q):
    pi_kl = jnp.sum(pi * (jnp.log(prior_policy) - jnp.log(EPS + pi)), axis=1)
    v = jnp.sum(q * pi, axis=1) + alpha * pi_kl
    q = reward + gamma * jnp.sum(transition * jnp.tile(v, (d_state, d_action, 1)), 2) + (1-gamma) * jnp.sum(initial * v)
    pi = softmax(q, alpha, prior_policy)
    return pi, q, v

  step = jax.jit(step)

  for _ in tqdm(range(n_iters)):
    pi, q, v = step(pi, q)
  return pi, q, v

def stationary_distribution(initial, transition, policy, discount, n_iters=1000):
  distribution = initial

  def sim(distribution):
    distribution_ = discount * jnp.einsum("ijk,ij,i->k", transition, policy, distribution) + (1 - discount) * initial
    distribution_ = distribution_ / distribution_.sum()
    delta = jnp.abs(distribution_ - distribution).max()
    return distribution_, delta

  sim = jax.jit(sim)
  for _ in range(n_iters):
    distribution, delta = sim(distribution)
    if delta < 1e-4:
      break
  else:
    print("Didn't converge!")
  return distribution

def simulate(initial, transition, policy, absorbing_states, n_episodes):
  t_max = 5_000
  key = jax.random.PRNGKey(0)
  d_state, d_action = policy.shape
  experience = np.zeros((t_max, d_state, d_action))
  state = jax.random.choice(key, jnp.arange(d_state), p=initial)
  n, t = 0, 0
  while n < n_episodes:
    _, key = jax.random.split(key)
    # act greedy
    action = jnp.argmax(policy[state])
    experience[t, state, action] = experience[t, state, action] + 1
    next_state_dist = transition[state, action, :]
    next_state = jax.random.choice(key, jnp.arange(d_state), p=next_state_dist)
    t += 1
    if state in absorbing_states:
      state = jax.random.choice(key, jnp.arange(d_state), p=initial)
      n += 1
    else:
      state = next_state
  return experience[:t, :, :]

d_action = 4
width = 10
discount = 0.95
gamma = discount
d_state = width * width
reward = jnp.zeros((d_state, d_action))
if EASY_REWARD:
  absorbing_state = [13 * width // 4,
                     15 * width // 4,
                     29 * width // 4,
                     31 * width // 4]
  for s in absorbing_state:
    reward = reward.at[s, :].set(1.)

else:  # sparse positive and negative
  reward = reward.at[-width, :].set(1. / (1 - discount))
  reward = reward.at[3 * width, :].set(-1. / (1 - discount))
  absorbing_state = [d_state - width,]

policy_eval = lambda eval_policy, state_dist: (reward * (jnp.diag(state_dist) @ eval_policy)).sum()

def make_dynamics(width, d_state, d_action, initial_idxs, windy):
  transition_matrix = jnp.zeros((d_state, d_action, d_state))
  clamp = lambda coord: max(min(coord, width - 1), 0)
  idx_to_coord = lambda idx: (i % width, i // width)
  coord_to_idx = lambda x, y:  clamp(y) * width + clamp(x)
  for i in range(d_state):
    x, y = idx_to_coord(i)
    next_states = jnp.array(list(map(coord_to_idx, [x+1, x-1, x, x], [y, y, y+1, y-1])))
    for j, k in zip(jnp.arange(d_action), next_states):
      if i not in absorbing_state:
        transition_matrix = transition_matrix.at[i, j, k].set(1.)
      else: # state is absorbing
        transition_matrix = transition_matrix.at[i, j, initial_idxs].set(1.)
    if windy and i != (d_state - width):
      disturance_idx = coord_to_idx(x+1, y)
      for j in range(d_action):
        transition_matrix = transition_matrix.at[i, j, disturance_idx].set(1.0)
  transition_matrix = transition_matrix / transition_matrix.sum(axis=2, keepdims=True)
  return transition_matrix

if SPARSE_INITIAL_STATES:
  initial_idxs = jnp.zeros((1,)).astype(jnp.int32)
  initial = jnp.zeros((d_state))
  initial = initial.at[0].set(1.)
else:
  initial_idxs = jnp.arange(d_state)
  initial = jnp.ones((d_state))

initial = initial / initial.sum()

transition_matrix = make_dynamics(width, d_state, d_action, initial_idxs, windy=False)

windy_transition_matrix = make_dynamics(width, d_state, d_action, initial_idxs, windy=True)

In [None]:
# @title Expert (SPI)

expert_alpha = 0.1
prior_policy = jnp.ones((d_state, d_action)) / d_action
expert_pi, expert_q, expert_v = soft_policy_iteration(reward, transition_matrix, initial, prior_policy, alpha=expert_alpha, gamma=discount)
expert_rho = stationary_distribution(initial, transition_matrix, expert_pi, discount)
expert_windy_rho = stationary_distribution(initial, windy_transition_matrix, expert_pi, discount)

experience = simulate(initial, transition_matrix, expert_pi, absorbing_state, n_episodes=100)
expert_counts = experience.sum(0)
expert_distribution = expert_counts / expert_counts.sum(keepdims=True)

fig, ax = plt.subplots(1, 6, figsize=(40, 10))
ax[0].set_title("Reward")
r = ax[0].imshow(reward)
plt.colorbar(r, ax=ax[0])

ax[1].set_title("Initial state distribution")
mu_ = ax[1].imshow(initial.reshape((width, width)))
plt.colorbar(mu_, ax=ax[1])

ax[2].set_title("Expert Soft Value Function")
v_ = ax[2].imshow(expert_v.reshape((width, width)))
plt.colorbar(v_, ax=ax[2])

ax[3].set_title("Expert Stationary Distribution")
rho_ = ax[3].imshow(expert_rho.reshape((width, width)))
plt.colorbar(rho_, ax=ax[3])

ax_ = ax[4]
ax_.set_title("Expert Stationary Distribution (Windy)")
fig = ax_.imshow(expert_windy_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax[5].set_title("Dataset Histogram")
rho_ = ax[5].imshow(experience.sum(0).sum(1).reshape((width, width)))
plt.colorbar(rho_, ax=ax[5])

In [None]:
# @title Behavoural Cloning (BC)

def fit_policy(data, prior):
  state_action_counts = data.sum(0)
  unnormalised = state_action_counts + jnp.log(prior)
  return jnp.exp(unnormalised - logsumexp(unnormalised, axis=1, keepdims=True))

bc_pi = fit_policy(experience, prior_policy) + EPS
bc_rho = stationary_distribution(initial, transition_matrix, bc_pi, discount)
bc_windy_rho = stationary_distribution(initial, windy_transition_matrix, bc_pi, discount)

fig, ax = plt.subplots(1, 2, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("BC Stationary Distribution")
fig = ax_.imshow(bc_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("BC Stationary Distribution (Windy)")
fig = ax_.imshow(bc_windy_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)


In [None]:
#@title Coherent Soft Imitation Learning (CSIL)
log_ratio = jnp.log(bc_pi) - jnp.log(prior_policy)
csil_alpha = 1.
csil_reward = csil_alpha  * log_ratio

csil_alpha = 1e-3
csil_pi, csil_q, csil_v = soft_policy_iteration(csil_reward, transition_matrix, initial,  prior_policy, alpha=csil_alpha, gamma=discount, pi_init=bc_pi, q_init=csil_reward)
rho_pcirl = stationary_distribution(initial, transition_matrix, csil_pi, discount)
windy_rho_pcirl = stationary_distribution(initial, windy_transition_matrix, csil_pi, discount)

csil_R_det = policy_eval(csil_pi, rho_pcirl)
csil_R_windy = policy_eval(csil_pi, windy_rho_pcirl)

fig, ax = plt.subplots(1, 3, figsize=(40, 10))

ax_ = ax[0]
ax_.set_title("CSIL Soft Value Function")
fig = ax_.imshow(csil_v.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("CSIL Stationary Distribution")
fig = ax_.imshow(rho_pcirl.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[2]
ax_.set_title("CSIL Stationary Distribution (Windy)")
fig = ax_.imshow(windy_rho_pcirl.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

fig, ax = plt.subplots(1, 4, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("Expert Policy")
fig = ax_.imshow(expert_pi)
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("CSIL reward")
fig = ax_.imshow(csil_reward)
plt.colorbar(fig, ax=ax_)

ax_ = ax[2]
ax_.set_title("BC Policy")
fig = ax_.imshow(bc_pi)
plt.colorbar(fig, ax=ax_)

ax_ = ax[3]
ax_.set_title("CSIL Policy")
fig = ax_.imshow(csil_pi)
plt.colorbar(fig, ax=ax_)

In [None]:
#@title Offline Classifier-style IRL (i.e. ORIL, SQIL, ...)
classifier_reward = jnp.array(experience.sum(0) > 0.).astype(np.float32)
new_alpha = 0.01

classifier_pi, classifier_q, classifier_v = soft_policy_iteration(classifier_reward, transition_matrix, initial, prior_policy, alpha=new_alpha, gamma=discount, pi_init=bc_pi)
classifier_rho = stationary_distribution(initial, transition_matrix, classifier_pi, discount)
classifier_windy_rho = stationary_distribution(initial, windy_transition_matrix, classifier_pi, discount)

fig, ax = plt.subplots(1, 4, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("Classifier-based Reward")
fig = ax_.imshow(classifier_reward)
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("Classifier-based IRL Stationary Distribution")
fig = ax_.imshow(classifier_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[2]
ax_.set_title("Classifier-based IRL Stationary Distribution (Windy)")
fig = ax_.imshow(classifier_windy_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[3]
ax_.set_title("IRL Value Function")
fig = ax_.imshow(classifier_v.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

In [None]:
#@title Maximum Entropy IRL
def max_ent_irl(demonstratioon_distribution, transition, initial,  prior_policy, alpha, gamma, n_iters=500, r_lr=1):
  demonstratioon_distribution = demonstratioon_distribution / demonstratioon_distribution.sum()
  d_state, d_action, _ = transition.shape
  meirl_reward = jnp.zeros((d_state, d_action))
  errors = []
  for _ in range(n_iters):
    pi_irl, q, v = soft_policy_iteration(meirl_reward, transition, initial, prior_policy, alpha=alpha, gamma=gamma, n_iters=100)
    irl_rho = stationary_distribution(initial, transition, pi_irl, discount, n_iters=100)
    irl_distribution = jnp.diag(irl_rho) @ pi_irl
    error = (demonstratioon_distribution - irl_distribution)
    violation = (error ** 2).mean()
    errors += [violation]
    if len(errors) > 1 and (abs(errors[-1] - errors[-2]) / abs(errors[-2]) < 1e-5):
      break
    meirl_reward += r_lr * error
  return pi_irl, q, v, meirl_reward

alpha_meirl = 1.
meirl_pi, meirl_q, meirl_v, meirl_reward = max_ent_irl(expert_distribution, transition_matrix, initial, prior_policy, alpha_meirl, gamma)

meirl_rho = stationary_distribution(initial, transition_matrix, meirl_pi, discount)
meirl_windy_rho = stationary_distribution(initial, windy_transition_matrix, meirl_pi, discount)

fig, ax = plt.subplots(1, 5, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("ME-IRL Policy")
fig = ax_.imshow(meirl_pi)
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("ME-IRL Q")
fig = ax_.imshow(meirl_q)
plt.colorbar(fig, ax=ax_)

ax_ = ax[2]
ax_.set_title("ME-IRL R")
fig = ax_.imshow(meirl_reward)
plt.colorbar(fig, ax=ax_)

ax_ = ax[3]
ax_.set_title("True R")
fig = ax_.imshow(reward)
plt.colorbar(fig, ax=ax_)

ax_ = ax[3]
ax_.set_title("ME-IRL R")
fig = ax_.imshow(meirl_reward.mean(1).reshape((width, width)))
plt.colorbar(fig, ax=ax_)

fig, ax = plt.subplots(1, 4, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("Expert Distribution")
fig = ax_.imshow(expert_distribution.sum(1).reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("ME-IRL Stationary Distribution")
fig = ax_.imshow(meirl_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[2]
ax_.set_title("ME-IRL IRL Stationary Distribution (Windy)")
fig = ax_.imshow(meirl_windy_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[3]
ax_.set_title("ME-IRL Value Function")
fig = ax_.imshow(meirl_v.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

In [None]:
# @title Generative adversarial imitation learning (GAIL)

classifier_reward = 1. - jnp.array(experience.sum(0) > 0.).astype(np.float32)

@jax.jit
def discrimator_loss(weights, expert_distribution, policy_distribution):
    classification = jax.nn.sigmoid(weights)
    loss = jnp.log(EPS + 1. - classification) * expert_distribution
    loss += jnp.log(EPS + classification) * policy_distribution
    return loss.sum()

discrimator_grad = jax.value_and_grad(discrimator_loss)

def gail(expert_distribution, transition, initial,  prior_policy, alpha, gamma, n_iters=100, r_lr=1):
  expert_distribution = expert_distribution / expert_distribution.sum()
  d_state, d_action, _ = transition.shape
  discriminator_weights = classifier_reward
  for _ in range(n_iters):
    gail_reward = jnp.log(EPS + jax.nn.sigmoid(discriminator_weights))
    pi_irl, q, v = soft_policy_iteration(gail_reward, transition_matrix, initial, prior_policy, alpha=alpha, gamma=gamma, n_iters=100)
    pi_rho = stationary_distribution(initial, transition_matrix, pi_irl, discount, n_iters=100)
    irl_distribution = jnp.diag(pi_rho) @ pi_irl
    for _ in range(200):
      r_loss, r_grad = discrimator_grad(discriminator_weights, expert_distribution, irl_distribution)
      discriminator_weights -= r_lr * r_grad
    print(r_loss)
  pi, q, v = soft_policy_iteration(gail_reward, transition_matrix, initial, prior_policy, alpha=alpha, gamma=gamma, n_iters=100)
  return pi, q, v, jnp.log(EPS + jax.nn.sigmoid(discriminator_weights))

gail_alpha = 1e-3
gail_pi, gail_q, gail_v, gail_reward = gail(expert_distribution, transition_matrix, initial, prior_policy, gail_alpha, gamma)

gail_rho = stationary_distribution(initial, transition_matrix, gail_pi, discount)
gail_windy_rho = stationary_distribution(initial, windy_transition_matrix, gail_pi, discount)

fig, ax = plt.subplots(1, 5, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("GAIL Policy")
fig = ax_.imshow(gail_pi)
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("GAIL Q")
fig = ax_.imshow(gail_q)
plt.colorbar(fig, ax=ax_)

ax_ = ax[2]
ax_.set_title("GAIL R")
fig = ax_.imshow(gail_reward)
plt.colorbar(fig, ax=ax_)

ax_ = ax[3]
ax_.set_title("True R")
fig = ax_.imshow(reward)
plt.colorbar(fig, ax=ax_)

ax_ = ax[4]
ax_.set_title("GAIL R")
fig = ax_.imshow(gail_reward.mean(1).reshape((width, width)))
plt.colorbar(fig, ax=ax_)

fig, ax = plt.subplots(1, 4, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("Expert Distribution")
fig = ax_.imshow(expert_distribution.sum(1).reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("GAIL Stationary Distribution")
fig = ax_.imshow(gail_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[2]
ax_.set_title("GAIL Stationary Distribution (Windy)")
fig = ax_.imshow(gail_windy_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[3]
ax_.set_title("GAIL Value Function")
fig = ax_.imshow(gail_v.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

In [None]:
# @title Inverse Soft Q Learning (IQLearn)
convex_reg = lambda x: x - 0.25 * x ** 2

@jax.jit
def iq_learn_loss(q, expert_distribution, pi, transition, initial, prior_policy, discount, alpha):
    pi_kl = jnp.sum(pi * (jnp.log(prior_policy) - jnp.log(EPS + pi)), axis=1)
    v = jax.lax.stop_gradient(jnp.sum(q * pi, axis=1) + alpha * pi_kl)
    imp_reward = q - gamma * jnp.sum(transition * jnp.tile(v, (d_state, d_action, 1)), 2)
    objective = (expert_distribution * convex_reg(imp_reward)).sum() - (1-gamma) * jnp.sum(initial * v)
    return -objective

iq_learn_grad = jax.grad(iq_learn_loss)

def inverse_q_learning(expert_distribution, transition, initial,  prior_policy, alpha, gamma, n_iters=1000, q_lr=1e-2):
  d_state, d_action, _ = transition.shape
  q_ = jnp.zeros((d_state, d_action))
  pi_ = jnp.ones(((d_state, d_action))) / d_action
  fig, ax = plt.subplots(1, 5, figsize=(40, 10))
  d = n_iters // 5
  for i in tqdm(range(n_iters)):
    pi_kl = jnp.sum(pi_ * (jnp.log(EPS + pi_) - jnp.log(prior_policy)), axis=1)
    v_ = jnp.sum(q_ * pi_, axis=1) - alpha * pi_kl
    if i % d == 0:
      ax_ = ax[i // d]
      ax_.set_title(i)
      fig = ax_.imshow(v_.reshape((width, width)))
      plt.colorbar(fig, ax=ax_)
    q_grad = iq_learn_grad(q_, expert_distribution, pi_, transition_matrix, initial, prior_policy, discount, alpha)
    q_ = q_ - q_lr * q_grad
    pi_ = softmax(q_, alpha, prior_policy)

  pi_kl = jnp.sum(pi_ * (jnp.log(EPS + pi_) - jnp.log(prior_policy)), axis=1)
  v_ = jnp.sum(q_ * pi_, axis=1) - alpha * pi_kl
  imp_reward = q_ - gamma * jnp.sum(transition * jnp.tile(v_, (d_state, d_action, 1)), 2)
  return pi_, q_, v_, imp_reward

iq_alpha = 1e-3
iq_pi, iq_q, iq_v, iq_r = inverse_q_learning(expert_distribution, transition_matrix, initial, prior_policy, iq_alpha, gamma)

iq_rho = stationary_distribution(initial, transition_matrix, iq_pi, discount)
iq_windy_rho = stationary_distribution(initial, windy_transition_matrix, iq_pi, discount)

iq_v = iq_v
iq_windy_rho = iq_windy_rho

fig, ax = plt.subplots(1, 5, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("IQ-Learn Policy")
fig = ax_.imshow(iq_pi)
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("IQ-Learn Q")
fig = ax_.imshow(iq_q)
plt.colorbar(fig, ax=ax_)

ax_ = ax[2]
ax_.set_title("IQ-Learn R")
fig = ax_.imshow(iq_r)
plt.colorbar(fig, ax=ax_)

ax_ = ax[3]
ax_.set_title("True Reward R")
fig = ax_.imshow(reward)
plt.colorbar(fig, ax=ax_)


fig, ax = plt.subplots(1, 4, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("Expert Distribution")
fig = ax_.imshow(expert_distribution.sum(1).reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("IQ-Learn IRL Stationary Distribution")
fig = ax_.imshow(iq_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[2]
ax_.set_title("IQ-Learn IRL Stationary Distribution (Windy)")
fig = ax_.imshow(iq_windy_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[3]
ax_.set_title("IQ-Learn Value Function")
fig = ax_.imshow(iq_v.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

In [None]:
# @title Proximal Point Imitation Learning (PPIL)

@jax.jit
def dual(q, r, expert_distribution, policy_distribution, pi, transition, initial, prior_policy, discount, alpha):
    expert_reward = (r * expert_distribution).sum()
    v = -alpha * jax.scipy.special.logsumexp(-q / alpha + jnp.log(EPS + prior_policy), axis=1)
    q_target = r + discount * jnp.sum(transition * jnp.tile(v, (d_state, d_action, 1)), 2)
    be = q_target - q
    w = jax.lax.stop_gradient(jax.nn.softmax(1e-2 * be))
    sbe = (policy_distribution * w * be).sum()
    dual = expert_reward - sbe - (1 - discount) * jnp.sum(initial * v)
    return -dual

dual_grad = jax.value_and_grad(dual, argnums=(0, 1))

def proximal_point_imitiation_learning(expert_distribution, transition, initial,  prior_policy, alpha, gamma, n_iters=250, lr=1e-2):
  d_state, d_action, _ = transition.shape
  q_ = jnp.zeros((d_state, d_action))
  r_ = jnp.zeros((d_state, d_action))
  pi_ = prior_policy
  for j in tqdm(range(25)):
    policy_distribution = stationary_distribution(initial, transition_matrix, pi_, discount, n_iters=10)
    policy_distribution = jnp.diag(policy_distribution) @ pi_
    for i in tqdm(range(n_iters)):
      d, grads = dual_grad(q_, r_, expert_distribution, policy_distribution, pi_, transition_matrix, initial, pi_, discount, alpha)
      q_grad, r_grad = grads
      q_ = q_ - lr * q_grad
      r_ = r_ - lr * r_grad
      r_ = r_ / (r_ ** 2).sum()
    pi_ = softmax(q_, alpha, prior_policy)

  v_ = -alpha * jax.scipy.special.logsumexp(-q_ / alpha + jnp.log(EPS + pi_), axis=1)
  return pi_, q_, v_, r_

alpha_ppil = 1e-3
ppil_pi, ppil_q, ppil_v, ppil_r = proximal_point_imitiation_learning(expert_distribution, transition_matrix, initial, prior_policy, alpha_ppil, gamma)

ppil_rho = stationary_distribution(initial, transition_matrix, ppil_pi, discount)
ppil_windy_rho = stationary_distribution(initial, windy_transition_matrix, ppil_pi, discount)

fig, ax = plt.subplots(1, 5, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("PPIL Policy")
fig = ax_.imshow(ppil_pi)
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("PPIL Q")
fig = ax_.imshow(ppil_q)
plt.colorbar(fig, ax=ax_)

ax_ = ax[2]
ax_.set_title("PPIL R")
fig = ax_.imshow(ppil_r)
plt.colorbar(fig, ax=ax_)

ax_ = ax[3]
ax_.set_title("True Reward R")
fig = ax_.imshow(reward)
plt.colorbar(fig, ax=ax_)


fig, ax = plt.subplots(1, 4, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("Expert Distribution")
fig = ax_.imshow(expert_distribution.sum(1).reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[1]
ax_.set_title("PPIL IRL Stationary Distribution")
fig = ax_.imshow(ppil_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[2]
ax_.set_title("PPIL IRL Stationary Distribution (Windy)")
fig = ax_.imshow(ppil_windy_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

ax_ = ax[3]
ax_.set_title("PPIL Value Function")
fig = ax_.imshow(ppil_v.reshape((width, width)))
plt.colorbar(fig, ax=ax_)

In [None]:
# @title Report results

import tabulate
from IPython.display import HTML, display

greedy_pi = jnp.array(expert_pi == jnp.max(expert_pi, axis=1, keepdims=True)).astype(np.float32)
greedy_pi = greedy_pi / greedy_pi.sum(axis=1, keepdims=True)
greedy_rho = stationary_distribution(initial, transition_matrix, greedy_pi, discount)
windy_greedy_rho = stationary_distribution(initial, windy_transition_matrix, greedy_pi, discount)


pi_R_det = policy_eval(expert_pi, expert_rho)
pi_R_windy = policy_eval(expert_pi, expert_windy_rho)

greedy_R_det = policy_eval(greedy_pi, greedy_rho)
greedy_R_windy = policy_eval(greedy_pi, windy_greedy_rho)

bc_pi_R_det = policy_eval(bc_pi, bc_rho)
bc_pi_R_windy = policy_eval(bc_pi, bc_windy_rho)

meirl_pi_R_det = policy_eval(meirl_pi, meirl_rho)
meirl_pi_R_windy = policy_eval(meirl_pi, meirl_windy_rho)

gail_pi_R_det = policy_eval(gail_pi, gail_rho)
gail_pi_R_windy = policy_eval(gail_pi, gail_windy_rho)

csil_R_det = policy_eval(csil_pi, rho_pcirl)
csil_R_windy = policy_eval(csil_pi, windy_rho_pcirl)

irl_R_det = policy_eval(classifier_pi, classifier_rho)
irl_R_windy = policy_eval(classifier_pi, classifier_windy_rho)

iq_R_det = policy_eval(iq_pi, iq_rho)
iq_R_windy = policy_eval(iq_pi, iq_windy_rho)

ppil_R_det = policy_eval(ppil_pi, ppil_rho)
ppil_R_windy = policy_eval(ppil_pi, ppil_windy_rho)

data =  [["Greedy Soft Expert ", greedy_R_det, greedy_R_windy],
         ["Soft Expert ", pi_R_det, pi_R_windy],
         ["BC ", bc_pi_R_det, bc_pi_R_windy],
         ["CLASSIFIER ", irl_R_det, irl_R_windy],
         ["ME-IRL ", meirl_pi_R_det, meirl_pi_R_windy],
         ["GAIL ", gail_pi_R_det, gail_pi_R_windy],
         ["IQLEARN ", iq_R_det, iq_R_windy],
         ["PPIL ", ppil_R_det, ppil_R_windy],
         ["CSIL ", csil_R_det, csil_R_windy]]

table = tabulate.tabulate(data, tablefmt='html', floatfmt=".3f", headers=["","Deterministic","Windy"],)
display(HTML(table))

fig, ax = plt.subplots(1, 4, figsize=(40, 10))
ax_ = ax[0]
ax_.set_title("Greedy Stationary Distribution")
fig = ax_.imshow(greedy_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)
ax_ = ax[1]
ax_.set_title("Greedy Stationary Distribution (Windy)")
fig = ax_.imshow(windy_greedy_rho.reshape((width, width)))
plt.colorbar(fig, ax=ax_)
ax_ = ax[2]
ax_.set_title("Greedy Policy")
fig = ax_.imshow(greedy_pi)
plt.colorbar(fig, ax=ax_)
ax_ = ax[3]
ax_.set_title("Soft Policy")
fig = ax_.imshow(expert_pi)
plt.colorbar(fig, ax=ax_)

In [None]:
#@title Plotting imports and constants
import matplotlib
import matplotlib.pyplot as plt

base = 10
small_offset = 2
constrained_layout = True
tight_layout = True
_GOLDEN_RATIO = (5.0**0.5 - 1.0) / 2.0
width_in = 6.9
height_in = width_in
height_per_row = height_in // 4
pad_inches = 0.015
from matplotlib import rc
rc('mathtext', fontset='cm')

mpl_params = {
    'figure.constrained_layout.use': True,
    "lines.markersize": 0.1,
    "lines.linewidth": 1.0,
    "font.family": "Crimson Text",
    "text.usetex": False,
    "font.serif": ["Crimson Text"] + plt.rcParams['font.serif'],
    "mathtext.fontset": "stix",  # free ptmx replacement, for ICML and NeurIPS
    "mathtext.rm": "Times New Roman",
    "mathtext.it": "Times New Roman:italic",
    "mathtext.bf": "Times New Roman:bold",
    "font.size": base,
    "axes.labelsize": base,
    "legend.fontsize": base - small_offset,
    "xtick.labelsize": base - small_offset,
    "ytick.labelsize": base - small_offset,
    "axes.titlesize": base,
    "figure.constrained_layout.use": constrained_layout,
    "figure.autolayout": tight_layout,
    "savefig.bbox": "tight",
    "savefig.pad_inches": pad_inches,
}

plt.rcParams.update(mpl_params)

title_params = {'fontfamily':'monospace'}

# Colorblind-safe and print-friendly colors selected from colorbrewer2.org
YELLOW = '#F2E34C'
GREEN = '#82b392'
AQUA = '#41b6c4'
BLUE = '#2c7fb8'
ROYALBLUE = '#253494'
PURPLE = '#998ec3'

RED = '#d7191c'
LIGHTRED = '#ffcccb'
ORANGE = '#fdae61'
LIGHTBLUE = '#abd9e9'
MEDBLUE = '#2c7bb6'

LIGHTORANGE = '#fed98e'
MEDORANGE = '#fe9929'
DARKORANGE = '#cc4c02'

GREY = '#808080'

OPACITY = 0.2  # for uncertainty

In [None]:
#@title Plot and save results
from colabtools import fileedit
RESULTS = {
    "expert": (expert_v, greedy_rho, windy_greedy_rho),
    "bc": (None, bc_rho, bc_windy_rho),
    "csil": (csil_v, rho_pcirl, windy_rho_pcirl),
    "classifier": (classifier_v, classifier_rho, classifier_windy_rho),
    "me-irl": (meirl_v, meirl_rho, meirl_windy_rho),
    "gail": (gail_v, gail_rho, gail_windy_rho),
    "iqlearn": (iq_v, iq_rho, iq_windy_rho),
    "ppil": (ppil_v, ppil_rho, ppil_windy_rho),
}

exp = f"{'sparse' if SPARSE_INITIAL_STATES else 'dense'}_{'easy' if EASY_REWARD else 'hard'}"

for name, objs in RESULTS.items():
  fig, ax = plt.subplots(1, 3, figsize=(width_in, width_in // 3), sharex=True, sharey=True, gridspec_kw={'wspace': 0.1})
  for ax_ in ax:
    ax_.set_xticklabels([])
    ax_.set_yticklabels([])
    ax_.set_xticks([])
    ax_.set_yticks([])
  v, rho, wrho = objs

  ax_ = ax[0]
  if v is not None:
    ax_.set_title("Value function")
    ax_.imshow(v.reshape((width, width)))
  else:
    ax_.axis('off')

  ax_ = ax[1]
  ax_.set_title("Nominal stationary distribution")
  ax_.imshow(rho.reshape((width, width)))
  ax_ = ax[2]
  ax_.set_title("Windy stationary distribution")
  ax_.imshow(wrho.reshape((width, width)))


  filename = f'{exp}_{name}.pdf'
  fig.savefig(filename, bbox_inches = 'tight')
  fileedit.download_file(filename, ephemeral=True)