In [2]:
# https://www.tensorflow.org/tutorials/reinforcement_learning/actor_critic

In [3]:
import gym
import haiku as hk
import optax
from jax import jit, lax, vmap, grad, partial
import jax.nn as nn
import jax.numpy as np
import jax.random as random

rng = random.PRNGKey(42)
env = gym.make("CartPole-v0")
gamma = .99
num_actions = env.action_space.n  # 2
num_hidden_units = 128
max_steps = 200
optimizer = optax.adam(learning_rate=.01)
eps = np.finfo(np.float32).eps.item()

In [4]:
def net_present_value(rate, values):
    def step(total, value):
        total = value + rate * total
        return total, total
    _, discounted_values = lax.scan(step, init=0, xs=values, reverse=True)
    return discounted_values

In [5]:
def get_expected_returns(rewards):
    returns = net_present_value(gamma, rewards)
    return nn.normalize(returns, epsilon=eps)

In [6]:
def env_init(rng):
    seed = random.randint(r, shape=(), minval=0, maxval=np.iinfo(np.int32).max)
    env.seed(int(seed))
    return np.array(env.reset())

In [7]:
def env_step(action):
    state, reward, done, _ = env.step(int(action))
    return (np.array(state), np.array(reward), np.array(done))

In [8]:
@jit
def sample_policy(rng, model_state, state):
    policy, _ = model.apply(model_state, state)
    rng, r = random.split(rng)
    return random.categorical(r, policy)

In [9]:
def run_episode(rng, model_state):
    rng, r = random.split(rng)
    state = env_init(r)
    
    items = []
    for i in range(max_steps):
        rng, r = random.split(rng)
        action = sample_policy(r, model_state, state)
        next_state, reward, done = env_step(action)
        items.append((state, action, reward))

        state = next_state
        if done:
            break

    return tuple(map(np.array, zip(*items)))

In [10]:
def huber_loss(yp, y, delta=1.):
    residual = np.abs(y - yp)
    losses = np.where(residual < delta, .5 * residual ** 2, residual - .5)
    return np.sum(losses)

In [11]:
def compute_loss(action_probs, values, returns):
    advantage = returns - values
    action_log_probs = np.log(action_probs)
    actor_loss = -np.sum(action_log_probs * advantage)
    critic_loss = huber_loss(values, returns)
    return actor_loss + critic_loss

In [12]:
@hk.without_apply_rng
@hk.transform
def model(x):
    x = hk.Linear(num_hidden_units, name='common')(x)
    x = nn.relu(x)
    actor = hk.Linear(2, name='actor')(x)
    critic = hk.Linear(1, name='critic')(x)
    return actor, critic

In [13]:
def train_step(rng, model_state, opt_state):
    rng, r = random.split(rng)
    action_probs, values, rewards = run_episode(r, model_state)
    returns = get_expected_returns(rewards)
    loss = compute_loss(action_probs, values, returns)

In [16]:
rng, r = random.split(rng)
model_state = model.init(r, env_init(r))
opt_state = optimizer.init(model_state)

In [17]:
rng, r = random.split(rng)
states, actions, rewards = run_episode(r, model_state)
policy, values = vmap(partial(model.apply, model_state))(states)
action_probs = nn.softmax(policy, axis=1)[np.arange(policy.shape[0]), actions]