<a href="https://colab.research.google.com/github/eemlcommunity/PracticalSessions2023/blob/omardd%2Frl/reinforcement_learning/part3_deep_q_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [EEML 2023] Reinforcement Learning Tutorial - Part 3

## Deep Q-Learning

# Colab Setup

In [None]:
# Colab setup
from IPython import get_ipython

if 'google.colab' in str(get_ipython()):
  # optax, haiku, rlax
  !pip install optax > /dev/null 2>&1
  !pip install dm-haiku > /dev/null 2>&1
  !pip install rlax > /dev/null 2>&1

  # gymnasium
  !pip install -q swig > /dev/null 2>&1
  !pip install "gymnasium[box2d]" > /dev/null 2>&1

  # install rlberry library (https://github.com/rlberry-py/rlberry)
  !pip install rlberry==0.5.0 > /dev/null 2>&1

  # reinstall numpy to avoid errors
  !pip install "numpy<1.23.0" > /dev/null 2>&1

  # install ffmpeg-python for saving videos
  !pip install ffmpeg-python > /dev/null 2>&1

  # packages required to show video
  !pip install pyvirtualdisplay > /dev/null 2>&1
  !apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1

# Check rlberry version
import rlberry
print(rlberry.__version__)

# Create directory for saving videos
!mkdir videos > /dev/null 2>&1

# Initialize display and import function to show videos
import rlberry.colab_utils.display_setup
from rlberry.colab_utils.display_setup import show_video

In [None]:
import numpy as np
import rlberry
print(np.__version__)

In [None]:
# Useful imports
import gymnasium as gym
from gymnasium.utils.save_video import save_video
import torch
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from rlberry.agents import Agent
from rlberry.manager import AgentManager, plot_writer_data, read_writer_data
from typing import Callable, NamedTuple, Sequence

import torch.nn as nn
import torch.optim as optim

import chex
import haiku as hk
import jax
import jax.numpy as jnp
import optax

# torch device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
env_id = 'MountainCar-v0' #@param ["CartPole-v1", "LunarLander-v2", "MountainCar-v0"]

In [None]:
def get_dqn_env(for_render=False):
  if not for_render:
    return gym.make(env_id)
  else:
    return gym.make(env_id, render_mode="rgb_array_list")

def render_dqn_policy(agent=None):
  env = get_dqn_env(for_render=True)
  state, _ = env.reset()
  step_starting_index = 0
  episode_index = 0
  for step_index in range(500):
    if episode_index > 0: # show only one episode
      break
    if agent is None:
      action = env.action_space.sample()
    else:
      action = agent.policy(state)
    next_state, reward, terminated, truncated, info = env.step(action)

    if terminated or truncated:
        save_video(
          env.render(),
          "videos",
          fps=env.metadata["render_fps"],
          step_starting_index=step_starting_index,
          episode_index=episode_index
        )
        step_starting_index = step_index + 1
        episode_index += 1
        next_state, _ = env.reset()
    state = next_state
  show_video("videos/rl-video-episode-0.mp4")


# render_dqn_policy()

# Replay Buffer

In [None]:
class ReplayBuffer:
  def __init__(self, capacity, rng):
    """
    Parameters
    ----------
    capacity : int
      Maximum number of transitions
    rng :
      instance of numpy's default_rng
    """
    self.capacity = capacity
    self.rng = rng  # random number generator
    self.memory = []
    self.position = 0

  def push(self, sample):
    """Saves a transition."""
    if len(self.memory) < self.capacity:
      self.memory.append(None)
    self.memory[self.position] = sample
    self.position = (self.position + 1) % self.capacity

  def sample(self, batch_size):
    indices = self.rng.choice(len(self.memory), size=batch_size)
    samples = [self.memory[idx] for idx in indices]
    return map(np.asarray, zip(*samples))

  def __len__(self):
    return len(self.memory)

# Neural Network Architecture

In [None]:
import haiku as hk
from typing import Tuple


