### Install bsuite along with the baseline agents

In [None]:
! pip install -q bsuite[baselines_jax]  # quiet import along with agent implemented in jax

### Imports (click to expand the hidden cell)

In [None]:
import math
import time
import bsuite
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt

from typing import Any, Callable, NamedTuple, Sequence

from bsuite.baselines import base
from bsuite.baselines.utils import replay

import dm_env
from dm_env import specs
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax

## Implementation of DQN and bootstrapped DQN agents

### Functions that need to be completed

In [None]:
def sample_td_error(q_t, a_t, r, discount, q_next):
  """Write a function for computing TD error for a single sample.

  For a sample (s_t, a_t, r, s_{t+1}),
  q_t: Q-values corresponding to state s_t. shape = (number of actions)
  a_t: action. An integer. We assume that all the actions are indexed 
       from 0 to number of actions -1.
  r: reward,
  q_next: Q_values correspondng to the next state s_{t+1}.
          shape = (number of actions)
  """
  pass


def epsilon_greedy(q_values, num_actions, epsilon):
  """ Epsilon-greedy action selection scheme.

  q_values: Q-values corresponding to different actions of the current state.
            shape = (1, number of actions)
  num_actions: number of actions. (same as the size of q_values array)
  epsilon: The probability with which we select a random action.
           a float value in between 0 and 1. 
  We assume that all the actions are indexed from 0 to number of actions -1.
  """ 
  pass

### DQN agent code

**You don't need to modify this part of the code**

In [None]:
class TrainingState(NamedTuple):
  """Holds the agent's training state."""
  params: hk.Params
  target_params: hk.Params
  opt_state: Any
  step: int

class DQN(base.Agent):
  """A simple DQN agent using JAX."""

  def __init__(
      self,
      obs_spec: specs.Array,
      action_spec: specs.DiscreteArray,
      network: Callable[[jnp.ndarray], jnp.ndarray],
      optimizer: optax.GradientTransformation,
      batch_size: int,
      epsilon: float,
      rng: hk.PRNGSequence,
      discount: float,
      replay_capacity: int,
      min_replay_size: int,
      sgd_period: int,
      target_update_period: int,
  ):
    self.name = 'dqn'

    # Transform the (impure) network into a pure function.
    network = hk.without_apply_rng(hk.transform(network, apply_rng=True))

    # Define loss function.
    def loss(params: hk.Params,
             target_params: hk.Params,
             transitions: Sequence[jnp.ndarray]) -> jnp.ndarray:
      """Computes the standard TD(0) Q-learning loss on batch of transitions."""
      o_tm1, a_tm1, r_t, d_t, o_t = transitions
      q_tm1 = network.apply(params, o_tm1)
      q_t = network.apply(target_params, o_t)
      q_t = jax.lax.stop_gradient(q_t)
      batch_q_learning = jax.vmap(sample_td_error)
      td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t)
      return jnp.mean(td_error**2)

    # Define update function.
    @jax.jit
    def sgd_step(state: TrainingState,
                 transitions: Sequence[jnp.ndarray]) -> TrainingState:
      """Performs an SGD step on a batch of transitions."""
      gradients = jax.grad(loss)(state.params, state.target_params, transitions)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      return TrainingState(
          params=new_params,
          target_params=state.target_params,
          opt_state=new_opt_state,
          step=state.step + 1)

    # Initialize the networks and optimizer.
    dummy_observation = np.zeros((1, *obs_spec.shape), jnp.float32)
    initial_params = network.init(next(rng), dummy_observation)
    initial_target_params = network.init(next(rng), dummy_observation)
    initial_opt_state = optimizer.init(initial_params)

    # This carries the agent state relevant to training.
    self._state = TrainingState(
        params=initial_params,
        target_params=initial_target_params,
        opt_state=initial_opt_state,
        step=0)
    self._sgd_step = sgd_step
    self._forward = jax.jit(network.apply)
    self._replay = replay.Replay(capacity=replay_capacity)

    # Store hyperparameters.
    self._num_actions = action_spec.num_values
    self._batch_size = batch_size
    self._sgd_period = sgd_period
    self._target_update_period = target_update_period
    self._epsilon = epsilon
    self._total_steps = 0
    self._min_replay_size = min_replay_size

  def select_action(self, timestep: dm_env.TimeStep) -> base.Action:
    """Selects actions according to an epsilon-greedy policy."""
    observation = timestep.observation[None, ...]
    q_values = self._forward(self._state.params, observation)
    action = epsilon_greedy(q_values, self._num_actions, self._epsilon)
    return int(action)

  def update(
      self,
      timestep: dm_env.TimeStep,
      action: base.Action,
      new_timestep: dm_env.TimeStep,
  ):
    """Adds transition to replay and periodically does SGD."""
    # Add this transition to replay.
    self._replay.add([
        timestep.observation,
        action,
        new_timestep.reward,
        new_timestep.discount,
        new_timestep.observation,
    ])

    self._total_steps += 1
    if self._total_steps % self._sgd_period != 0:
      return

    if self._replay.size < self._min_replay_size:
      return

    # Do a batch of SGD.
    transitions = self._replay.sample(self._batch_size)
    self._state = self._sgd_step(self._state, transitions)

    # Periodically update target parameters.
    if self._state.step % self._target_update_period == 0:
      self._state = self._state._replace(target_params=self._state.params)


