<a href="https://colab.research.google.com/github/bacey/open_spiel/blob/master/Podracer_Architectures_for_Scalable_RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

> Copyright 2021 DeepMind Technologies Limited.
>
> Licensed under the Apache License, Version 2.0 (the "License");
> you may not use this file except in compliance with the License.
>
> You may obtain a copy of the License at
> https://www.apache.org/licenses/LICENSE-2.0
>
> Unless required by applicable law or agreed to in writing, software
> distributed under the License is distributed on an "AS IS" BASIS,
> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
> See the License for the specific language governing permissions and
> limitations under the License.

### **Podracer Architectures for Scalable RL**
*By Matteo Hessel, Manuel Kroiss, Fabio Viola, Hado van Hasselt*


This is a minimal demonstration of how to implement the `Anakin` architecture in JAX.

### Imports

In [None]:
# Install dependencies
! pip install chex -q
! pip install dm_haiku -q
! pip install rlax -q
! pip install optax -q

# Setup TPU
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

# Imports
import chex
import jax
import haiku as hk
from jax import lax
from jax import random
from jax import numpy as jnp
import jax.numpy as jnp
jax.devices()
import optax
import rlax
import timeit

In [None]:
class TimeIt():

  def __init__(self, tag, frames=None):
    self.tag = tag
    self.frames = frames

  def __enter__(self):
    self.start = timeit.default_timer()
    return self

  def __exit__(self, *args):
    self.elapsed_secs = timeit.default_timer() - self.start
    msg = self.tag + (': Elapsed time=%.2fs' % self.elapsed_secs)
    if self.frames:
      msg += ', FPS=%.2e' % (self.frames / self.elapsed_secs)
    print(msg)

### Environment

In [None]:
class Catch:
  """A JAX implementation of the Catch gridworld."""

  def __init__(self, rows: int = 10, columns: int = 5):
    self._rows = rows
    self._columns = columns
    self.num_actions = 3

  def initial_state(self, rng):
    ball_y = 0
    ball_x = random.randint(rng, (), 0, self._columns)
    paddle_y = self._rows - 1
    paddle_x = self._columns // 2
    return lax.stop_gradient(jnp.array((ball_y, ball_x, paddle_y, paddle_x), dtype=jnp.int32))

  def step(self, rng, state, action):
    is_terminal = self.is_terminal(state)
    paddle_x = jnp.clip(state[3] + action - 1, 0, self._columns - 1)
    state = jnp.array([state[0] + 1, state[1], state[2], paddle_x])
    state = lax.select(is_terminal, self.initial_state(rng), state)
    return lax.stop_gradient(state)

  def observation(self, state):
    return (self.render(state), self.reward(state),
            self.discount(state), self.is_terminal(state))

  def render(self, state):
    def f(y, x):
      return lax.select(
          jnp.bitwise_or(
              jnp.bitwise_and(y == state[0], x == state[1]),
              jnp.bitwise_and(y == state[2], x == state[3])), 1., 0.)
    y_board = jnp.repeat(jnp.arange(self._rows), self._columns)
    x_board = jnp.tile(jnp.arange(self._columns), self._rows)
    return jax.vmap(f)(y_board, x_board).reshape((self._rows, self._columns, 1))

  def reward(self, state):
    return lax.select(
        self.is_terminal(state), lax.select(state[1] == state[3], 1., -1.), 0.)

  def discount(self, state):
    return lax.select(self.is_terminal(state), 0., 1.)

  def is_terminal(self, state):
    return state[0] == self._rows - 1

### Anakin

In [None]:
@chex.dataclass(frozen=True)
class TimeStep:
  q_values: chex.Array
  action: chex.Array
  discount: chex.Array
  reward: chex.Array

def get_network_fn(num_outputs: int):
  """Define a fully connected multi-layer haiku network."""
  def network_fn(obs: chex.Array) -> chex.Array:
    return hk.Sequential([  # flatten, hidden layer, relu, output layer.
        hk.Flatten(), hk.Linear(256), jax.nn.relu, hk.Linear(num_outputs)])(obs)
  return hk.without_apply_rng(hk.transform(network_fn))