class MLPQNetwork(hk.Module):
    """
    MLP for Q functions with discrete number of actions.

    Parameters
    ----------
    num_actions : int
        Number of actions.
    hidden_sizes : Tuple[int, ...]
        Number of hidden layers in the MLP.
    name : str
        Identifier of the module.
    """

    def __init__(
        self, num_actions: int, hidden_sizes: Tuple[int, ...] = (64, 64), name: str = "MLPQNetwork"
    ):
        super().__init__(name=name)
        self._mlp = hk.nets.MLP(output_sizes=hidden_sizes + (num_actions,))

    def __call__(self, observation):
        out = self._mlp(observation)
        return out


# Hyperparameters


In [None]:
# Parameters

# Number of timesteps for training.
DQN_TRAINING_TIMESTEPS = 10000  #@param {type:"integer"}
# Discount factor
GAMMA = 0.99  #@param {type:"number"}
# Batch size (in number of chunks).
BATCH_SIZE = 64  #@param {type:"integer"}
# Size of trajectory chunks to sample from the buffer.
CHUNK_SIZE = 8  #@param {type:"integer"}
# Interval (in number of transitions) between updates of the online network.
ONLINE_UPDATE_INTERVAL = 1  #@param {type:"integer"}
# Interval (in number total number of online updates) between updates of the target network.
TARGET_UPDATE_INTERVAL = 512  #@param {type:"integer"}
# Learning rate
LEARNING_RATE = 0.001 #@param {type:"number"}
# Initial value of epsilon
EPSILON_INIT = 1.0  #@param {type:"number"}
# Minimum value of epsilon
EPSILON_END = 0.05  #@param {type:"number"}
# Parameter to decrease epsilon
EPSILON_STEPS = 5000  #@param {type:"integer"}
# Maximum size of replay buffer
MAX_REPLAY_SIZE = 100000  #@param {type:"integer"}
# Interval (in number of transitions) between agent evaluations in fit().
EVAL_INTERVAL = 256 #@param {type:"integer"}


DQN_PARAMS = dict(
    gamma=GAMMA,
    batch_size=BATCH_SIZE,
    chunk_size=CHUNK_SIZE,
    online_update_interval=ONLINE_UPDATE_INTERVAL,
    target_update_interval=TARGET_UPDATE_INTERVAL,
    learning_rate=LEARNING_RATE,
    epsilon_init=EPSILON_INIT,
    epsilon_end=EPSILON_END,
    epsilon_steps=EPSILON_STEPS,
    max_replay_size=MAX_REPLAY_SIZE,
    eval_interval=EVAL_INTERVAL,
)

# DQN Agent Implementation


Implement the DQN loss in the `_loss` method of the `DQNAgent` class below.

The loss is given by:


