<a href="https://colab.research.google.com/github/mauricef/jax-ml/blob/main/jax-actor-critic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install dm-haiku optax

# Jax Actor Critic
This is an implementation of the [TensorFlow actor critic example](https://www.tensorflow.org/tutorials/reinforcement_learning/actor_critic) using JAX, Haiku and Optax. The TF version executes ~ 15 it/sec while the JAX version executes ~ 35 it/sec on the same CPU Colab. They are both much slower using GPU acceleration.


The code is similar between the two implementations but there are some important differences.

- The JAX version generates fixed length episodes with a reward of `np.nan` when the environment enters the `done` state. The TF version generates variable length episodes. If we generate variable length episodes then the `jit` on the `train_step` will keep triggering a recompile every time it encounters a different episode length which seriously slows things down.

- Due to the fixed episode length, we need to normalize the returns while taking into account `np.nan` values, that is what the `safe_*` methods are for.

- Also due to the variable length episodes, we need to filter out steps that are past the end of the episode when computing the gradient.

- JAX computes gradients by tracing a function call taking the model weights as the first parameter vs TF which traces the execution using the gradient tape context. To get this working we needed to refactor the responsabilities of the `generate_episode` and `compute_loss_and_grads`. This does require two model evaluations per step - the first to sample an action in `generate_episode` and the second to compute the `value` and `policy` in `compute_loss_and_grads`.

In [1]:
import collections

import gym
import haiku as hk
import matplotlib.pyplot as plt
import numpy as onp
import optax
from jax import jit, lax, vmap, value_and_grad, partial
import jax.nn as nn
import jax.numpy as np
import jax.random as random
from jax.tree_util import tree_map
import tqdm

In [2]:
Episode = collections.namedtuple('Episode', 'state action reward value')

In [3]:
env = gym.make("CartPole-v0")
gamma = .99
num_actions = env.action_space.n  # 2
num_hidden_units = 128
episode_steps = 200
optimizer = optax.adam(learning_rate=.01)
eps = np.finfo(np.float32).eps.item()

In [4]:
@jit
def scan_stack(ys):
    return tuple(map(np.array, zip(*ys)))

In [5]:
def scan(f, init, xs=None, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, scan_stack(ys)

In [6]:
# nan aware helpers

def safe_mean(x):
    mask = ~np.isnan(x)
    n = np.count_nonzero(mask)
    total = np.sum(np.nan_to_num(x))
    return total / n

def safe_var(x):
    return safe_mean(np.square(x)) - np.square(safe_mean(x))

def safe_normalize(x, epsilon=eps):
    mean = safe_mean(x)
    var = safe_var(x)
    return (x - mean) * lax.rsqrt(var + epsilon)

def safe_npv(rate, xs):
    def step(total, x):
        total = x + rate * total
        return total, total
    mask = ~np.isnan(xs)
    xs = np.nan_to_num(xs)
    _, ys = lax.scan(step, init=0., xs=xs, reverse=True)
    ys = np.where(mask, ys, np.nan)
    return ys

In [7]:
@jit
def get_values(rewards, gamma=gamma):
    values = safe_npv(gamma, rewards)
    return safe_normalize(values, epsilon=eps)

In [8]:
def env_is_done(env):
    return env.steps_beyond_done is not None

In [9]:
def env_init(rng):
    seed = random.randint(r, shape=(), minval=0, maxval=np.iinfo(np.int32).max)
    env.seed(int(seed))
    return np.array(env.reset())

In [10]:
def env_step(action):
    if env_is_done(env):
        state, reward = onp.array(env.state), np.nan
    else:
        state, reward, _, _ = env.step(int(action))
    return state, reward

In [11]:
@hk.without_apply_rng
@hk.transform
def model(x):
    x = hk.Linear(num_hidden_units, name='common')(x)
    x = nn.relu(x)
    actor = hk.Linear(2, name='actor')(x)
    critic = hk.Linear(1, name='critic')(x)
    return actor, critic

In [12]:
@jit
def sample_policy(rng, model_state, state):
    policy, _ = model.apply(model_state, state)
    return random.categorical(rng, policy)

In [13]:
def generate_episode(rng, model_state):
    def generate_step(state, rng):
        action = sample_policy(rng, model_state, state)
        next_state, reward = env_step(action)
        return next_state, (state, action, reward)

    rng, r = random.split(rng)
    initial_state = env_init(r)
    rngs = np.array(random.split(rng, episode_steps))
    _, (state, action, reward) = scan(generate_step, initial_state, xs=rngs)
    value = get_values(reward)
    return Episode(state, action, reward, value)

In [14]:
def huber_loss(yp, y, delta=1.):
    residual = np.abs(y - yp)
    return np.where(residual < delta, .5 * residual ** 2, residual - .5)

In [15]:
@jit
def loss_fn(model_state, step):
    policy_logits, predicted_value = model.apply(model_state, step.state)
    advantage = step.value - predicted_value
    policy_probs = nn.softmax(policy_logits)
    action_prob = policy_probs[step.action]
    action_log_prob = np.log(action_prob)
    actor_loss = -action_log_prob * advantage
    critic_loss = huber_loss(predicted_value, step.value)
    loss = actor_loss + critic_loss
    return loss

In [16]:
@value_and_grad
def compute_loss_and_grads(model_state, episode):
    def compute_loss_step(step):
        w = tree_map(lambda a: np.where(np.isnan(step.reward), lax.stop_gradient(a), a), model_state)
        loss = loss_fn(w, step)
        loss = np.nan_to_num(loss)
        return loss
    losses = vmap(partial(compute_loss_step))(episode)
    return np.sum(losses)

In [17]:
@jit
def train_step(episode, model_state, opt_state):
    loss_value, loss_grads = compute_loss_and_grads(model_state, episode)
    model_updates, opt_state = optimizer.update(loss_grads, opt_state, model_state)
    model_state = optax.apply_updates(model_state, model_updates)
    episode_reward = np.sum(np.nan_to_num(episode.reward))
    return episode_reward, loss_value, model_state, opt_state

In [18]:
rng = random.PRNGKey(42)



In [19]:
rng, r = random.split(rng)
model_state = model.init(r, env_init(r))
opt_state = optimizer.init(model_state)

max_episodes = 10000
max_steps_per_episode = 200
min_episodes_criterion = 100
reward_threshold = 195
running_reward = 0
gamma = 0.99
episodes_reward = collections.deque(maxlen=min_episodes_criterion)

with tqdm.trange(max_episodes) as t:
    for i in t:
        rng, r = random.split(rng)
        episode = generate_episode(r, model_state)
        rng, r = random.split(rng)
        episode_reward, loss_value, model_state, opt_state = train_step(episode, model_state, opt_state)
        episodes_reward.append(episode_reward)
        running_reward = onp.mean(episodes_reward)
        
        t.set_description(f'Episode {i}')
        t.set_postfix(
            episode_reward=episode_reward, running_reward=running_reward)

        if running_reward > reward_threshold and i >= min_episodes_criterion:  
            break
            
print(f'\nSolved at episode {i}: average reward: {running_reward:.2f}!')

Episode 1255:  13%|█▎        | 1255/10000 [00:31<03:40, 39.73it/s, episode_reward=200.0, running_reward=195]


Solved at episode 1255: average reward: 195.14!



