# Deep Reinforcement Learning practical

For some problems, formulated as Markov decision processes (MDP), the number of states and/or actions is very large, or even uncountable (continuous case). Such problems can not be solved by tabular RL methods, but often present an underlying structure or at least some continuity, making possible the use of approximative methods to approach the optimal value function. For more details, please refer to the monument written by Sutton and Bartho: *Reinforcement Learning: an Introduction, Part II*.

The **goal** of this practical is to introduce modern tools that allow to build the current state-of-the-art deep reinforcement learning (deep RL) algorithms. We will start with a very simple example: linear regression. Then, because we only have an hour and a half, you will play with the hyperparameters of the famous [Deep Q-learning algorithm (DQN)](https://arxiv.org/pdf/1312.5602). In this notebook, the DQN agent has to learn how to control a simple simulated [lunar lander](https://gymnasium.farama.org/environments/box2d/lunar_lander/).

Many RL algorithms derive from DQN. This notebook provides a standalone implementation, making it easy to implement extensions, variants or even new methods by a simple *inheritage* or *composition* (search Object-oriented programming for more information). Therefore, if you choose this practical for your **graded project**, you will implement and test an existing algorithm that derives from DQN, and provide some analysis of the results you have obtained.

**Project instructions:**

For your graded project, please focus on **ONE** of the options below. Also, we recommand to **ask questions** and **exchange with the supervisor** before you implement your variant. Your personal implementation will inherit from a single base class. You will only need to identify this class, choose which methods to override, and add hyperparameters if necessary. Please provide a .pdf with a detailed description of your project and your
code (notebook or .py). In your project, feel free to come up with your own ideas, scientific questions and modeling choices but please detail them. Provide a comparison with DQN, explain the theoretical differences between algorithms and how it affects the results.

- [QR-DQN](https://arxiv.org/pdf/1710.10044)
- [DQN + Prioritized Experience Replay](https://arxiv.org/pdf/1511.05952)



In [None]:
try:
    if "google.colab" in str(get_ipython()):
        print('Running on Google Colab.')
        !pip install swig
        !pip install gymnasium==0.29.1
        !pip install gymnasium[box2d]
    else:
        print("Not running on Google Colab. No install.")
except:
    print("Not running on Google Colab. No install.")


In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple

## 1) Linear regression with JAX

In this part, we use linear regression as a pretext to introduce JAX, an open source library developed by Google that provide all the necessary elements to build deep learning algorithms:

- Numpy-like tensor manipulations
- Parallel computation
- Automatic differentiation

As the parallel development of multiple disciplines have made deep learning possible, such as signal processing or physics, JAX can be used in other contexts than deep learning.

Also, you will see that JAX, contrary to PyTorch or Tensorflow, follows a functional paradigm. It has some advantages in terms of computing but also understanding, as the written program will be close to the written maths.

### Dataset generation

In [None]:
def gaussian_mixture(key, weights, mus, sigmas, shape):
    """
    Draw samples from a Gaussian mixture distribution.
    
    weights: array of shape (K,) that sum to 1
    mus: array of shape (K,)
    sigmas: array of shape (K,)
    shape: output shape, e.g. (n,)
    """
    K = len(weights)
    key_cat, key_norm = jax.random.split(key)

    # Step 1: choose mixture component for each sample
    comp_ids = jax.random.categorical(
        key_cat,
        jnp.log(weights),
        shape=shape
    )

    # Step 2: sample Gaussians
    eps = jax.random.normal(key_norm, shape)

    # Gather mu_k and sigma_k for each sample
    mu = mus[comp_ids]
    sigma = sigmas[comp_ids]

    return mu + sigma * eps

def generate_weird_data(key: jnp.ndarray, n=500):
    key1, key2, key3 = jax.random.split(key, 3)

    # Linear component
    x = jnp.linspace(-3, 3, n)

    # # Nonlinear twist + heteroscedastic + fat-tailed noise
    # noise_heavy = jax.random.t(df=10, key=key1)  # Student-t
    # noise = (0.3 + 0.15 * x**2) * noise_heavy

    # Mixture parameters
    weights = jnp.array([0.5, 0.4, 0.1])
    mus     = jnp.array([0, 2.0, -3.0])
    sigmas  = jnp.array([2.0, 1.2, 0.4])*10

    # heteroscedastic mixture noise
    noise = gaussian_mixture(key1, weights, mus, sigmas, (n,))
    noise = noise * (0.4 + 0.1 * x**2)

    y = 2.0 * x + 1.0 + 0.8*jnp.sin(2*x) + noise

    # Add a few big outliers
    idx = jax.random.choice(key2, n, (10,), replace=False)
    y = y.at[idx].set(y[idx] + jax.random.normal(key3, (10,)) * 10)

    return x, y

In [None]:
# In JAX we need to propagate a key for pseudo-random number generation
key = jax.random.PRNGKey(0)

# Then we generate our date according to some weird distribution
x, y = generate_weird_data(key)
print("x:", x.shape, x.dtype)
print("y:", y.shape, y.dtype)

In [None]:
fig, ax = plt.subplots()
ax.scatter(x, y, marker="+")
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$")
ax.set_title("Dataset")
# ax.legend()
fig.tight_layout()

### Linear Least Squares (LLS)

The $(x_i, y_i)$ datapoints seem to follow a linear tendency. As a result, we propose a linear model to describe our data on average. Thus we want to minimize the sum of squares: $\mathcal{L}(a, b) = \sum_i (a x_i + b - y_i )^2$ to find the parameters $(a, b)$ of the linear model.

#### i) Analytical solution

This problem has an analytical solution. If $x = (x_1, ..., x_n)$ and $y = (y_1, ..., y_n)$ then: $\mathcal{L}(a, b) = || a x + b - y ||_2^2$.

Such a function is convex with respect to $a, b$. Therefore, its gradient with respect to $(a, b)$ is $0$ if and only if $(a, b)$ is the global minimum. This leads to the solving of a linear system of equations.

In [None]:
x = np.array(x)
y = np.array(y)
A = np.zeros((2, 2))
A[0, 0] = np.sum(np.square(x))
A[0, 1] = np.sum(x)
A[1, 0] = np.sum(x)
A[1, 1] = x.shape[0]
B = np.zeros((2,))
B[0] = np.dot(x, y)
B[1] = np.sum(y)
params = np.linalg.inv(A) @ B 
print("[a, b] =", params)

In [None]:
fig, ax = plt.subplots()
ax.scatter(x, y, marker="+")
ax.plot(x, params[0]*x + params[1], label="Analytical", color="red")
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$")
ax.set_title("Analytical LLS")
# ax.legend()
fig.tight_layout()
# fig.savefig("linear_regression.pdf")
# plt.close(fig)

#### ii) Iterative solving via gradient descent

Some problems do not have an analytical solution but can be solved iteratively. Before using such methods, we can check whether they work in a case where the solution is known. 

To train deep learning models, we specify a loss function that our model should minimize given some data, and perform gradient descent to find optimal parameters. The process is always the same, even for a simple LLS regression.

We define the loss function ```linear_least_square_loss```. **WARNING:** The first parameter is ```params``` over which the gradient will be computed. This function returns a tuple ```loss_value, metrics```. Loss value will effectively be used by JAX to compute the gradient while the metrics provide some information on the learning process.

Function ```jax.grad``` (see JAX documentation) computes the gradient thanks to the backpropagation algorithm. The following function ```gradient_step``` encapsulates that operation.

In [None]:
def linear_least_square_loss(
    params: jnp.ndarray,
    x: jnp.ndarray,
    y: jnp.ndarray
) -> Tuple[jnp.ndarray, dict]:
    y_model = params[0] * x + params[1]
    loss = jnp.mean(jnp.square(y_model - y))
    return loss, {"loss": loss, "a": params[0], "b": params[1]}

def gradient_step(
    params: jnp.ndarray,
    x: jnp.ndarray,
    y: jnp.ndarray,
    step: float
) -> Tuple[jnp.ndarray, dict]:
    grad, metrics = jax.grad(linear_least_square_loss, has_aux=True)(
        params, x, y
    )
    return params - step * grad, metrics

**Question 1:** Implement a for loop that performs $100$ steps of gradient descent with a gradient step of 1e-1 (two lines).

In [None]:
learning_rate = 1e-1
nb_steps = 100
params_gd = jnp.zeros((2,))
# BEGIN SOLUTION
raise NotImplementedError
# END SOLUTION
print(metrics)
print("[a, b] =", params_gd)

In [None]:
fig, ax = plt.subplots()
ax.scatter(x, y, marker="+")
ax.plot(x, params[0]*x + params[1], label="Analytical", color="red")
ax.plot(x, params_gd[0]*x + params_gd[1], label="Simple GD", color="purple")
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$")
ax.set_title("Analytical vs Simple gradient descent")
ax.legend()
fig.tight_layout()
# fig.savefig("linear_regression.pdf")
# plt.close(fig)

To prevent people from reinventing the wheel, some packages exist around JAX, such as ```optax``` that implements many existing algorithms and routines for optimization. Therefore, we can propose another code based on ```optax```. For optimization, we choose the [Adam optimizer](https://arxiv.org/abs/1412.6980) which in this specific case does not really have sense, because Adam has been developed for stochastic optimization, but why not. The optimization process itself has an internal state that allows to take the loss curvature into account when performing the descent. Fortunately, we don't have to know these internal parameters to perform a descent. We simply memorize them via a variable ```opt_state```.

In [None]:
import optax
optimizer = optax.adam(learning_rate=learning_rate)

The following cell performs Adam gradient descent with ```optax```.

Algorithm:
- Initialize the optimizer state: ```opt_state = optimizer.init(params_gd_adams)```
- For ```nb_steps```:
    - Compute the gradient of the loss with ```jax.grad```
    - Compute the parameter updates an the new optimizer state ```updates, opt_state = optimizer.update(grad, opt_state, params_gd_adams)```
    - Use ```optax.apply_updates``` to apply parameters modifications.

**Don't hesitate to take a look at the ```optax``` documentation.**

In [None]:

params_gd_adams = jnp.zeros((2,))
opt_state = optimizer.init(params_gd_adams)
for _ in range(nb_steps):
    grad, metrics = jax.grad(linear_least_square_loss, has_aux=True)(
        params_gd_adams, x, y
    )
    updates, opt_state = optimizer.update(grad, opt_state, params_gd_adams)
    params_gd_adams = optax.apply_updates(params_gd_adams, updates)
print(metrics)
print("[a, b] =", params_gd_adams)

In [None]:
fig, ax = plt.subplots()
ax.scatter(x, y, marker="+")
ax.plot(x, params[0]*x + params[1], label="Analytical", color="red")
ax.plot(x, params_gd[0]*x + params_gd[1], label="Simple GD", color="purple")
ax.plot(x, params_gd_adams[0]*x + params_gd_adams[1], label="Adam GD", color="green")
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$")
ax.set_title("Comparison of the three methods")
ax.legend()
fig.tight_layout()
# fig.savefig("linear_regression.pdf")
# plt.close(fig)

#### Conclusion on LLS

In this case, the number of points is reasonable making possible the analytical solving of the LLS problem. If we had way more datapoints, computing the analytical solution would be almost impossible. Also, for many problems, like image classification, we don't even know the analytical solution. However, computing a loss and its gradient with respect to the parameters of our model is easy. In deep learning we compute the gradient over batches of data to approximate the direction of descent. This process is called Stochastic Gradient Descent (SGD). Adam is a state-of-the-art algorithm in the SGD family.

The LLS solving via gradient descent is very interesting as it is very simple but uses the same tools as complex deep learning learning algorithms.

### Quantile regression

Instead of finding the line that passes on average through all the data points, we will find the line that splits our data into two parts. One part will contain a certain proportion of the data, and the other will contain the remaining points. This is known as quantile regression.

Below you will find the loss associated to quantile regression. Then, taking inspiration from the LLS case, you will implement the gradient descent algorithm with ```optax```.

In [None]:
def huber(td: jnp.ndarray) -> jnp.ndarray:
    """Huber function."""
    abs_td = jnp.abs(td)
    return jnp.where(abs_td <= 1.0, jnp.square(td), abs_td)

def quantile_loss(
    params: jnp.ndarray,
    x: jnp.ndarray,
    y: jnp.ndarray,
    expectile: float
) -> Tuple[jnp.ndarray, dict]:
    y_model = params[0] * x + params[1]
    difference = y - y_model
    element_wise_loss = huber(difference)
    element_wise_loss *= jax.lax.stop_gradient(
        jnp.abs(expectile - (difference < 0))
    )
    loss = jnp.mean(element_wise_loss)
    return loss, {"loss": loss, "a": params[0], "b": params[1]}

**Question 2:** Implement the gradient descent algorithm using ```optax```. (Few lines)

Take inspiration from LLS and use the same ```optimizer```.

*Don't hesitate to take a look at the ```optax``` documentation.*

**Question 3:** Execute the code with ```expectile``` equal to 0.9. Then play with the value to see what happens on the plot. 

In [None]:
expectile = 0.9
assert expectile > 0 and expectile < 1.0
nb_steps = 1_000
params_qr = jnp.zeros((2,))
# BEGIN SOLUTION
raise NotImplementedError
# END SOLUTION
print(metrics)
print("[a, b] =", params_qr)

**Question 4:** How far is $a$ from LLS $a$ ? How about $b$ ? Why ? What is the meaning of ```expectile``` ? Why should it lie between 0 and 1 ?

In [None]:
fig, ax = plt.subplots()
ax.scatter(x, y, marker="+")
ax.plot(x, params_qr[0]*x + params_qr[1], label=r"QR ($\tau = {}$)".format(expectile), color="red")
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$")
ax.set_title("Linear Quantile Regression")
ax.legend()
fig.tight_layout()

#### Conclusion on Linear Quantile Regression

Quantile regression is very useful when we are interested in finding "best cases" or "worst cases". For example, with expectile $0.99$, we know that 99% of our points lie under the line.

Least squares are theoretically related to the expectation, which is why it is widely used in RL. Quantile regression is at the core of quantile-based distributional RL methods (QR-DQN, TQC, Worst-case SAC, etc). This family of methods can capture more information on the distribution of the return, allowing to solve worst-case problems or tackle overestimation.

In both cases, a gradient step over a regression loss is performed to reduce a Bellman error, which is why we studied a simple linear regression example in this practical.

### General conclusion on JAX 

What we have seen:
- LLS and linear quantile regression
- Automatic differentiation of a loss function in JAX
- Gradient descent and its implementation with optax.

- Note that whether a model is linear or not—even a neural network—the code for gradient updates, and often even for the loss functions, remains almost identical. 

What we have not covered:
- JAX just-in-time compilation to accelerate computation (jax.jit)
- JAX parallelization (jax.vmap, jax.pmap)
- Neural network implementation (Flax)

## 2) Deep Q-learning

In this part, we propose to play with an implementation of [DQN](https://arxiv.org/abs/1312.5602) and [DoubleDQN](https://arxiv.org/abs/1509.06461). The agent has to solve the [Lunar Lander v2](https://gymnasium.farama.org/environments/box2d/lunar_lander/) environment. 

As the computation time is about 10 minutes, different hyperparameters will be assigned to different groups. Then we will draw some comparisons among groups to highlight the effect of different hyperparameters. 

In [None]:
import os
import numpy as np
import functools
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
import flax
import flax.linen as nn
import optax
import gymnasium as gym
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

### Buffer and Sampler (Nothing to do)

The replay buffer (Buffer class) stores the transitions during training. At each gradient step, a batch of transitions is sampled from the replay buffer (Sampler class). In the DQN class (a few cells below), the gradient over parameters is computed for each batch item. The mean over these gradient gives the orientation of the minimization (See [SGD](https://en.wikipedia.org/wiki/Stochastic_gradient_descent)).


The following cell is adapted from [**xpag**](https://github.com/perrin-isir/xpag), an RL library developed by Nicolas Perrin-Gilbert et al.

Copyright (c) 2022-2023, CNRS - Licensed under BSD 3-Clause License.

**Take a look if interested.**

In [None]:
# This file is adapted from xpag
# Original source: https://github.com/perrin-isir/xpag/xpag/buffers/buffer.py
# Commit: fad4d9cf77053cf29b42f322091bdf182012a501
# Copyright (c) 2022-2023, CNRS - Licensed under BSD 3-Clause License.


from abc import ABC, abstractmethod
from enum import Enum

class DataType(Enum):
    NUMPY = "data represented as numpy arrays"
    JAX = "data represented as jax.numpy arrays"


def get_datatype(x: Union[np.ndarray, jnp.ndarray]) -> DataType:
    if isinstance(x, jnp.ndarray):
        return DataType.JAX
    elif isinstance(x, np.ndarray):
        return DataType.NUMPY
    else:
        raise TypeError(f"{type(x)} not handled.")


def datatype_convert(
    x: Union[np.ndarray, jnp.ndarray, list, float],
    datatype: Union[DataType, None] = DataType.NUMPY,
) -> Union[np.ndarray, jnp.ndarray]:
    if datatype is None:
        return x
    elif datatype == DataType.NUMPY:
        if isinstance(x, np.ndarray):
            return x
        else:
            return np.array(x)
    elif datatype == DataType.JAX:
        if isinstance(x, jnp.ndarray):
            return x
        else:
            return jnp.array(x)


class Sampler(ABC):
    def __init__(self, *, seed: Union[int, None] = None):
        self.seed = seed
        self.rng = np.random.default_rng(seed)
        pass

    @abstractmethod
    def sample(
        self,
        buffer,
        batch_size: int,
    ) -> Dict[str, Union[np.ndarray, jnp.ndarray]]:
        """Return a batch of transitions"""
        pass


class DefaultSampler(Sampler):
    def __init__(self, *, seed: Union[int, None] = None):
        super().__init__(seed=seed)

    def sample(
        self,
        buffer: Dict[str, Union[np.ndarray]],
        batch_size: int,
    ) -> Dict[str, Union[np.ndarray]]:
        buffer_size = next(iter(buffer.values())).shape[0]
        idxs = self.rng.choice(
            buffer_size,
            size=batch_size,
            replace=True,
        )
        transitions = {key: buffer[key][idxs] for key in buffer.keys()}
        return transitions


class Buffer(ABC):
    """Base class for buffers"""

    def __init__(
        self,
        buffer_size: int,
        sampler: Optional[Sampler] = None,
    ):
        self.buffer_size = buffer_size
        self.sampler = sampler

    @abstractmethod
    def insert(self, step: Dict[str, Any]):
        """Inserts a transition in the buffer"""
        pass

    @abstractmethod
    def sample(self, batch_size) -> Dict[str, Union[np.ndarray, jnp.ndarray]]:
        """Uses the sampler to returns a batch of transitions"""
        pass


class DefaultBuffer(Buffer):
    def __init__(
        self,
        buffer_size: int,
        sampler: Sampler,
    ):
        super().__init__(buffer_size, sampler)
        self.current_size = 0
        self.buffers = {}
        self.size = buffer_size
        self.dict_sizes = None
        self.num_envs = None
        self.keys = None
        self.zeros = None
        self.where = None
        self.first_insert_done = False

    def init_buffer(self, step: Dict[str, Any]):
        self.dict_sizes = {}
        self.keys = list(step.keys())
        assert "terminated" in self.keys
        for key in self.keys:
            if isinstance(step[key], dict):
                for k in step[key]:
                    assert len(step[key][k].shape) == 2
                    self.dict_sizes[key + "." + k] = step[key][k].shape[1]
            else:
                assert len(step[key].shape) == 2
                self.dict_sizes[key] = step[key].shape[1]
        self.num_envs = step["terminated"].shape[0]
        for key in self.dict_sizes:
            self.buffers[key] = np.zeros([self.size, self.dict_sizes[key]])
        self.zeros = lambda i: np.zeros(i).astype("int")
        self.where = np.where
        self.first_insert_done = True

    def insert(self, step: Dict[str, Any]):
        if not self.first_insert_done:
            self.init_buffer(step)
        idxs = self._get_storage_idx(inc=self.num_envs)
        for key in self.keys:
            if isinstance(step[key], dict):
                for k in step[key]:
                    self.buffers[key + "." + k][idxs, :] = datatype_convert(
                        step[key][k], DataType.NUMPY
                    ).reshape((self.num_envs, self.dict_sizes[key + "." + k]))
            else:
                self.buffers[key][idxs, :] = datatype_convert(
                    step[key], DataType.NUMPY
                ).reshape((self.num_envs, self.dict_sizes[key]))

    def pre_sample(self):
        temp_buffers = {}
        for key in self.buffers.keys():
            temp_buffers[key] = self.buffers[key][: self.current_size]
        return temp_buffers

    def sample(self, batch_size):
        return self.sampler.sample(self.pre_sample(), batch_size)

    def _get_storage_idx(self, inc=None):
        inc = inc or 1
        if self.current_size + inc <= self.size:
            idx = np.arange(self.current_size, self.current_size + inc)
        elif self.current_size < self.size:
            overflow = inc - (self.size - self.current_size)
            idx_a = np.arange(self.current_size, self.size)
            idx_b = np.random.randint(0, self.current_size, overflow)
            idx = np.concatenate([idx_a, idx_b])
        else:
            idx = np.random.randint(0, self.size, inc)
        self.current_size = min(self.size, self.current_size + inc)
        return idx

    def save(self, directory: str):
        os.makedirs(directory, exist_ok=True)
        list_vars = [
            ("current_size", self.current_size),
            ("buffers", self.buffers),
            ("size", self.size),
            ("dict_sizes", self.dict_sizes),
            ("num_envs", self.num_envs),
            ("keys", self.keys),
            ("first_insert_done", self.first_insert_done),
        ]
        for cpl in list_vars:
            with open(os.path.join(directory, cpl[0] + ".joblib"), "wb") as f_:
                joblib.dump(cpl[1], f_)

    def load(self, directory: str):
        self.current_size = joblib.load(os.path.join(directory, "current_size.joblib"))
        self.buffers = joblib.load(os.path.join(directory, "buffers.joblib"))
        self.size = joblib.load(os.path.join(directory, "size.joblib"))
        self.dict_sizes = joblib.load(os.path.join(directory, "dict_sizes.joblib"))
        self.num_envs = joblib.load(os.path.join(directory, "num_envs.joblib"))
        self.keys = joblib.load(os.path.join(directory, "keys.joblib"))
        self.first_insert_done = joblib.load(
            os.path.join(directory, "first_insert_done.joblib")
        )
        self.zeros = lambda i: np.zeros(i).astype("int")
        self.where = np.where

### From Multi-Layer Perceptron (MLP) to Deep Q-network (Nothing to do)

Since the state of our environment is represented as a vector rather than an image or any other high dimensional input, the deep Q-network reduces to a standard multi-layer perceptron (MLP). It consists of several hidden layers and produces one output per possible action.

**Take a look if interested.**

In [None]:
def default_init(scale: Optional[float] = jnp.sqrt(2)):
    return nn.initializers.orthogonal(scale)


class MLP(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activate_final: int = False
    init_last_layer_zeros: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for i, size in enumerate(self.hidden_dims):
            if i + 1 == len(self.hidden_dims) and self.init_last_layer_zeros:
                x = nn.Dense(
                    size,
                    kernel_init=nn.initializers.zeros_init(),
                    bias_init=nn.initializers.zeros_init(),
                )(x)
            else:
                x = nn.Dense(size, kernel_init=default_init())(x)
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                x = self.activations(x)
        return x


class DeepQNetwork(nn.Module):
    action_dim: int
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activate_final: int = False
    init_last_layer_zeros: bool = False

    @nn.compact
    def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
        action_values = MLP(
            hidden_dims=(*self.hidden_dims, self.action_dim),
            activations=self.activations,
            activate_final=self.activate_final,
            init_last_layer_zeros=self.init_last_layer_zeros
        )(observations)
        return action_values

### DQN and DoubleDQN (Nothing to do)

The DQN class represents a DQN agent. It can initialize the Q-network parameters, act based on received observation, and learn from batches of transitions.
 
The only difference between DoubleDQN and DQN is the loss, which is why the DoubleDQN class inherits from the DQN class.

**If interested:** take a look from the bottom to the top of the cell.

In [None]:
Params = flax.core.FrozenDict[str, Any]
Metrics = Dict[str, Any]

def random_key_split(rng: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    return jax.random.split(rng)

def init(
    model_def: nn.Module,
    inputs: Sequence[jnp.ndarray],
    tx: Optional[optax.GradientTransformation] = None,
) -> Tuple[Params, optax.OptState]:
    """Generic initialization function"""
    # Initialize params
    variables = model_def.init(*inputs)
    params = variables.pop("params")

    # Initialize optimizer state
    if tx is not None:
        opt_state = tx.init(params)
    else:
        opt_state = None

    return params, opt_state

def grad_norm(grad):
    flattened_grads, _ =  ravel_pytree(grad)
    return jax.numpy.linalg.norm(flattened_grads)

def get_shapes(params: Params):
    return jax.tree_util.tree_map(lambda p: p.shape, params)

def apply_updates(
    optimizer: optax.GradientTransformation,
    grad: Params, 
    opt_state: optax.OptState, 
    params: Params,
) -> Tuple[Params, optax.OptState, Params]:
    updates, new_opt_state = optimizer.update(grad, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, updates

def ema_target_update(
    params: Params, target_params: Params, tau: float
) -> Params:
    new_target_params = jax.tree_util.tree_map(
        lambda p, tp: p * tau + tp * (1 - tau), params, target_params
    )

    return new_target_params

def huber(td: jnp.ndarray) -> jnp.ndarray:
    """Huber function."""
    abs_td = jnp.abs(td)
    return jnp.where(abs_td <= 1.0, jnp.square(td), abs_td)


class DQN:
    """Class implementing a DQN algorithm."""
    def __init__(
        self,
        observation_dim: int,
        action_dim: int,
        gamma: float = 0.99,
        learning_rate = 3e-4,
        update_target_every_x_steps: int = 10_000,
        network_features: dict = dict(hidden_dims=(128, 128)),
        eps_greedy_hp: dict = dict(
            eps_decrease="exponential",
            eps_end=0.01,
            eps_end_at=4e5
        )
    ):
        # Observation and action spaces
        self.observation_dim = observation_dim
        self.action_dim = action_dim

        # Discount factor
        self.gamma = gamma

        # Epsilon-greedy parameters
        self.eps_greedy_hp = eps_greedy_hp.copy()
        if self.eps_greedy_hp["eps_decrease"] == "exponential":
            self.eps_dec = self.eps_greedy_hp["eps_end"]**(
                1 / self.eps_greedy_hp["eps_end_at"]
            )
            print("DQN.eps_dec (exponential):", self.eps_dec)
            self.update_epsilon = self.update_epsilon_exponential
        elif self.eps_greedy_hp["eps_decrease"] == "linear":
            self.eps_dec = (
                self.eps_greedy_hp["eps_end"] - 1.0
            ) / self.eps_greedy_hp["eps_end_at"]
            print("DQN.eps_dec (linear):", self.eps_dec)
            self.update_epsilon = self.update_epsilon_linear
        else:
            raise NotImplementedError

        # Optimization
        self.optimizer = optax.adam(learning_rate=learning_rate)
        self.update_target_every_x_steps = update_target_every_x_steps

        # Neural network
        self.network = self.make_network(network_features)
    
    def make_network(self, network_features: dict) -> DeepQNetwork:
        return DeepQNetwork(
            action_dim=self.action_dim,
            **network_features
        )
    
    def init(self, rng: jnp.ndarray) -> Tuple[Params, optax.OptState, Params, float]:
        init_obs = jnp.ones((1, 1, self.observation_dim))

        params, opt_state = init(
            model_def=self.network,
            inputs=[rng, init_obs],
            tx=self.optimizer
        )
        print(get_shapes(params))

        target_params, _ = init(
            model_def=self.network,
            inputs=[rng, init_obs]
        )
        print(get_shapes(target_params))

        epsilon = 1.0
        print("epsilon:", epsilon)

        return params, opt_state, target_params, epsilon
    
    @functools.partial(
        jax.jit,
        static_argnames=(
            "self",
        )
    )
    def select_action(
        self,
        params: Params,
        observation: jnp.ndarray
    ) -> jnp.ndarray:
        print("COMPILE: DQN.select_action")
        actions_values = self.network.apply({"params": params}, observation)
        return jnp.argmax(actions_values, axis=-1)
    
    def select_single_action(
        self,
        params: Params,
        observation: np.ndarray
    ) -> jnp.ndarray:
        return int(
            jnp.squeeze(
                self.select_action(
                    params=params,
                    observation=jnp.expand_dims(jnp.array(observation), axis=0)
                )
            )
        )
    
    def dqn_loss(
        self,
        params: Params,
        target_params: Params,
        observations: jnp.ndarray,
        actions: jnp.ndarray,
        next_observations: jnp.ndarray,
        rewards: jnp.ndarray,
        masks: jnp.ndarray
    ) -> Tuple[jnp.ndarray, Metrics]:
        actions_values = self.network.apply({"params": params}, observations)
        actions = jnp.array(actions, dtype=jnp.int32)
        current_action_values = jnp.take_along_axis(
            actions_values, actions, axis=-1
        ).squeeze(axis=-1)
        next_actions_values = self.network.apply({"params": target_params}, next_observations)
        next_values = jnp.max(next_actions_values, axis=-1)
        rewards = jnp.squeeze(rewards, axis=-1)
        masks = jnp.squeeze(masks, axis=-1)
        target_values = rewards + self.gamma * masks * next_values
        print(rewards.shape, masks.shape, next_values.shape)
        print(current_action_values.shape, target_values.shape)
        batch_loss = huber(current_action_values - target_values)
        loss = jnp.mean(batch_loss)
        return loss, {
            "loss": loss,
            "q_mean": jnp.mean(current_action_values),
            "next_q_mean": jnp.mean(next_values)
        }
    
    @functools.partial(
        jax.jit,
        static_argnames=(
            "self",
        )
    )
    def gradient_step(
        self,
        params: Params,
        opt_state: optax.OptState,
        target_params: Params,
        batch: Dict[str, np.ndarray]
    ) -> tuple:
        print("COMPILE: DQN.gradient_step")
        grad, metrics = jax.grad(self.dqn_loss, has_aux=True)(
            params,
            target_params=target_params,
            observations=batch["observation"],
            actions=batch["action"],
            next_observations=batch["next_observation"],
            rewards=batch["reward"],
            masks=1-batch["terminated"]
        )
        params, opt_state, updates = apply_updates(
            optimizer=self.optimizer,
            grad=grad,
            opt_state=opt_state,
            params=params
        )
        return params, opt_state, updates, grad, metrics
    
    def update_epsilon_exponential(self, epsilon: float) -> float:
        return max(self.eps_greedy_hp["eps_end"], self.eps_dec * epsilon)
    
    def update_epsilon_linear(self, epsilon: float) -> float:
        return max(self.eps_greedy_hp["eps_end"], self.eps_dec + epsilon)
    
    def update_target_params(
        self,
        params: Params,
        target_params: Params,
        step: int
    ) -> Params:
        if step % self.update_target_every_x_steps == 0:
            return params
        else:
            return target_params
    
    def update(
        self,
        params: Params,
        opt_state: optax.OptState,
        target_params: Params,
        epsilon: float,
        batch: Dict[str, np.ndarray],
        step: int
    ) -> tuple:
        # Perform gradient descent step
        params, opt_state, updates, grad, metrics = self.gradient_step(
            params=params,
            opt_state=opt_state,
            target_params=target_params,
            batch=batch
        )

        # Update epsilon
        epsilon = self.update_epsilon(epsilon)

        # Update target params
        target_params = self.update_target_params(
            params=params,
            target_params=target_params,
            step=step
        )

        return params, opt_state, target_params, epsilon, metrics


class DoubleDQN(DQN):
    def dqn_loss(
        self,
        params: Params,
        target_params: Params,
        observations: jnp.ndarray,
        actions: jnp.ndarray,
        next_observations: jnp.ndarray,
        rewards: jnp.ndarray,
        masks: jnp.ndarray
    ) -> Tuple[jnp.ndarray, Metrics]:
        # Compute current value
        actions_values = self.network.apply({"params": params}, observations)
        actions = jnp.array(actions, dtype=jnp.int32)
        current_action_values = jnp.take_along_axis(
            actions_values, actions, axis=-1
        ).squeeze(axis=-1)
        # Compute next actions according to current q function
        next_actions_values = jax.lax.stop_gradient(
            self.network.apply({"params": params}, next_observations)
        )
        next_actions = jnp.argmax(next_actions_values, axis=-1, keepdims=True)
        # Compute target values according to https://arxiv.org/pdf/1509.06461
        target_next_actions_values = self.network.apply(
            {"params": target_params}, next_observations
        )
        next_values = jnp.take_along_axis(
            target_next_actions_values, next_actions, axis=-1
        ).squeeze(axis=-1)
        rewards = jnp.squeeze(rewards, axis=-1)
        masks = jnp.squeeze(masks, axis=-1)
        target_values = rewards + self.gamma * masks * next_values
        print(rewards.shape, masks.shape, next_values.shape)
        print(current_action_values.shape, target_values.shape)
        # Compute loss
        batch_loss = huber(current_action_values - target_values)
        loss = jnp.mean(batch_loss)
        return loss, {
            "loss": loss,
            "q_mean": jnp.mean(current_action_values),
            "next_q_mean": jnp.mean(next_values)
        }


### Logging and Evaluation (Nothing to do)

Just a simple pandas logger. Data can then be plotted and/or saved in a csv file.

In [None]:
class PandasLogger:
    def __init__(self):
        self.df = pd.DataFrame()
    
    @property
    def data(self):
        return self.df
    
    def __str__(self):
        return self.df.__str__()
    
    def __repr__(self):
        return self.df.__repr__()
    
    def log(self, data: dict):
        if type(data) != dict:
            raise ValueError("log() requires a dict")

        if self.df.empty and len(self.df.columns) == 0:
            self.df = pd.DataFrame(columns=data.keys())
        
        self.df = pd.concat([self.df, pd.DataFrame([data])], ignore_index=True)
    
    def save(self, filepath: str):
        """
        Save internal dataframe to a CSV file.
        """
        filepath = Path(filepath)
        filepath.parent.mkdir(parents=True, exist_ok=True)
        self.df.to_csv(filepath, index=False)
    
    def load(self, filepath: str):
        """
        Load data from a CSV file into the internal dataframe.
        """
        filepath = Path(filepath)
        if not filepath.exists():
            raise FileNotFoundError(f"No file found at {filepath}")

        self.df = pd.read_csv(filepath)

def eval_episode(
    eval_env: gym.Env,
    # model_def: DeepQNetwork,
    agent: DQN,
    params: Params
) -> float:
    obs, info = eval_env.reset()
    done = False
    sum_of_rewards = 0
    while not(done):
        action = agent.select_single_action(params, obs)
        next_obs, reward, term, trunc, info = eval_env.step(action)
        sum_of_rewards += reward

        done = term or trunc
        obs = next_obs
    return sum_of_rewards

def eval(
    eval_env: gym.Env,
    agent: DQN,
    params: Params,
    nb_episodes: int
) -> Metrics:
    returns = np.zeros((nb_episodes,))
    for i in range(nb_episodes):
        returns[i] = eval_episode(
            eval_env=eval_env, agent=agent, params=params
        )
    return {
        "return_mean": np.mean(returns),
        "return_median": np.median(returns),
        "return_25": np.quantile(returns, 0.25),
        "return_75": np.quantile(returns, 0.75)
    }

def plot_from_dataframe(
    fig, ax,
    df: pd.DataFrame,
    x_key: str, y_key: str,
    fill_between: Optional[Tuple[str, str]] = None,
    xlabel: Optional[str] = None,
    ylabel: Optional[str] = None
):
    x = pd.to_numeric(df[x_key], errors="coerce").to_numpy()
    y = pd.to_numeric(df[y_key], errors="coerce").to_numpy()
    
    if fill_between is not None:
        ymin = pd.to_numeric(df[fill_between[0]], errors="coerce").to_numpy()
        ymax = pd.to_numeric(df[fill_between[1]], errors="coerce").to_numpy()
        ax.fill_between(x, ymin, ymax, alpha=0.2)
    
    ax.plot(x, y)
    ax.grid()

    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    return fig, ax

### Training

**Question 5:** Identify the hyperparameters of DQN. What do they mean ? How do you think they influence the training ?

**Question 6:** Call the teacher and launch the training when ready.

**Question 7:** What is the difference between DQN and DoubleDQN ? Of course the loss... But what is different in the loss ?

**Question 8:** How data is sampled from the replay buffer ? How new data is inserted ?

In [None]:
# Environment
env_kwargs = {
    "LunarLander-v2": dict(
        max_episode_steps=1_000,
        continuous=False,
        gravity=-9.81,
        enable_wind=False, 
        wind_power=15.0,
        turbulence_power=1.5
    )
}
env_name = "LunarLander-v2"
env = gym.make(
    env_name,
    render_mode=None,
    **env_kwargs[env_name]
)
num_eval_episodes = 20
if num_eval_episodes == 1:
    eval_render_mode = "human"
else:
    eval_render_mode = None
eval_env = gym.make(
    env_name,
    render_mode=eval_render_mode,
    **env_kwargs[env_name]
)
observation_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# JAX random state
seed = 0
np_rand_state = np.random.RandomState(seed)
rng = jax.random.PRNGKey(seed)

# Agent
# agent = DQN(
agent = DoubleDQN(
    observation_dim=observation_dim, action_dim=action_dim,
    gamma=0.999,
    learning_rate=1e-3,
    update_target_every_x_steps=5_000,
    network_features=dict(hidden_dims=(256, 256)),
    eps_greedy_hp = dict(
        eps_decrease="exponential",
        eps_end=0.01,
        eps_end_at=4.5e5
    )
)
rng, sub = jax.random.split(rng)
params, opt_state, target_params, epsilon = agent.init(sub)

# Buffer and sampler
buffer_size = 200_000
buffer = DefaultBuffer(
    buffer_size=buffer_size,
    sampler=DefaultSampler()
)

# Main loop
batch_size = 64
max_steps = 400_100
start_training_after_x_steps = 10_000
eval_every_x_steps = 50_000
update_params_every = 4

logger = PandasLogger()
metrics = None
observation, info = env.reset()
for step in tqdm(range(max_steps)):
    # Evaluation
    if step % eval_every_x_steps == 0:
        eval_metrics = eval(
            eval_env=eval_env,
            agent=agent,
            params=params,
            nb_episodes=num_eval_episodes
        )
        if metrics is not None:
            logger.log({"step": step, "epsilon": epsilon, **eval_metrics, **metrics})
            print(
                "step:", step, ";",
                "epsilon: {:.2f}".format(epsilon), ";",
                "return_median: {:.3f}".format(eval_metrics["return_median"]), ";",
                "return_25: {:.3f}".format(eval_metrics["return_25"]), ";",
                "return_75: {:.3f}".format(eval_metrics["return_75"]), ";",
                "loss:", metrics["loss"],
                "q_mean:", metrics["q_mean"],
                "next_q_mean:", metrics["next_q_mean"]
            )
        else:
            logger.log({"step": step, "epsilon": epsilon, **eval_metrics})
            print(
                "step:", step, ";",
                "epsilon: {:.2f}".format(epsilon), ";",
                "return_median: {:.3f}".format(eval_metrics["return_median"]), ";",
                "return_25: {:.3f}".format(eval_metrics["return_25"]), ";",
                "return_75: {:.3f}".format(eval_metrics["return_75"]), ";"
            )

    # Select action
    if np_rand_state.uniform() < epsilon:
        action = env.action_space.sample()
    else:
        action = agent.select_single_action(params, observation)
    
    # Perform step
    next_observation, reward, terminated, truncated, info = env.step(action)

    # Store transition
    transition = {
        "observation": np.expand_dims(observation, axis=0),
        "action": np.array([[action]], dtype=np.int64),
        "next_observation": np.expand_dims(next_observation, axis=0),
        "reward": np.array([[reward]]),
        "terminated": np.array([[terminated]])
    }
    buffer.insert(transition)

    # Update if necessary
    if step > start_training_after_x_steps:
        # Sample batch
        batch = buffer.sample(batch_size)

        if step % update_params_every == 0:
            # Perform gradient descent step
            params, opt_state, updates, grad, metrics = agent.gradient_step(
                params=params,
                opt_state=opt_state,
                target_params=target_params,
                batch=batch
            )
        # Update target params
        target_params = agent.update_target_params(
            params=params,
            target_params=target_params,
            step=step
        )
        # Update epsilon
        epsilon = agent.update_epsilon(epsilon)

    # Prepare next iter
    if terminated or truncated:
        observation, info = env.reset()
    else:
        observation = next_observation

# Save learning
logger.save("./crash_test.csv")


In [None]:
# Some figures
df = logger.data
fig, ax = plt.subplots()
fig, ax = plot_from_dataframe(
    fig, ax, df,
    x_key="step", y_key="return_median",
    fill_between=("return_25", "return_75"),
    xlabel="Step",
    ylabel="Return"
)
fig.tight_layout()
ax.plot()
# fig.savefig("return.pdf")
# plt.close(fig)

for y_key in df.keys():
    if y_key != "step" and not("return" in y_key):
        fig, ax = plt.subplots()
        fig, ax = plot_from_dataframe(
            fig, ax, df,
            x_key="step", y_key=y_key,
            xlabel="Step",
            ylabel=y_key
        )
        fig.tight_layout()
        ax.plot()
        # fig.savefig("{}.pdf".format(y_key))
        # plt.close(fig)