$$
L(\theta) = \sum_{(s_i, a_i, r_i, s_i') \in \mathcal{B}}
\left[
Q(s_i, a_i, \theta) -  y_i
\right]^2
$$
where the $y_i$ are the **targets** computed with the **target network** $\theta^-$:


$$
y_i = r_i + \gamma \max_{a'} Q(s_i', a', \theta^-).
$$

Note: we call **online network** the one parameterized by $\theta$, since it's the one used to interact with the environment (online).

In [None]:
import chex
import functools
import haiku as hk
import jax
import jax.numpy as jnp

import numpy as np
import optax
import dill
import rlax

from gymnasium import spaces
from pathlib import Path
from rlberry import types
from rlberry.agents import AgentWithSimplePolicy
from rlberry.agents.utils.replay import ReplayBuffer
from typing import Any, Callable, Mapping, Optional

import rlberry

logger = rlberry.logger


@chex.dataclass
class AllParams:
    online: chex.ArrayTree
    target: chex.ArrayTree


@chex.dataclass
class AllStates:
    optimizer: chex.ArrayTree
    learner_steps: int
    actor_steps: int


@chex.dataclass
class ActorOutput:
    actions: chex.Array
    q_values: chex.Array


class DQNAgent(AgentWithSimplePolicy):
    """
    Implementation of Deep Q-Learning using JAX.

    Parameters
    ----------
    env : types.Env
        Environment.
    gamma : float
        Discount factor.
    batch_size : int
        Batch size (in number of chunks).
    chunk_size : int
        Size of trajectory chunks to sample from the buffer.
    online_update_interval : int
        Interval (in number of transitions) between updates of the online network.
    target_update_interval : int
        Interval (in number total number of online updates) between updates of the target network.
    learning_rate : float
        Optimizer learning rate.
    epsilon_init : float
        Initial value of epsilon-greedy exploration.
    epsilon_end : float
        End value of epsilon-greedy exploration.
    epsilon_steps : int
        Number of steps over which annealing over epsilon takes place.
    max_replay_size : int
        Maximum number of transitions in the replay buffer.
    eval_interval : int
        Interval (in number of transitions) between agent evaluations in fit().
        If None, never evaluate.
    max_episode_length : int
        Maximum length of an episode. If None, episodes will only end if `done = True`
        is returned by env.step().
    net_constructor : callable
        Constructor for Q network. If None, uses default MLP.
    net_kwargs : dict
        kwargs for network constructor (net_constructor).
    max_gradient_norm : float, default: 100.0
        Maximum gradient norm.
    """

    name = "JaxDqnAgent"

    def __init__(
        self,
        env: types.Env,
        gamma: float = 0.99,
        batch_size: int = 64,
        chunk_size: int = 8,
        online_update_interval: int = 1,
        target_update_interval: int = 512,
        learning_rate: float = 0.001,
        epsilon_init: float = 1.0,
        epsilon_end: float = 0.05,
        epsilon_steps: int = 5000,
        max_replay_size: int = 100000,
        eval_interval: Optional[int] = None,
        max_episode_length: Optional[int] = None,
        net_constructor: Optional[Callable[..., hk.Module]] = None,
        net_kwargs: Optional[Mapping[str, Any]] = None,
        max_gradient_norm: float = 100.0,
        **kwargs
    ):
        AgentWithSimplePolicy.__init__(self, env, **kwargs)
        env = self.env
        self.rng_key = jax.random.PRNGKey(self.rng.integers(2**32).item())

        # checks
        if not isinstance(self.env.observation_space, spaces.Box):
            raise ValueError("DQN only implemented for Box observation spaces.")
        if not isinstance(self.env.action_space, spaces.Discrete):
            raise ValueError("DQN only implemented for Discrete action spaces.")

        # params
        self._gamma = gamma
        self._batch_size = batch_size
        self._chunk_size = chunk_size
        self._online_update_interval = online_update_interval
        self._target_update_interval = target_update_interval
        self._max_replay_size = max_replay_size
        self._eval_interval = eval_interval
        self._max_episode_length = max_episode_length or np.inf
        self._max_gradient_norm = max_gradient_norm

        #
        # Setup replay buffer
        #

        # define specs
        sample_obs, _ = env.reset()
        try:
            obs_shape, obs_dtype = sample_obs.shape, sample_obs.dtype
        except AttributeError:  # in case sample_obs has no .shape attribute
            obs_shape, obs_dtype = (
                env.observation_space.shape,
                env.observation_space.dtype,
            )
        action_shape, action_dtype = env.action_space.shape, env.action_space.dtype

        # create replay buffer
        self._replay_buffer = ReplayBuffer(
            max_replay_size = self._max_replay_size,
            rng=self.rng
        )

        self._replay_buffer.setup_entry("actions", action_dtype)
        self._replay_buffer.setup_entry("observations", obs_dtype)
        self._replay_buffer.setup_entry("next_observations", obs_dtype)
        self._replay_buffer.setup_entry("rewards", np.float32)
        self._replay_buffer.setup_entry("discounts", np.float32)

        # initialize network and params
        net_constructor = net_constructor or MLPQNetwork
        net_kwargs = net_kwargs or dict(
            num_actions=self.env.action_space.n, hidden_sizes=(64, 64)
        )
        net_ctor = functools.partial(net_constructor, **net_kwargs)
        self._q_net = hk.without_apply_rng(hk.transform(lambda x: net_ctor()(x)))

        self._dummy_obs = jnp.ones(self.env.observation_space.shape)

        self.rng_key, subkey1 = jax.random.split(self.rng_key)
        self.rng_key, subkey2 = jax.random.split(self.rng_key)

        self._all_params = AllParams(
            online=self._q_net.init(subkey1, self._dummy_obs),
            target=self._q_net.init(subkey2, self._dummy_obs),
        )

        # initialize optimizer and states
        self._optimizer = optax.chain(
            optax.clip_by_global_norm(self._max_gradient_norm),
            optax.adam(learning_rate),
        )
        self._all_states = AllStates(
            optimizer=self._optimizer.init(self._all_params.online),
            learner_steps=jnp.array(0),
            actor_steps=jnp.array(0),
        )

        # epsilon decay
        self._epsilon_schedule = optax.polynomial_schedule(
            init_value=epsilon_init,
            end_value=epsilon_end,
            transition_steps=epsilon_steps,
            transition_begin=0,
            power=1.0,
        )

        # update functions (jit)
        self.actor_step = jax.jit(self._actor_step)
        self.learner_step = jax.jit(self._learner_step)

    def policy(self, observation):
        self.rng_key, subkey = jax.random.split(self.rng_key)
        actor_out, _ = self.actor_step(
            self._all_params,
            self._all_states,
            observation,
            subkey,
            evaluation=True,
        )
        action = actor_out.actions.item()
        return action

    def fit(self, budget: int, **kwargs):
        """
        Train DQN agent.

        Parameters
        ----------
        budget: int
            Number of timesteps to train the agent.
        """
        del kwargs
        timesteps_counter = 0
        episode_rewards = 0.0
        episode_timesteps = 0
        observation, _ = self.env.reset()
        while timesteps_counter < budget:
            self.rng_key, subkey = jax.random.split(self.rng_key)
            actor_out, self._all_states = self.actor_step(
                self._all_params,
                self._all_states,
                observation,
                subkey,
                evaluation=False,
            )
            action = actor_out.actions.item()
            next_obs, reward, terminated, truncated, info= self.env.step(action)
            done = terminated or truncated

            # check max episode length
            done = done and (episode_timesteps < self._max_episode_length)

            # store data
            episode_rewards += reward
            self._replay_buffer.append(
                {
                    "actions": action,
                    "observations": observation,
                    "rewards": reward,
                    "discounts": self._gamma * (1.0 - done),
                    "next_observations": next_obs,
                }
            )

            # counters and next obs
            timesteps_counter += 1
            episode_timesteps += 1
            observation = next_obs

            # update
            total_timesteps = self._all_states.actor_steps.item()
            if total_timesteps % self._online_update_interval == 0:
                if len(self._replay_buffer) > self._batch_size:
                    sample = self._replay_buffer.sample(batch_size=self._batch_size, chunk_size=self._chunk_size)
                    batch = sample.data
                    self._all_params, self._all_states, info = self.learner_step(
                        self._all_params, self._all_states, batch
                    )
                    if self.writer:
                        self.writer.add_scalar(
                            "q_loss", info["loss"].item(), total_timesteps
                        )
                        self.writer.add_scalar(
                            "learner_steps",
                            self._all_states.learner_steps.item(),
                            total_timesteps,
                        )

            # eval
            if (
                self._eval_interval is not None
                and total_timesteps % self._eval_interval == 0
            ):
                eval_rewards = self.eval(
                    eval_horizon=self._max_episode_length,
                    n_simimulations=2,
                    gamma=1.0,
                )
                self.writer.add_scalar(
                    "eval_rewards", eval_rewards, total_timesteps
                )

            # check if episode ended
            if done:
                if self.writer:
                    self.writer.add_scalar(
                        "episode_rewards", episode_rewards, total_timesteps
                    )
                self._replay_buffer.end_episode()
                episode_rewards = 0.0
                episode_timesteps = 0
                observation, _ = self.env.reset()

    def _loss(self, all_params, batch):
        obs_tm1 = batch["observations"]
        a_tm1 = batch["actions"]
        r_t = batch["rewards"]
        discount_t = batch["discounts"]
        obs_t = batch["next_observations"]

        # remove time (chunk) dim (batch has shape [batch, chunk_size, ...])
        # they're reshaped to [batch * chunk_size, ...]
        a_tm1 = a_tm1.flatten()
        r_t = r_t.flatten()
        discount_t = discount_t.flatten()
        obs_tm1 = jnp.reshape(obs_tm1, (-1, obs_tm1.shape[-1]))
        obs_t = jnp.reshape(obs_t, (-1, obs_t.shape[-1]))

        q_tm1 = self._q_net.apply(all_params.online, obs_tm1)
        q_t_val = self._q_net.apply(all_params.target, obs_t)
        q_t_select = self._q_net.apply(all_params.online, obs_t)

        # ====================================================
        # YOUR IMPLEMENTATION HERE
        #
        loss = jnp.array(0.0) # ...
        # ====================================================
        info = dict(loss=loss)
        return loss, info

    def _actor_step(self, all_params, all_states, observation, rng_key, evaluation):
        obs = jnp.expand_dims(observation, 0)  # dummy batch
        q_val = self._q_net.apply(all_params.online, obs)[0]  # remove batch
        epsilon = self._epsilon_schedule(all_states.actor_steps)
        train_action = rlax.epsilon_greedy(epsilon).sample(rng_key, q_val)
        eval_action = rlax.greedy().sample(rng_key, q_val)
        action = jax.lax.select(evaluation, eval_action, train_action)
        return (
            ActorOutput(actions=action, q_values=q_val),
            AllStates(
                optimizer=all_states.optimizer,
                learner_steps=all_states.learner_steps,
                actor_steps=all_states.actor_steps + 1,
            ),
        )

    def _learner_step(self, all_params, all_states, batch):
        target_params = rlax.periodic_update(
            all_params.online,
            all_params.target,
            all_states.learner_steps,
            self._target_update_interval,
        )
        grad, info = jax.grad(self._loss, has_aux=True)(all_params, batch)
        updates, optimizer_state = self._optimizer.update(
            grad.online, all_states.optimizer
        )
        online_params = optax.apply_updates(all_params.online, updates)
        return (
            AllParams(online=online_params, target=target_params),
            AllStates(
                optimizer=optimizer_state,
                learner_steps=all_states.learner_steps + 1,
                actor_steps=all_states.actor_steps,
            ),
            info,
        )

# Training & Evaluation

In [None]:
# # Training one instance of DQN
# dqn_agent = DQNAgent(
#     env=(get_dqn_env, dict()),  # we can send (constructor, kwargs) as an env
#     **DQN_PARAMS
# )
# dqn_agent.fit(DQN_TRAINING_TIMESTEPS)

#
# Training several instances using AgentManager
#
manager_kwargs = dict(
    agent_class=DQNAgent,
    train_env=(get_dqn_env, dict()),
    eval_env=(get_dqn_env, dict()),
    fit_budget=DQN_TRAINING_TIMESTEPS,
    n_fit=2,                   # NOTE: You may increase this parameter (number of agents to train)
    parallelization='thread',
    seed=456,
    default_writer_kwargs=dict(maxlen=None,log_interval=10),
)

In [None]:
dqn_manager = AgentManager(
    init_kwargs=DQN_PARAMS,
    agent_name='DQN',
    **manager_kwargs
)
dqn_manager.fit()



In [None]:
all_dqn_managers = []
all_dqn_managers.append(dqn_manager)

# We can plot the data that was
# stored by the agent with self.writer.add_scalar(tag, value, global_step):
_ = plot_writer_data(all_dqn_managers, tag='q_loss', title='Q Loss')
_ = plot_writer_data(all_dqn_managers, tag='episode_rewards', title='Rewards (Evaluation)')

In [None]:
agent = dqn_manager.get_agent_instances()[0]
render_dqn_policy(agent)