def dqn_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
  """Initialize a DQN agent with default parameters."""

  np.random.seed(seed)

  def network(inputs: jnp.ndarray) -> jnp.ndarray:
    flat_inputs = hk.Flatten()(inputs)
    mlp = hk.nets.MLP([50, 50, action_spec.num_values])
    action_values = mlp(flat_inputs)
    return action_values

  return DQN(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      optimizer=optax.adam(1e-3),
      batch_size=128,
      discount=0.99,
      replay_capacity=10000,
      min_replay_size=128,
      sgd_period=1,
      target_update_period=4,
      epsilon=0.2,
      rng=hk.PRNGSequence(seed),
  )

### Boot DQN agent code

**You don't need to modify this part of the code**

In [None]:
class TrainingState(NamedTuple):
  params: hk.Params
  target_params: hk.Params
  opt_state: Any
  step: int


class BootstrappedDqn(base.Agent):
  """Bootstrapped DQN with randomized prior functions."""

  def __init__(
      self,
      obs_spec: specs.Array,
      action_spec: specs.DiscreteArray,
      network: Callable[[jnp.ndarray], jnp.ndarray],
      num_ensemble: int,
      batch_size: int,
      discount: float,
      replay_capacity: int,
      min_replay_size: int,
      sgd_period: int,
      target_update_period: int,
      optimizer: optax.GradientTransformation,
      mask_prob: float,
      noise_scale: float,
      epsilon_fn: Callable[[int], float] = lambda _: 0.,
      seed: int = 1,
  ):
    self.name = 'boot dqn'
    # Transform the (impure) network into a pure function.
    network = hk.without_apply_rng(hk.transform(network, apply_rng=True))

    # Define loss function, including bootstrap mask `m_t` & reward noise `z_t`.
    def loss(params: hk.Params, target_params: hk.Params,
             transitions: Sequence[jnp.ndarray]) -> jnp.ndarray:
      """Q-learning loss with added reward noise + half-in bootstrap."""
      o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
      q_tm1 = network.apply(params, o_tm1)
      q_t = network.apply(target_params, o_t)
      r_t += noise_scale * z_t
      batch_q_learning = jax.vmap(sample_td_error)
      td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t)
      return jnp.mean(m_t * td_error**2)

    # Define update function for each member of ensemble..
    @jax.jit
    def sgd_step(state: TrainingState,
                 transitions: Sequence[jnp.ndarray]) -> TrainingState:
      """Does a step of SGD for the whole ensemble over `transitions`."""

      gradients = jax.grad(loss)(state.params, state.target_params, transitions)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      return TrainingState(
          params=new_params,
          target_params=state.target_params,
          opt_state=new_opt_state,
          step=state.step + 1)

    # Initialize parameters and optimizer state for an ensemble of Q-networks.
    rng = hk.PRNGSequence(seed)
    dummy_obs = np.zeros((1, *obs_spec.shape), jnp.float32)
    initial_params = [
        network.init(next(rng), dummy_obs) for _ in range(num_ensemble)
    ]
    initial_target_params = [
        network.init(next(rng), dummy_obs) for _ in range(num_ensemble)
    ]
    initial_opt_state = [optimizer.init(p) for p in initial_params]

    # Internalize state.
    self._ensemble = [
        TrainingState(p, tp, o, step=0) for p, tp, o in zip(
            initial_params, initial_target_params, initial_opt_state)
    ]
    self._forward = jax.jit(network.apply)
    self._sgd_step = sgd_step
    self._num_ensemble = num_ensemble
    self._optimizer = optimizer
    self._replay = replay.Replay(capacity=replay_capacity)

    # Agent hyperparameters.
    self._num_actions = action_spec.num_values
    self._batch_size = batch_size
    self._sgd_period = sgd_period
    self._target_update_period = target_update_period
    self._min_replay_size = min_replay_size
    self._epsilon_fn = epsilon_fn
    self._mask_prob = mask_prob

    # Agent state.
    self._active_head = self._ensemble[0]
    self._total_steps = 0

  def select_action(self, timestep: dm_env.TimeStep) -> base.Action:
    """Select values via Thompson sampling, then use epsilon-greedy policy."""
    self._total_steps += 1
    batched_obs = timestep.observation[None, ...]
    q_values = self._forward(self._active_head.params, batched_obs)
    action = epsilon_greedy(q_values, self._num_actions,
                            self._epsilon_fn(self._total_steps))
    return int(action)

  def update(
      self,
      timestep: dm_env.TimeStep,
      action: base.Action,
      new_timestep: dm_env.TimeStep,
  ):
    """Update the agent: add transition to replay and periodically do SGD."""

    # Thompson sampling: every episode pick a new Q-network as the policy.
    if new_timestep.last():
      k = np.random.randint(self._num_ensemble)
      self._active_head = self._ensemble[k]

    # Generate bootstrapping mask & reward noise.
    mask = np.random.binomial(1, self._mask_prob, self._num_ensemble)
    noise = np.random.randn(self._num_ensemble)

    # Make transition and add to replay.
    transition = [
        timestep.observation,
        action,
        np.float32(new_timestep.reward),
        np.float32(new_timestep.discount),
        new_timestep.observation,
        mask,
        noise,
    ]
    self._replay.add(transition)

    if self._replay.size < self._min_replay_size:
      return

    # Periodically sample from replay and do SGD for the whole ensemble.
    if self._total_steps % self._sgd_period == 0:
      transitions = self._replay.sample(self._batch_size)
      o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
      for k, state in enumerate(self._ensemble):
        transitions = [o_tm1, a_tm1, r_t, d_t, o_t, m_t[:, k], z_t[:, k]]
        self._ensemble[k] = self._sgd_step(state, transitions)

    # Periodically update target parameters.
    for k, state in enumerate(self._ensemble):
      if state.step % self._target_update_period == 0:
        self._ensemble[k] = state._replace(target_params=state.params)


