> Copyright 2022 DeepMind Technologies Limited.
>
> Licensed under the Apache License, Version 2.0 (the "License");
> you may not use this file except in compliance with the License.
>
> You may obtain a copy of the License at
> https://www.apache.org/licenses/LICENSE-2.0
>
> Unless required by applicable law or agreed to in writing, software
> distributed under the License is distributed on an "AS IS" BASIS,
> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
> See the License for the specific language governing permissions and
> limitations under the License.

# Reinforcement Learning
You may have already encountered **supervised learning**, where we have an input and a target value or class that we want to predict. There is also **unsupervised learning**, where we are only given an input and look for patterns in that input. In this practical, we look into **reinforcement learning** (RL), which can loosely be defined as training an **agent** to maximise a numerical **reward** it obtains through interaction with an **environment**.

The environment defines a set of **actions** that an agent can take. The agent observes the current **state** of the environment, tries actions, and *learns* a **policy** which is a distribution over the possible actions given a state of the environment.

The following diagram illustrates the interaction between the agent and the environment. We will explore each of the terms in more detail throughout this practical.

<center><img src="https://storage.googleapis.com/dm-educational/assets/reinforcement-learning-summer-school/rl_agent_environment.png" width="500" /></center>

# Imports and dependencies.
Please run the cells below.

In [None]:
#@title Install dependencies.
%%capture

!pip install bsuite
!pip install dm-haiku
!pip install optax

In [None]:
#@title Imports.

import collections
import random
from typing import Sequence

import chex
import dm_env
from dm_env import specs
import haiku as hk
import jax
import jax.numpy as jnp
from matplotlib import animation
from matplotlib import rc
import matplotlib.pyplot as plt
import numpy as np
import optax

rc('animation', html='jshtml')

## Catch Environment
Let's begin by making a very simple game called "Catch", which is often used as a test bed for RL algorithms.

In this environment, a ball drops from the top and the agent controls the paddle at the bottom via three possible  <font color='#0175c2'>**Actions**</font>: `left`, `stay`, and `right`. The <font color='#00ba47'>**Reward**</font> is given at the end of the episode, and is either `+1` for catching the ball or `-1` dropping the ball (reward is `0` in all intermediate steps). The episode ends when the ball reaches the bottom of the screen, and otherwise continues.

In [None]:
#@title Defining the Catch Environment

_ACTIONS = (-1, 0, 1)  # Move paddle left, no-op, move paddle right.

class Catch(dm_env.Environment):
  """A Catch environment built on the dm_env.Environment class.

  The agent must move a paddle to intercept falling balls. Falling balls only
  move downwards on the column they are in.

  The observation is an array with shape (rows, columns) containing binary
  values: 0 if a space is empty; 1 if it contains the paddle or a ball.

  The actions are discrete, and by default there are three available actions:
  move left, stay, and move right.

  The episode terminates when the ball reaches the bottom of the screen.
  """

  def __init__(self,
               rows: int = 10,
               columns: int = 5,
               discount: float = 1.0):
    """Initializes a new Catch environment.

    Args:
      rows: number of rows.
      columns: number of columns.
      discount: discount factor for calculating reward.
    """
    self._rows = rows
    self._columns = columns
    self._discount = discount
    self._board = np.zeros((rows, columns), dtype=np.float32)
    self._ball_x = None
    self._ball_y = None
    self._paddle_x = None
    self._reset_next_step = True

  def reset(self) -> dm_env.TimeStep:
    """Returns the first `TimeStep` of a new episode."""
    self._reset_next_step = False
    # Ball can drop from any column.
    self._ball_x = np.random.randint(self._columns)
    self._ball_y = 0  # Top of matrix.
    self._paddle_x = self._columns // 2  # Centre.

    return dm_env.restart(self._observation())

  def step(self, action: int) -> dm_env.TimeStep:
    """Updates the environment according to the action."""
    if self._reset_next_step:
      return self.reset()

    # Move the paddle.
    dx = _ACTIONS[action]  # Get action. dx = change in x position.
    # Clip to keep paddle in bounds of the environment matrix.
    self._paddle_x = np.clip(self._paddle_x + dx, 0, self._columns - 1)

    # Drop the ball down one row.
    self._ball_y += 1

    # Check for termination.
    if self._ball_y == self._rows - 1:  # Ball has fallen below the rows.
      # Reward depends on whether the paddle is on the ball (positions match).
      reward = 1. if self._paddle_x == self._ball_x else -1.
      self._reset_next_step = True
      return dm_env.termination(reward=reward, observation=self._observation())

    return dm_env.transition(reward=0., observation=self._observation(),
                             discount=self._discount)

  def observation_spec(self) -> specs.BoundedArray:
    """Returns the observation spec."""
    return specs.BoundedArray(
        shape=self._board.shape,
        dtype=self._board.dtype,
        name='board',
        minimum=0,
        maximum=2)

  def action_spec(self) -> specs.DiscreteArray:
    """Returns the action spec."""
    return specs.DiscreteArray(
        dtype=int, num_values=len(_ACTIONS), name='action')

  def _observation(self) -> np.ndarray:
    self._board.fill(0.)
    self._board[self._ball_y, self._ball_x] = 2.
    self._board[self._rows - 1, self._paddle_x] = 1.

    return self._board.copy()



Before we start building an agent to interact with this environment, let's first look at the types of objects the environment returns (observations) and consumes (actions). The `environment_spec` will show you the form of the *observations*, *rewards* and *discounts* that the environment exposes and the form of the *actions* that can be taken:


In [None]:
environment = Catch()
print(environment.observation_spec())
print(environment.reward_spec())
print(environment.action_spec())