def get_learner_fn(
    env, forward_pass, opt_update, rollout_len, agent_discount,
    lambda_, iterations):
  """Define the minimal unit of computation in Anakin."""

  def loss_fn(params, outer_rng, env_state):
    """Compute the loss on a single trajectory."""

    def step_fn(env_state, rng):
      obs, reward, discount, is_terminal = env.observation(env_state)
      q_values = forward_pass(params, obs[None,])[0]  # forward pass.
      action = jnp.argmax(q_values)  # greedy policy.
      env_state = env.step(rng, env_state, action)  # step environment.
      return env_state, TimeStep(  # return env state and transition data.
          q_values=q_values, action=action, discount=discount, reward=reward)

    step_rngs = random.split(outer_rng, rollout_len)
    env_state, rollout = lax.scan(step_fn, env_state, step_rngs)  # trajectory.
    qa_tm1 = rlax.batched_index(rollout.q_values[:-1], rollout.action[:-1])
    td_error = rlax.td_lambda(  # compute multi-step temporal diff error.
        v_tm1=qa_tm1,  # predictions.
        r_t=rollout.reward[1:],  # rewards.
        discount_t=agent_discount * rollout.discount[1:],  # discount.
        v_t=jnp.max(rollout.q_values[1:], axis=-1),  # bootstrap values.
        lambda_=lambda_)  # mixing hyper-parameter lambda.
    return jnp.mean(td_error**2), env_state

  def update_fn(params, opt_state, rng, env_state):
    """Compute a gradient update from a single trajectory."""
    rng, loss_rng = random.split(rng)
    grads, new_env_state = jax.grad(  # compute gradient on a single trajectory.
        loss_fn, has_aux=True)(params, loss_rng, env_state)
    grads = lax.pmean(grads, axis_name='j')  # reduce mean across cores.
    grads = lax.pmean(grads, axis_name='i')  # reduce mean across batch.
    updates, new_opt_state = opt_update(grads, opt_state)  # transform grads.
    new_params = optax.apply_updates(params, updates)  # update parameters.
    return new_params, new_opt_state, rng, new_env_state

  def learner_fn(params, opt_state, rngs, env_states):
    """Vectorise and repeat the update."""
    batched_update_fn = jax.vmap(update_fn, axis_name='j')  # vectorize across batch.
    def iterate_fn(_, val):  # repeat many times to avoid going back to Python.
      params, opt_state, rngs, env_states = val
      return batched_update_fn(params, opt_state, rngs, env_states)
    return lax.fori_loop(0, iterations, iterate_fn, (
        params, opt_state, rngs, env_states))

  return learner_fn

In [None]:
def run_experiment(env, batch_size, rollout_len, step_size, iterations, seed):
  """Runs experiment."""
  cores_count = len(jax.devices())  # get available TPU cores.
  network = get_network_fn(env.num_actions)  # define network.
  optim = optax.adam(step_size)  # define optimiser.

  rng, rng_e, rng_p = random.split(random.PRNGKey(seed), num=3)  # prng keys.
  dummy_obs = env.render(env.initial_state(rng_e))[None,]  # dummy for net init.
  params = network.init(rng_p, dummy_obs)  # initialise params.
  opt_state = optim.init(params)  # initialise optimiser stats.

  learn = get_learner_fn(  # get batched iterated update.
      env, network.apply, optim.update, rollout_len=rollout_len,
      agent_discount=1, lambda_=0.99, iterations=iterations)
  learn = jax.pmap(learn, axis_name='i')  # replicate over multiple cores.

  broadcast = lambda x: jnp.broadcast_to(x, (cores_count, batch_size) + x.shape)
  params = jax.tree_map(broadcast, params)  # broadcast to cores and batch.
  opt_state = jax.tree_map(broadcast, opt_state)  # broadcast to cores and batch

  rng, *env_rngs = jax.random.split(rng, cores_count * batch_size + 1)
  env_states = jax.vmap(env.initial_state)(jnp.stack(env_rngs))  # init envs.
  rng, *step_rngs = jax.random.split(rng, cores_count * batch_size + 1)

  reshape = lambda x: x.reshape((cores_count, batch_size) + x.shape[1:])
  step_rngs = reshape(jnp.stack(step_rngs))  # add dimension to pmap over.
  env_states = reshape(env_states)  # add dimension to pmap over.

  with TimeIt(tag='COMPILATION'):
    learn(params, opt_state, step_rngs, env_states)  # compiles

  num_frames = cores_count * iterations * rollout_len * batch_size
  with TimeIt(tag='EXECUTION', frames=num_frames):
    params, opt_state, step_rngs, env_states = learn(  # runs compiled fn
        params, opt_state, step_rngs, env_states)

In [None]:
print('Running on', len(jax.devices()), 'cores.', flush=True)  # !expected 8!
run_experiment(Catch(), 128, 16, 1e-4, 100, 42)

Running on 8 cores.
COMPILATION: Elapsed time=4.08s
EXECUTION: Elapsed time=0.37s, FPS=4.41e+06