def boot_dqn_agent(
    obs_spec: specs.Array,
    action_spec: specs.DiscreteArray,
    seed: int = 0,
    num_ensemble: int = 10,
) -> BootstrappedDqn:
  """Initialize a Bootstrapped DQN agent with default parameters."""

  np.random.seed(seed)

  # Define network.
  prior_scale = 5.
  hidden_sizes = [50, 50]

  def network(inputs: jnp.ndarray) -> jnp.ndarray:
    """Simple Q-network with randomized prior function."""
    net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
    x = hk.Flatten()(inputs)
    return net(x) + prior_scale * jax.lax.stop_gradient(prior_net(x))

  optimizer = optax.adam(learning_rate=1e-3)
  return BootstrappedDqn(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      batch_size=128,
      discount=.99,
      num_ensemble=num_ensemble,
      replay_capacity=10000,
      min_replay_size=128,
      sgd_period=1,
      target_update_period=4,
      optimizer=optimizer,
      mask_prob=1.,
      noise_scale=0.,
      epsilon_fn=lambda _: 0.,
      seed=seed,
  )

### Run DQN and Boot DQN agent on the deep sea environment of depth 10

**You don't need to modify this part of the code**

In [None]:
# @title run the agent and collect data
num_simulations = 5
num_episodes = 250

deep_sea_sizes = [10]
AGENT_LIST = [dqn_agent, boot_dqn_agent]

results = []
for seed in range(num_simulations):
  for deep_sea_size in deep_sea_sizes:
    env = bsuite.load('deep_sea', {'size': deep_sea_size, 
                                   'seed': seed, 'mapping_seed': seed})
    for AGENT in AGENT_LIST:

      agent = AGENT(obs_spec=env.observation_spec(),
                    action_spec=env.action_spec(),
                    seed=seed)
      for episode in range(num_episodes):
        # Run an episode.
        timestep = env.reset()
        reward = 0
        while not timestep.last():
          # Generate an action from the agent's policy.
          action = agent.select_action(timestep)
          # Step the environment.
          new_timestep = env.step(action)
          # Tell the agent about what just happened.
          agent.update(timestep, action, new_timestep)
          # Book-keeping.
          timestep = new_timestep
          reward += timestep.reward

        result = {'episode': episode, 'reward': round(reward, 3),
                  'regret': round(0.99 - reward, 3), 'seed': seed,
                  'agent': agent.name, 'deep_sea_size': deep_sea_size}
        results.append(result)
        if episode % int(num_episodes/5) == 0:
          print(result)
df = pd.DataFrame(results)

### Plot results

**You don't need to modify this part of the code**

In [None]:
ave_df = df.groupby(['episode', 'agent', 'deep_sea_size'])['reward', 'regret'].mean().reset_index()

for deep_sea_size in ave_df.deep_sea_size.unique():
  plt.figure()
  for agent in ave_df.agent.unique():
    agent_df = ave_df[(ave_df.agent==agent) & (ave_df.deep_sea_size==deep_sea_size)]
    plt.plot(agent_df.episode, np.cumsum(agent_df.regret), label=agent)
  plt.xlabel(r'episode', fontsize=20)
  plt.ylabel('cumulative regret', fontsize=20)
  plt.title('Deep sea of size '+str(deep_sea_size))
  plt.legend(loc='best')
  plt.ylim([0, 200])
  plt.show()