We can see that by default the <font color='#ed005a'>**observations**</font> consist of a matrix of shape (10, 5). You can change the size of the game by setting the `rows` and `columns` parameters to different values. The <font color='#0175c2'>**actions**</font> is a 1-D integer array with possible values [0, 1, 2]. Finally, the <font color='#00ba47'>**reward**</font> is a scalar.

Now we want to take an action using the `step` method to interact with the environment, which will return a `TimeStep` namedtuple with fields:

```none
step_type, reward, discount, observation
```

In [None]:
# Reset and initialise the environment.
step_type, reward, discount, observation = environment.reset()
plt.imshow(observation)
plt.show()

# Let's take a single action.
step_type, reward, discount, observation = environment.step(0)
plt.imshow(observation)
plt.show()

print('\nstep_type:', step_type)
print('reward:', reward)
print('discount:', discount)

`observation`: our observations are zero everywhere except where the paddle and ball are indicated by 1.

`step_type`: indicates whether we're at the beginning, middle, or end of the episode. For more details, look [here](https://github.com/deepmind/dm_env/blob/master/dm_env/_environment.py#L32).

`reward`: is the reward returned by the environment.

`discount`: is the discount factor $\gamma$. More details on what this is can be found in the *Agent->Run Loop* section below.

## The Agent
We now turn to the agent. An agent receives the current <font color='#ed005a'>**state**</font> and (previous) <font color='#00ba47'>**reward**</font> from the environment, then uses an internal policy to determine an <font color='#0175c2'>**action**</font> to take. We implement the agent as a Python [**class**](https://en.wikibooks.org/wiki/A_Beginner%27s_Python_Tutorial/Classes), which is just a logical wrapper of variables and methods (functions) that operate on those variables. The methods our agent will have are the following:


* ```__init__```:  Initialises the agent the first time it's created.
* `actor_step`: Receives the timestep information from the environment and returns an action.



# Random Agent

To get a feel for an agent and the methods it has, let's first implement an agent that ignores the observations and just takes a *random* action at every step.

The main piece of information we need in order to implement this agent is the number of available actions in this environment. We can get this information from the `num_values` attribute of an action spec `env.action_spec()`:

In [None]:
class RandomAgent(object):
  """An agent which simply takes random actions."""

  def __init__(self,
               action_spec: specs.DiscreteArray):
    self._num_actions = action_spec.num_values

  def actor_step(self, timestep: dm_env.TimeStep):
    # This agent is ignoring the observations, so we delete timestep.
    del timestep
    # Return a random integer between 0 and (self._num_actions - 1).
    return np.random.randint(self._num_actions)

### Run Loop

Now we can loop through the environment using our random agent until the environment is terminated. We call each sequence of interactions with the environment until the termination of the episode.

We also calculate the **episode return** at the end of an episode. The episode return is the sum of the (discounted) rewards obtained during the episode. If the reward for episode $i$ at time-step $t$ with trajectory $\tau_{i}$ is denoted $\color{#00ba47}{r_{i, t}}$, and the **discount factor** is $\gamma$, then the episode return is calculated as:


$$\color{#00ba47}{r}(\tau_i) = \sum_{t=1}^{T_i} \gamma^t \color{#00ba47}{r_{i,t}}$$

where $T_i$ is the episode length.

The discount factor allows us to increase the importance of rewards received quickly and decrease the importance of rewards that take long to receive. It is especially important in environments that could have infinitely long episodes. In our particular environment where every episode is of the same length and the only non-zero reward is received at the end of the game, the discount factor doesn't make much difference and so we will ignore it (effectively set it to $1$) for now.

In [None]:
num_episodes = 10

# Initialise the environment.
env = Catch()
timestep = env.reset()

# Initialise the agent.
agent = RandomAgent(env.action_spec())

# Run loop.
for episode in range(num_episodes):
  timesteps = []  # Accumulate data for the episode.

  # Prepare agent, environment and accumulator for a new episode.
  timestep = env.reset()

  while not timestep.last():
    timesteps.append(timestep)
    action = agent.actor_step(timestep)  # Acting.
    timestep = env.step(action)  # Agent-environment interaction.

  # Save the last timestep too.
  timesteps.append(timestep)

  # The first timestep is ignored due to having NaN as a reward.
  returns = sum([item.reward for item in timesteps[1:]])
  print(f'Episode {episode:2d}: Returns: {returns:.2f}.')

### Visualisation

Let's build a function to animate the observations.

In [None]:
def animate(data):
  fig = plt.figure(1)
  img = plt.imshow(data[0])
  plt.axis('off')

  def animate(i):
    img.set_data(data[i])

  anim = animation.FuncAnimation(fig, animate, frames=len(timesteps))
  plt.close(1)
  return anim

We can now look at the game play for the last episode:

In [None]:
animate([item.observation for item in timesteps])

### **[Coding Task]**
We assumed that $\gamma= 1$ in the code above in the return calculation. Modify the code so that it computes discounts other than 1.


# Value-Based Reinforcement Learning

Not surprisingly, our random agent is not really good at this game and we need to use some learning.


In **value-based** reinforcement learning methods, agents maintain a value for all state-action pairs and use those estimates to choose actions that maximise that value (instead of maintaining a policy directly like policy gradient methods, which we will cover later).

We represent the function mapping state-action pairs to values (otherwise known as a **Q-function**) for a specific policy $\pi$ in a given [MDP](https://en.wikipedia.org/wiki/Markov_decision_process) as:

$$ Q^{\pi}(\color{#ed005a}{s},\color{#0175c2}{a}) = \mathbb{E}_{\tau \sim P^{\pi}} \left[ \sum_t \gamma^t \color{#00ba47}{R_t}| s_0=\color{#ed005a}s,a=\color{#0175c2}{a_0} \right]$$

where $\tau = \{\color{#ed005a}{s_0}, \color{#0175c2}{a_0}, \color{#00ba47}{r_0}, \color{#ed005a}{s_1}, \color{#0175c2}{a_1}, \color{#00ba47}{r_1}, \cdots \}$. In other words, $Q^{\pi}(\color{#ed005a}{s},\color{#0175c2}{a})$ is the expected **value** (sum of discounted rewards) of being in a given <font color='#ed005a'>**state**</font> $\color{#ed005a}s$ and taking the <font color='#0175c2'>**action**</font> $\color{#0175c2}a$ and then following policy ${\pi}$ thereafter.

Efficient value estimations are based on the famous **_Bellman Optimality Equation_**:

$$ Q^\pi(\color{#ed005a}{s},\color{#0175c2}{a}) =  \color{#00ba47}{r}(\color{#ed005a}{s},\color{#0175c2}{a}) + \gamma  \sum_{\color{#ed005a}{s'}\in \color{#ed005a}{\mathcal{S}}} P(\color{#ed005a}{s'} |\color{#ed005a}{s},\color{#0175c2}{a}) V^\pi(\color{#ed005a}{s'}) $$

which breaks down $Q^{\pi}(\color{#ed005a}{s},\color{#0175c2}{a})$ into 2 parts: the immediate reward associated with being in state $\color{#ed005a}{s}$ and taking action $\color{#0175c2}{a}$, and the discounted sum of all future rewards. Note that $V^\pi$ here is the expected $Q^\pi$ value for a particular state, i.e.

$$V^\pi(\color{#ed005a}{s}) = \sum_{\color{#0175c2}{a} \in \color{#0175c2}{\mathcal{A}}} \pi(\color{#0175c2}{a} |\color{#ed005a}{s}) Q^\pi(\color{#ed005a}{s},\color{#0175c2}{a})$$

**Note**: If you have not previously encountered Reinforcement Learning theory, these definitions might seem a bit dense! We recommend the following resources for learning about reinforcement learning:
- [DeepMind x UCL RL Lecture Series](https://www.youtube.com/watch?v=TCCjZe0y4Qc)
- [Introduction to Reinforcement Learning with David Silver](https://www.deepmind.com/learning-resources/introduction-to-reinforcement-learning-with-david-silver)




## Q-learning Agent

One of the simplest forms of value-based learning is [Q-learning](https://en.wikipedia.org/wiki/Q-learning). To implement this, we are going to change the random agent as follows:

1. **Represent Q values.** We need to have a tabular representation of $Q$, which is a matrix of size `(number of states, number of actions)`. Our state space is the position of the ball and paddle in our grid so its size is $c*r*c$, where $r, c$ are the numbers of rows and columns, respectively. Our number of actions is 3 (move left, stay, move right).

2. **Implement a policy.** For now, we are going to implement a greedy policy that returns the action with the highest $Q$ value. $$\pi_{greedy} (\color{#0175c2}a|\color{#ed005a}s) = \arg\max_\color{#0175c2}a Q^{\pi_e}(\color{#ed005a}s,\color{#0175c2}a) $$

3. **Implement a learning step.** We need to add a new method to our agent class to do the learning step, meaning update the $Q$ values based on a learning algorithm. We will call this new method  `learner_step`. The learning algorithm that we are going to use is called [temporal difference learning](https://en.wikipedia.org/wiki/Temporal_difference_learning). We will be updating our $Q$ value estimates at each step with the following update rule: 

$$Q(\color{#ed005a}s, \color{#0175c2}a) \gets Q(\color{#ed005a}s, \color{#0175c2}a) + \alpha \delta$$

The size of the (usually small) $\alpha$ step size will influence how quickly our $Q$ values will be updated given new observations.

The measure of error, the TD-error $\delta$, is defined as:

$$\delta = \color{#00ba47}R + \gamma Q(\color{#ed005a}{s'}, \underbrace{\pi_e(\color{#ed005a}{s'}}_{\color{#0175c2}{a'}})) − Q(\color{#ed005a}s, \color{#0175c2}a)$$


In [None]:
class QlearningAgent(object):
  """Q-learning agent."""

  def __init__(self,
               action_spec: specs.DiscreteArray,
               observation_spec: specs.DiscreteArray,
               step_size: float = 0.1):
    self._num_actions = action_spec.num_values
    self._step_size = step_size
    r, c = observation_spec.shape
    self._q = np.zeros((c * r * c, self._num_actions))

  def _obs_to_index(self, obs):
    """Convert the observation into an index for accessing q values."""
    # The paddle location is always at the bottom.
    obs_shape = obs.shape
    paddle = np.where(obs[-1, :].flatten() == 1)[0][0]
    obs = obs.flatten().astype(int)
    # Case where the ball and paddle overlap.
    if obs.sum() == 1:
      ball = (obs_shape[0] - 1) * obs_shape[1]  + paddle
    else:
      ball = np.where(obs == 2)[0][0]
    return paddle * np.prod(obs_shape) + ball

  def actor_step(self, timestep):
    # Index into the Q value matrix.
    qvalue = self._q[self._obs_to_index(timestep.observation)]
    # Greedy policy.
    return np.argmax(qvalue)

  def learner_step(self, obs_tm1, a_tm1, r_t, discount_t, obs_t):
    # Offline Q-value update.
    obs_t = self._obs_to_index(obs_t)
    obs_tm1 = self._obs_to_index(obs_tm1)
    # Greedy policy.
    a_t = np.argmax(self._q[obs_t])
    td_error = r_t + discount_t * self._q[obs_t, a_t] - self._q[obs_tm1, a_tm1]
    self._q[obs_tm1, a_tm1] += self._step_size * td_error


To compute the loss, we need timestep information from both the current and last timesteps. We can achieve this by using local variables, but it would be cleaner code to define a new class to handle the data. We can call this new class `TransitionAccumulator`.

At each timestep, we will save the data using the `push` method and retrieve data using the `sample` method. The `sample` method returns the previous observation in addition to the data for the current timestep:

In [None]:
Transition = collections.namedtuple(
    'Transition', 'obs_tm1 a_tm1 r_t discount_t obs_t')

# For now, this only handles batch_size=1 but you can modify the code
# to handle other batch sizes.
class TransitionAccumulator:
  """Simple Python accumulator for transitions."""

  def __init__(self):
    self._prev = None
    self._action = None
    self._latest = None

  def push(self, env_output, action):
    self._prev = self._latest
    self._action = action
    self._latest = env_output

  def sample(self, batch_size):
    assert batch_size == 1
    return Transition(self._prev.observation, self._action, self._latest.reward,
                      self._latest.discount, self._latest.observation)

  def is_ready(self, batch_size):
    """Checks if there is previous data stored."""
    assert batch_size == 1
    return self._prev is not None

### Run Loop

There is one final addition that we need to make to finish our run loop: a learning step.

Since we are now training the agent using actual observations, we will need to run the environment for more episodes in order to gather sufficient data. We will only evaluate occasionally, every `evaluate_every` steps, to reduce the amount of logging info and computation (the latter point mostly applies later on when the agent's `actor_step` differs between train and evaluation time).

In [None]:
batch_size = 1
train_episodes = 100
evaluate_every = 10
eval_episodes = 10
seed = 1221

# Initialise the environment.
env = Catch()
timestep = env.reset()

# Build and initialise the agent.
agent = QlearningAgent(env.action_spec(),
                       env.observation_spec())

# Initialise the accumulator.
accumulator = TransitionAccumulator()

# Run loop
avg_returns = []

for episode in range(train_episodes):
  # Prepare agent, environment and accumulator for a new episode.
  timestep = env.reset()
  accumulator.push(timestep, None)
  while not timestep.last():
    # Acting.
    action = agent.actor_step(timestep)
    # Agent-environment interaction.
    timestep = env.step(action)
    # Accumulate experience.
    accumulator.push(timestep, action)
    # Learning.
    if accumulator.is_ready(batch_size):
      agent.learner_step(*accumulator.sample(batch_size))
   # Evaluation.
  if not episode % evaluate_every:
    returns = []
    for _ in range(eval_episodes):
      timestep = env.reset()
      timesteps = [timestep]
      while not timestep.last():
        action = agent.actor_step(timestep)
        timestep = env.step(action)
        timesteps.append(timestep)
      returns.append(np.sum([item.reward for item in timesteps[1:]]))

    avg_returns.append(np.mean(returns))
    print(f'Episode {episode:4d}: Average returns: {avg_returns[-1]:.2f}.')

Let's animate the last episode again:

In [None]:
animate([item.observation for item in timesteps])

Hmmm, looks like the agent could use a bit more work!

### **[Coding Task]**
A greedy policy with respect to a given estimate of $Q^\pi$ fails to explore the environment as needed. An $\epsilon$-greedy policy is a simple policy that at each time-step with probability $\epsilon$ will choose a random action instead of the greedy action. Update the QlearningAgent's policy to an $\epsilon$-greedy policy.


In [None]:
class EGQlearningAgent(object):
  """Epsilon-greedy Q learning agent."""

  def __init__(self,
               action_spec: specs.DiscreteArray,
               observation_spec: specs.DiscreteArray,
               epsilon: float = 0.1,
               step_size: float = 0.1):
    self._num_actions = action_spec.num_values
    self._epsilon = epsilon
    self._step_size = step_size
    r, c = observation_spec.shape
    self._q = np.zeros((c * r * c, self._num_actions))

  def _obs_to_index(self, obs):
    """Convert the observation into an index for accessing q values."""
    # The paddle location is always at the bottom.
    obs_shape = obs.shape
    paddle = np.where(obs[-1, :].flatten() == 1)[0][0]
    obs = obs.flatten().astype(int)
    # Case where the ball and paddle overlap.
    if obs.sum() == 1:
      ball = (obs_shape[0] - 1) * obs_shape[1]  + paddle
    else:
      ball = np.where(obs == 2)[0][0]
    return paddle * np.prod(obs_shape) + ball

  def actor_step(self, timestep, evaluation):
    # Index into the Q value matrix.
    qvalue = self._q[self._obs_to_index(timestep.observation)]
    # Epsilon-greedy policy.
    if np.random.random() > self._epsilon:
      train_a = np.argmax(qvalue)
    else:
      train_a = np.random.choice(self._num_actions)
    if evaluation:
      return np.argmax(qvalue)
    else:
      return train_a

  def learner_step(self, obs_tm1, a_tm1, r_t, discount_t, obs_t):
    # Offline Q-value update.
    obs_t = self._obs_to_index(obs_t)
    obs_tm1 = self._obs_to_index(obs_tm1)
    td_error = r_t + discount_t * np.max(self._q[obs_t]) - self._q[obs_tm1, a_tm1]
    self._q[obs_tm1, a_tm1] += self._step_size * td_error

In [None]:
# @title **[Solution]** Epilson-greedy run-loop { form-width: "30%" }

batch_size = 1
train_episodes = 100
evaluate_every = 10
eval_episodes = 10
seed = 1221

# Initialise the environment.
env = Catch(5, 2)  # Smaller environment to have smaller state.
timestep = env.reset()

# Build and initialise the agent.
agent = EGQlearningAgent(env.action_spec(),
                         env.observation_spec())

# Initialise the accumulator.
accumulator = TransitionAccumulator()

# Run loop.
avg_returns = []

for episode in range(train_episodes):

  # Prepare agent, environment and accumulator for a new episode.
  timestep = env.reset()
  accumulator.push(timestep, None)

  while not timestep.last():
    # Acting.
    action = agent.actor_step(timestep, False)
    # Agent-environment interaction.
    timestep = env.step(action)
    # Accumulate experience.
    accumulator.push(timestep, action)

    # Learning.
    if accumulator.is_ready(batch_size):
      agent.learner_step(*accumulator.sample(batch_size))

  # Evaluation.
  if not episode % evaluate_every:
    returns = []
    for _ in range(eval_episodes):
      timestep = env.reset()
      timesteps = [timestep]
      while not timestep.last():
        action = agent.actor_step(timestep, True)
        timestep = env.step(action)
        timesteps.append(timestep)
      returns.append(np.sum([item.reward for item in timesteps[1:]]))

    avg_returns.append(np.mean(returns))
    print(f"Episode {episode:4d}: Average returns: {avg_returns[-1]:.2f}.")

### **[Coding Task]**

Compare the result of greedy and epsilon greedy Q-Learning. Which one learns faster?

# Deep Reinforcement Learning

So far, we only considered look-up tables: in all the previous cases, every state and action pair $(\color{#ed005a}{s}, \color{#0175c2}{a})$, had an entry in our Q table. This is possible in this environment because the number of states is quite small. But this is not scalable to situations where, say, the goal location changes or the obstacles are in different locations at every episode (consider how big the table should be in this situation?).

An example (not covered in this tutorial) is playing ATARI from pixels, where the number of possible frames an agent can see is exponential in the number of pixels on the screen.

<center><img width="200" src="https://storage.googleapis.com/dm-educational/assets/reinforcement-learning-summer-school/atari.gif"></center>

But what we **really** want is just being able to *compute* the Q-value when fed with a particular $(\color{#ed005a}{s}, \color{#0175c2}{a})$ pair. So if we had a way to get a function to do this work instead of keeping a big table, we'd get around this problem.

To address this, we can use **function approximation** as a way to generalise Q-values over some representation of a very large state space, and **train** our approximator to get it to output accurate Q-value estimates. In this section, we will explore Q-learning with function approximation, which, although theoretically proven to diverge for some degenerate MDPs, can yield impressive results in very large environments. [Playing Atari with Deep Reinforcement Learning](https://deepmind.com/research/publications/playing-atari-deep-reinforcement-learning)  introduced the first deep learning model to successfully learn control policies directly from high-dimensional pixel inputs using RL, and we're going to implement a simplified version of that agent here!

<center><img src="https://storage.googleapis.com/dm-educational/assets/reinforcement-learning-summer-school/dqn.jpeg" width="500" /></center>

We will predict $Q(\color{#ed005a}s, \color{#0175c2}a)$ using a neural network $f()$, which given a vector $\color{#ed005a}s$, will output a vector of Q-values for all possible actions $\color{#0175c2}a$.$^2$

When using function approximations, particularly with neural networks, we need to have a loss to optimise. But looking back at the tabular setting above, you can see that we already have some notion of error: the **TD error**.

By training our neural network to output values such that the *TD error is minimised*, we will also satisfy the Bellman Optimality Equation, which is a good sufficient condition to enforce so that we may obtain an optimal policy.
Thanks to automatic differentiation, we can just write the TD error as a loss (e.g. with an $L2$ loss, but others would work too), compute its gradient (which are now gradients with respect to individual parameters of the neural network) and slowly improve our Q-value approximation:

$$Loss = \mathbb{E}\left[ \left( \color{#00ba47}{r} + \gamma \max_\color{#0175c2}{a'} Q(\color{#ed005a}{s'}, \color{#0175c2}{a'}) − Q(\color{#ed005a}{s}, \color{#0175c2}{a})  \right)^2\right]$$



## Neural Net-Based Q-Learning Agent

For our function approximator, we're going to use an [MLP](https://en.wikipedia.org/wiki/Multilayer_perceptron) that takes the observation and outputs Q values for each of the actions. We construct the MLP inside the `__init__` function. We are going to use [Jax](https://github.com/google/jax) and [Haiku](https://github.com/deepmind/dm-haiku) to implement and train our neural nets. Please have a look [here](https://dm-haiku.readthedocs.io/en/latest/api.html) to understand Haiku transformations.

One other class method we need to add to our agent is `initial_params`, which initialise the parameters of the neural network:

In [None]:
class QlearningAgent(object):
  """Q-learning agent."""

  def __init__(self,
               action_spec: specs.DiscreteArray,
               observation_spec: specs.DiscreteArray,
               num_hiddens: Sequence[int] = [50],
               epsilon: float = 0.01,
               learning_rate: float = 0.005):
    self._observation_spec = observation_spec
    self._num_actions = action_spec.num_values
    self._epsilon = epsilon
    self._optimizer = optax.adam(learning_rate)

    def network(obs):
      """Q network of the agent."""
      flatten = lambda x: jnp.reshape(x, (-1,))
      mlp = hk.Sequential(
          [flatten,
           hk.nets.MLP(num_hiddens + [self._num_actions])])
      return mlp(obs)

    self._network = hk.without_apply_rng(hk.transform(network, apply_rng=True))
    # Jitting for speed.
    self.actor_step = jax.jit(self.actor_step)
    self.learner_step = jax.jit(self.learner_step)

  def initial_params(self, rng_key):
    """Initialises the agent params given the RNG key."""
    sample_input = self._observation_spec.generate_value()
    sample_input = jnp.expand_dims(sample_input, 0)
    return self._network.init(rng_key, sample_input)

  def initial_learner_state(self, params):
    return self._optimizer.init(params)

  def actor_step(self, params, timestep, rng_key, evaluation):
    """Given the observation, computes the action using epsilon-greedy algorithm."""
    qvalues = self._network.apply(params, timestep.observation)
    if np.random.random() > self._epsilon:
      train_a = jnp.argmax(qvalues)
    else:
      train_a = jax.random.choice(rng_key, self._num_actions)

    # If evaluating, return the greedy action. Otherwise, return the
    # epsilon-greedy action.
    return jax.lax.select(evaluation, jnp.argmax(qvalues), train_a)

Now we have an agent that uses an MLP to compute Q-values. But in its current state, the MLP params are just initialised randomly and not changed at all. We need to add a TD-Learning algorithm to our agent.

`learner_step` will receive a collection of data that is collected from interacting with the environment using the `actor_step` function and then update the network parameters by computing the gradient of the loss function with respect to the network parameters.

We also need to add an optimiser to the optimisation. We are going to use the [Adam optimizer](https://arxiv.org/abs/1412.6980), which is a very popular optimiser that requires less tuning:

In [None]:
class QlearningAgent(object):
  """Q-learning agent."""

  def __init__(self,
               action_spec: specs.DiscreteArray,
               observation_spec: specs.DiscreteArray,
               num_hiddens: Sequence[int] = [50],
               epsilon: float = 0.01,
               learning_rate: float = 0.005):
    self._observation_spec = observation_spec
    self._num_actions = action_spec.num_values
    self._epsilon = epsilon
    self._optimizer = optax.adam(learning_rate)

    def network(obs):
      """Q network of the agent."""
      flatten = lambda x: jnp.reshape(x, (-1,))
      mlp = hk.Sequential(
          [flatten,
           hk.nets.MLP(num_hiddens + [self._num_actions])])
      return mlp(obs)

    self._network = hk.without_apply_rng(hk.transform(network, apply_rng=True))
    # Jitting for speed.
    self.actor_step = jax.jit(self.actor_step)
    self.learner_step = jax.jit(self.learner_step)

  def initial_params(self, rng_key):
    """Initialises the agent params given the RNG key."""
    sample_input = self._observation_spec.generate_value()
    sample_input = jnp.expand_dims(sample_input, 0)
    return self._network.init(rng_key, sample_input)

  def initial_learner_state(self, params):
    return self._optimizer.init(params)

  def actor_step(self, params, timestep, rng_key, evaluation):
    """Given the observation, computes the action using epsilon-greedy algorithm."""
    qvalues = self._network.apply(params, timestep.observation)
    if np.random.random() > self._epsilon:
      train_a = jnp.argmax(qvalues)
    else:
      train_a = jax.random.choice(rng_key, self._num_actions)

    # If evaluating, return the greedy action. Otherwise, return the
    # epsilon-greedy action.
    return jax.lax.select(evaluation, jnp.argmax(qvalues), train_a)

  def learner_step(self, params: hk.Params, data, learner_state, rng_key):
    """Computes loss, its gradient w.r.t. params, and runs an optimisation step."""
    dloss_dtheta, loss = jax.grad(self._loss, has_aux=True)(params, *data)
    updates, learner_state = self._optimizer.update(
        dloss_dtheta, learner_state)
    params = optax.apply_updates(params, updates)
    return params, learner_state, loss

  def _loss(self, params, obs_tm1, a_tm1, r_t, discount_t, obs_t):
    """Computes the TD error loss."""
    q_tm1 = self._network.apply(params, obs_tm1)
    q_t = self._network.apply(params, obs_t)

    chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t], [1, 0, 0, 0, 1])
    chex.assert_type([q_tm1, a_tm1, r_t, discount_t, q_t],
                     [float, int, float, float, float])

    target_tm1 = r_t + discount_t * jnp.max(q_t)
    target_tm1 = jax.lax.stop_gradient(target_tm1)
    td_error = target_tm1 - q_tm1[a_tm1]
    loss = 0.5 * td_error ** 2
    return loss, loss


### Run Loop
Now we are ready to write the training loop. We're going to use the same accumulator from the previous section. Before starting the training, we need to initialise the agent's and optimiser's parameters.

Similar to the tabular Q-learning agent, this agent also first acts in the environment and then accumulates the transition data inside an accumulator. Similarly, in `learner_step`, the agent does one step of learning and a parameter update:


In [None]:
batch_size = 1
train_episodes = 500
evaluate_every = 50
eval_episodes = 10
seed = 1221

rng = hk.PRNGSequence(jax.random.PRNGKey(seed))

# Initialise the environment.
env = Catch()
timestep = env.reset()

# Build and initialise the agent.
agent = QlearningAgent(env.action_spec(),
                       env.observation_spec())
params = agent.initial_params(next(rng))
learner_state = agent.initial_learner_state(params)

# Initialise the accumulator.
accumulator = TransitionAccumulator()

# Run loop.
avg_returns = []
losses = []

for episode in range(train_episodes):

  # Prepare agent, environment and accumulator for a new episode.
  timestep = env.reset()
  accumulator.push(timestep, None)

  while not timestep.last():
    # Acting.
    action = agent.actor_step(params, timestep, next(rng), False)
    # Agent-environment interaction.
    timestep = env.step(action)
    # Accumulate experience.
    accumulator.push(timestep, action)

    # Learning.
    if accumulator.is_ready(batch_size):
      params, learner_state, loss = agent.learner_step(
          params, accumulator.sample(batch_size), learner_state, next(rng))
      losses.append(np.asarray(loss))

  # Evaluation.
  if not episode % evaluate_every:
    returns = []
    for _ in range(eval_episodes):
      timestep = env.reset()
      timesteps = [timestep]
      while not timestep.last():
        action = agent.actor_step(params, timestep, next(rng), True)
        timestep = env.step(action)
        timesteps.append(timestep)
      returns.append(np.sum([item.reward for item in timesteps[1:]]))

    avg_returns.append(np.mean(returns))
    print(f"Episode {episode:4d}: Average returns: {avg_returns[-1]:.2f}")

Let's plot the loss and look at how it changes during training. But since we have a lot of data points, it's better to plot a moving average of the loss:

In [None]:
def moving_average(x, w):
  return np.convolve(x, np.ones(w), 'valid') / w

In [None]:
plt.plot(moving_average(losses, 50))

We can plot average returns during evaluations too:

In [None]:
plt.plot(avg_returns)

#### **[Coding Task]**
At the moment, our Q-learning agent's learning only works with `batch_size=1`, which is very inefficient (can you think of why this is?). Try to update the agent and the accumulator so that the training can be done with a batch size larger than 1.

## DQN Agent

The agent we implemented above, while very successful on some tasks like [TD-Gammon](https://en.wikipedia.org/wiki/TD-Gammon), suffers from divergence issues and is hard to train for more complicated tasks.

The [Deep Q-Learning Agent (DQN)](https://deepmind.com/research/publications/2019/playing-atari-deep-reinforcement-learning) improves on the Q-learning agent by incorporating two main ideas:

*   `Replay buffer`: "To alleviate the problems of correlated data and non-stationary distributions." [1]
*   `Target network`: "Use of an iterative update that adjusts the action-values (Q) towards target values that are only periodically updated, thereby reducing correlations with the target" [1]

First, let's make the replay buffer. We can modify our `TransitionAccumulator` slightly by adding a queue to collect more than one transition:


In [None]:
class ReplayBuffer(object):
  """A simple Python replay buffer."""

  def __init__(self, capacity, discount_factor=0.99):
    self._discount_factor = discount_factor
    self._prev = None
    self._action = None
    self._latest = None
    self.buffer = collections.deque(maxlen=capacity)

  def push(self, env_output, action):
    self._prev = self._latest
    self._action = action
    self._latest = env_output

    if action is not None:
      self.buffer.append(
          (self._prev.observation, self._action, self._latest.reward,
           self._latest.discount, self._latest.observation))

  def sample(self, batch_size):
    obs_tm1, a_tm1, r_t, discount_t, obs_t = zip(
        *random.sample(self.buffer, batch_size))
    return (jnp.stack(obs_tm1), jnp.asarray(a_tm1), jnp.asarray(r_t),
            jnp.asarray(discount_t) * self._discount_factor, jnp.stack(obs_t))

  def is_ready(self, batch_size):
    return batch_size <= len(self.buffer)

### The agent

The second ingredient we need is a *target network*.

At each iteration `i`, the target network computes the DQN loss $L_i$ on the parameters $\theta_i$, based on a the set of target parameters $\theta_{i-1}$ and a given batch of sampled trajectories `sample`. As described in the manuscript, the loss function is defined as:

$$L_i (\theta_i) = \mathbb{E}_{\color{#ed005a}{s},\color{#0175c2}{a} \sim \rho(\cdot)} \left[ \left( y_i - Q(\color{#ed005a}{s},\color{#0175c2}{a} ;\theta_i) \right)^2\right]$$

where the target $y_i$ is computed using a bootstrap value computed from Q-value network with target parameters:

$$ y_i = \mathbb{E}_{\color{#ed005a}{s'} \sim \mathcal{E}} \left[ \color{#00ba47}{r} + \gamma \max_{\color{#0175c2}{a'} \in \color{#0175c2}{\mathcal{A}}} Q(\color{#ed005a}{s'}, \color{#0175c2}{a'} ; \theta^{\text{target}}_i) \; | \; \color{#ed005a}{s}, \color{#0175c2}{a} \right] $$


In [None]:
Params = collections.namedtuple("Params", "online target")
LearnerState = collections.namedtuple("LearnerState", "count opt_state")

class DQNAgent(object):
  """Q-learning agent."""

  def __init__(self,
               action_spec: specs.DiscreteArray,
               observation_spec: specs.DiscreteArray,
               num_hiddens: Sequence[int] = [50],
               epsilon: float = 0.01,
               learning_rate: float = 0.005,
               target_period = 10):
    self._observation_spec = observation_spec
    self._num_actions = action_spec.num_values
    self._epsilon = epsilon
    self._target_period = target_period
    self._optimizer = optax.adam(learning_rate)

    def network(obs):
      """Q network of the agent."""
      # Unlike the previous version of the agent, here the observation has a
      # leading batch dimension. Hence, we can use hk.Flatten(), which will
      # flatten an array but leave the batch dimension intact.
      mlp = hk.Sequential(
          [hk.Flatten(),
           hk.nets.MLP(num_hiddens + [self._num_actions])])
      return mlp(obs)

    self._network = hk.without_apply_rng(hk.transform(network, apply_rng=True))
    # Jitting for speed.
    self.actor_step = jax.jit(self.actor_step)
    self.learner_step = jax.jit(self.learner_step)

  def initial_params(self, rng_key):
    """Initialises the agent params given the RNG key."""
    sample_input = self._observation_spec.generate_value()
    sample_input = jnp.expand_dims(sample_input, 0)
    online_params = self._network.init(rng_key, sample_input)
    return Params(online_params, online_params)

  def initial_learner_state(self, params):
    learner_count = jnp.zeros((), dtype=jnp.float32)
    opt_state = self._optimizer.init(params.online)
    return LearnerState(learner_count, opt_state)

  def actor_step(self, params, timestep, rng_key, evaluation):
    """Given the observation, computes the action using epsilon-greedy algorithm."""
    # The actor step works with batch size 1 but our network expects
    # the inputs to have a batch dimension.
    obs = jnp.expand_dims(timestep.observation, 0)  # Add dummy batch.
    qvalues = self._network.apply(params.online, obs)[0]  # Remove dummy batch.

    if np.random.random() > self._epsilon:
      train_a = jnp.argmax(qvalues)
    else:
      train_a = jax.random.choice(rng_key, self._num_actions)

    # If evaluating, return the greedy action. Otherwise, return the
    # epsilon-greedy action.
    return jax.lax.select(evaluation, jnp.argmax(qvalues), train_a)

  def learner_step(self, params: hk.Params, data, learner_state, rng_key):
    """Computes the loss and its gradient with respect to the parameters and
    does a step of optimisation."""
    # Update the target network parameters periodically.
    is_time = learner_state.count % self._target_period == 0
    target_params = jax.tree_map(
        lambda new, old: jax.lax.select(is_time, new, old),
        params.online, params.target)

    dloss_dtheta, loss = jax.grad(self._loss, has_aux=True)(
        params.online, target_params, *data)

    updates, opt_state = self._optimizer.update(
        dloss_dtheta, learner_state.opt_state)
    online_params = optax.apply_updates(params.online, updates)
    return (
        Params(online_params, target_params),
        LearnerState(learner_state.count + 1, opt_state),
        loss)

  def _loss(self, online_params, target_params, obs_tm1, a_tm1, r_t,
            discount_t, obs_t):
    """Computes the TD error loss."""
    q_tm1 = self._network.apply(online_params, obs_tm1)
    q_t_val = self._network.apply(target_params, obs_t)
    q_t_select = self._network.apply(online_params, obs_t)

    def q_learning_loss(q_tm1, a_tm1,  r_t, discount_t, q_t_value,
                        q_t_selector):
      target_tm1 = r_t + discount_t * q_t_value[q_t_selector.argmax()]
      target_tm1 = jax.lax.stop_gradient(target_tm1)
      return target_tm1 - q_tm1[a_tm1]

    batched_loss = jax.vmap(q_learning_loss)
    td_error = batched_loss(q_tm1, a_tm1, r_t, discount_t, q_t_val, q_t_select)
    loss = jnp.mean(0.5 * td_error ** 2)
    return loss, loss

### Run Loop

The training loop for the DQN agent is identical to the one we used above for the Q-learning agent. We only need to change the accumulator and the agent's class:

In [None]:
batch_size = 10  #@param
discount_factor = 0.99  #@param
replay_buffer_capacity = 100  #@param
train_episodes = 300  #@param
evaluate_every = 25  #@param
eval_episodes = 20  #@param
seed = 1221  #@param

rng = hk.PRNGSequence(jax.random.PRNGKey(seed))

# Initialise the environment.
env = Catch()
timestep = env.reset()

# Build and initialise the agent.
agent = DQNAgent(env.action_spec(),
                 env.observation_spec())
params = agent.initial_params(next(rng))
learner_state = agent.initial_learner_state(params)

# Initialise the accumulator.
accumulator = ReplayBuffer(replay_buffer_capacity, discount_factor)

# Run loop
avg_returns = []
losses = []

for episode in range(train_episodes):
  # Prepare agent, environment and accumulator for a new episode.
  timestep = env.reset()
  accumulator.push(timestep, None)

  while not timestep.last():
    # Acting.
    action = agent.actor_step(params, timestep, next(rng), False)
    # Agent-environment interaction.
    timestep = env.step(action)
    # Accumulate experience.
    accumulator.push(timestep, action)
    # Learning.
    if accumulator.is_ready(batch_size):
      params, learner_state, loss = agent.learner_step(
          params, accumulator.sample(batch_size), learner_state, next(rng))
      losses.append(np.asarray(loss))

  # Evaluation.
  if not episode % evaluate_every:
    returns = []
    for _ in range(eval_episodes):
      timestep = env.reset()
      timesteps = [timestep]
      while not timestep.last():
        action = agent.actor_step(params, timestep, next(rng), True)
        timestep = env.step(action)
        timesteps.append(timestep)
      returns.append(np.sum([item.reward for item in timesteps[1:]]))

    avg_returns.append(np.mean(returns))
    print(f"Episode {episode:4d}: Average returns: {avg_returns[-1]:.2f}.")


In [None]:
plt.plot(moving_average(losses, 50))

In [None]:
animate([item.observation for item in timesteps])

That's looking like a much better ball-catching agent already! :)

### **[Coding Task]**


*   Collect loss and average returns for the whole training run and plot them.
*   Play around with the parameters and observe their effect. A few suggestions:
> * Number of rows and columns for the game.
> * Number of hidden units and number of layers.
> * Learning rate.




## Policy Gradients

In this tutorial, we have looked at **value-based methods**. A popular alternative approach is using **policy gradient methods**.
The name "policy gradient" comes from the fact that we are estimating the gradient of the policy, rather than the alternative, value-based RL such as Q-learning, which uses iterative update rules to calculate the expected return associated with a state and action.

In order to learn, we need a loss function or *objective*. In RL, the general objective is to maximise the expected episode return (rewards) by taking actions in the environment. The actions our agent takes are determined by the policy $\pi_\theta(a|s)$, which is in turn determined by the neural network parameters $\theta$. So, we want to find the neural network parameters $\theta$ that maximise

$$J(\theta) = \mathbb{E}_{\tau}[r(\tau)]$$

Thanks for reading through this tutorial, we hope you enjoyed it and found it helpful for learning more about reinforcement learning! :)


# References


1.   Human-level control through deep reinforcement learning, V. Mnih *et al*., *Nature*, 2015. https://www.nature.com/articles/nature14236

The materials in this colab are heavily borrowed from the following sources:

1. [RLax](https://github.com/deepmind/rlax)
2. [EEML2020 RL Tutorial](https://colab.research.google.com/github/eemlcommunity/PracticalSessions2020/blob/master/rl/EEML2020_RL_Tutorial.ipynb#scrollTo=acqPbd8zXH_K)
3. [Indaba 2018](https://colab.sandbox.google.com/github/deep-learning-indaba/indaba-2018/blob/master/Practical_4_Reinforcement_Learning.ipynb#scrollTo=hQldYOWuu9RO)

