# Deep Q-Networks

## Tabular Q-Learning
By *Tabular Q-Learning* we mean a *model-free* set of techniques which work in discrete (or discretized) environments with small amount of states and actions and keep around a table of Q values. Contrary to the Q-Value Learning from the last chapter, here we don't explicitly model the environment.

After getting an experience $(s, a, r, s')$ we perform a blending Bellman approximation update:
$$
Q(s, a) \gets (1 - \alpha) Q(s, a) + \alpha (r + \max_{a'} Q(s', a'))
$$
which can be reformulated in terms of *Temporal Difference learning (TD learning)* as
$$
Q(s, a) \gets Q(s, a) + \alpha (r + \max_{a'} Q(s', a') - Q(s, a)) = Q(s, a) + \alpha \delta(s, a, r, s')
$$
where
* $\delta(s, a, r, s') = r + \max_{a'} Q(s', a') - Q(s, a)$ is called *TD error* and
* $r + \max_{a'} Q(s', a')$ is the *TD target*

Finally, for efficiency reasons we don't actually have to construct the full Q table as we don't really care about states that we've never experience. So we'll estimate values only for those states that we've seen and iterate over a smaller set.

In [1]:
import collections
from functools import partial
from typing import Callable, NamedTuple, Optional, Tuple

import cv2
import gym
import numpy as np
from tensorboardX import SummaryWriter


class Experience(NamedTuple):
    state: int
    action: int
    reward: float
    next_state: int


class Agent:
    """Q-Learning agent"""

    def __init__(
        self,
        n_actions: int,
        alpha: float,
        gamma: float,
    ) -> None:
        self.n_actions = n_actions
        self.alpha = alpha
        self.gamma = gamma
        self.values = collections.defaultdict(float)

    def policy(self, s: int) -> int:
        """Returns the best known action"""
        return np.argmax([self.values[s, a] for a in range(self.n_actions)])

    def __iadd__(self, e: Experience) -> "Agent":
        Q_max = self.values[e.next_state, self.policy(e.next_state)]
        td_target = e.reward + self.gamma * Q_max
        td_error = td_target - self.values[e.state, e.action]
        self.values[e.state, e.action] += self.alpha * td_error
        return self


def explore(env: gym.Env, state: int) -> Tuple[Experience, int]:
    """
    Samples and applies a random action in given environment from given state.
    Returns experience (s, a, r, s') and new state (resets env if necessary).
    """
    action = env.action_space.sample()
    next_state, reward, done, _ = env.step(action)
    experience = Experience(state, action, reward, next_state)
    state = env.reset() if done else next_state
    return experience, state


def evaluate(
    env: gym.Env,
    n_episodes: int,
    policy: Callable[[int], int],
) -> float:
    """
    Runs n episodes in given environment using given policy and
    computes the mean non-disounted reward.
    """

    total_reward = 0.0

    for _ in range(n_episodes):

        state = env.reset()
        episode_done = False

        while not episode_done:
            action = policy(state)
            next_state, reward, episode_done, _ = env.step(action)
            state = next_state
            total_reward += reward

    return total_reward / n_episodes


def train(
    env: gym.Env,
    eval_episodes: int = 20,
    alpha: float = 0.2,
    gamma: float = 0.9,
    solution_bound: float = 0.8,
    max_iters: int = 5_000,
) -> None:
    with SummaryWriter(comment="-q-learning") as writer:

        # Bind experience sampling and policy evaluation to environments
        #  - Note: We use a copy of the env. for testing.
        explore_env = partial(explore, env)
        eval_policy = partial(evaluate, gym.make(env.spec.id), eval_episodes)

        # Initialize the Q-Learning agent
        agent = Agent(env.action_space.n, alpha, gamma)

        i = 0
        reward = 0.0
        best_reward = 0.0

        # Initialize the environment
        state = env.reset()

        while reward < solution_bound and i < max_iters:
            i += 1

            # Sample new experience from the environment
            # and pass it to the agent to learn from it.
            experience, state = explore_env(state)
            agent += experience

            # Evaluate current policy
            mean_reward = eval_policy(agent.policy)
            best_reward = max(mean_reward, best_reward)

            # Record metrics
            writer.add_scalar("reward", mean_reward, i)
            writer.add_scalar("best_reward", best_reward, i)

    print(f"Solved in {i} iterations with best reward: {best_reward:.3f}")


# Run Q-Learning in FrozenLake
train(env=gym.make("FrozenLake-v0"))

Solved in 5000 iterations with best reward: 0.950


## Deep Q-Learning
Even Q-Learning does not iterate over all possible states, the size of the Q table might become intractable. The idea of *Deep Q-Learning* is to use a deep NN to represent $Q(s, a) \approx Q(s, a; \mathbf{w})$ ($\mathbf{w}$ are the NN parameters). We treat the problem as a supervised regression task and train the network with a SGD on a dataset of collected experiences.

### Problems in Deep Q-Learning
Although a NN model is a compact representation of the Q table, the learning dynamics we described above has some issues:

#### Exploration-Exploitation Dilema
In all the previous examples we've used random actions to sample experiences from the environment (exploration). However, if we're confident enough in the policy we have, it'd be much better to exploit it and use $a = argmax_a Q(s, a)$ in a state $s$.

A straightforward trick that is used to balance exploration and exploitation a bit is called $\epsilon$*-greedy* policy which randomly switches between these two extremes with $\epsilon$ probability. We typically define a schedule and decrease $\epsilon$ as the training proceeds.

#### SGD Optimization
SGD is heavily based on the assumption of the training instances being *i.i.d* which is definitely not the case in Q-Learning dynamics.

First, independence is broken due to the fact further states in an episode do depend on previous actions and states.

Furthermore, the training data (experiences) are not identically distributed. Regardless of which policy we use to sample the experiences (during training) it won't produce the same distibution of experiences as if we had used the optimal policy. In short, our targets do not follow the same distribution as what we're trying to learn.

To mitigate the i.i.d problem we introduce a *replay buffer* - a large collection of experinces from which we sample training mini-batches.

#### Experience Correlations
Next issue is that the future Q values $Q(s', a')$ are highly correlated with current $Q(s, a)$ that we try to update. The model starts by modifying $Q(s', a')$ which might cause $Q(s, a)$ to get worse but then by trying fixing $Q(s, a)$ we break $Q(s', a')$. Basically, the training can be quite unstable.

A technique that is frequently used to address this is described as *fixing (freezing) the targets*. We introduce second NN $\hat{Q}$ which is a copy of our $Q$ NN and use DT targets $r + \gamma \max_{a'} \hat{Q}(s', a')$. NN parameters are then regularly synchronized with quite a large period (1k or even 10k iterations).

#### The Markov Property
Last but not least, in certain environments (e.g. Atari games) the observations do not capture all the information required to make a decesion. For instance, from a single game screenshot we can't tell the dynamics of a ball. This breaks the Markov property we assume in MDPs and shifts the problem to a much harder class of *Partially Observable Markov Decision Processes (POMDPs)*.

One common trick to mitigate this issue and pretend we have a MDP is to capture the dynamic by stacking a number of consecutive frames and thread the resulting tensor as single observation. This is sufficient in many cases.

### DQN Training
Putting all the pieces together, here is an overview of the training algorithm for a *Deep Q-Network (DQN)*.
1. Initialize parameters of $Q$ and $\hat{Q}$, start with $\epsilon \gets 1$
1. Select a random action $a$ with $\epsilon$ probability, otherwise pick $a = argmax_a Q(s, a)$
1. Run $a$ in the environment and gain experience $e = (s, a, r, s')$
1. Store $e$ in the replay buffer
1. Sample random mini-batch $B$ from the replay buffer
1. For each $(s, a, r, s') \in B$ alculate targets $y$: $y = r$ if episode ends, otherwise $y = r + \gamma \max_{a'} \hat{Q}(s', a')$
1. Compute the loss $\mathcal{L} = (Q(s, a) - y)^2$
1. Using $\mathcal{L}$ run a SGD step to update $Q(s, a)$
1. Every $N$ steps copy $\mathbf{w}_Q \to \mathbf{w}_{\hat{Q}}$
1. Repeat from 2. until convergence or other termination condition

### Atari Environment Wrappers
The original Atari DQN paper defines a set of nowadays common environment transformations that make the training easier or even possible.

In [2]:
class FireResetEnv(gym.Wrapper):
    """
    For environments where the user need to press FIRE for the game to start.
    """

    def __init__(self, env: Optional[gym.Env] = None) -> None:
        super().__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == "FIRE"
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def step(self, action: int) -> np.ndarray:
        return self.env.step(action)

    def reset(self) -> np.ndarray:
        self.env.reset()

        # Press FIRE
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset()

        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset()

        return obs


class MaxAndSkipEnv(gym.Wrapper):
    """
    Return only every `skip`-th frame, the intermediate frames are max-pooled.
    This is to prevent the flickering effect in Atari games.
    """

    def __init__(self, env: Optional[gym.Env] = None, skip: int = 4) -> None:
        super().__init__(env)
        # Small frame buffer to max-poole the last two frames
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip

    def step(self, action: int) -> np.ndarray:
        total_reward = 0.0
        done = False

        # Make `skip` steps instead of just one and accumulate the reward
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)

            self._obs_buffer.append(obs)
            total_reward += reward

            if done:
                break

        # Max pooling of buffered frames
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info

    def reset(self) -> np.ndarray:
        # Reset the environment and frame buffer
        obs = self.env.reset()
        self._obs_buffer.clear()
        self._obs_buffer.append(obs)
        return obs


class ProcessFrame84(gym.ObservationWrapper):
    """
    Reshape, crop and reduce RGB frames to 84x84 grayscale images.
    """

    def __init__(self, env: Optional[gym.Env] = None) -> None:
        super().__init__(env)
        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(84, 84, 1),
            dtype=np.uint8,
        )

    def observation(self, obs: np.ndarray) -> np.ndarray:
        return self.process(obs)

    @staticmethod
    def process(frame: np.ndarray) -> np.ndarray:

        # Handle screen shapes in different Atari games
        if frame.size == 210 * 160 * 3:
            shape = [210, 160, 3]
        elif frame.size == 250 * 160 * 3:
            shape = [250, 160, 3]
        else:
            raise Exception(f"Unknown resolution: {frame.shape}")

        # Reshape the image
        img = frame.reshape(shape).astype(np.float32)

        # Convert RGB image to grayscale
        #  - colorimetric grayscale conversion
        #  - img[..., 0] * 0.299 + img[..., 1] * 0.587 + img[..., 2] * 0.114
        img = img.dot((0.299, 0.587, 0.114))

        # Resize and crop the image to final 84x84x1 tensor
        #  - We crop the image to keep the central part
        resized_screen = cv2.resize(
            src=img,
            dsize=(84, 110),
            interpolation=cv2.INTER_AREA,
        )
        cropped_screen = resized_screen[18:102, :]
        return np.reshape(cropped_screen, [84, 84, 1]).astype(np.uint8)


class ImageToPyTorch(gym.ObservationWrapper):
    """
    Reshape observations to conform to PyTorch's conventions.

    (height, width, channels) -> (channels, height, width)
    """

    def __init__(self, env: gym.Env) -> None:
        super().__init__(env)
        old_shape = self.observation_space.shape
        new_shape = (old_shape[-1], old_shape[0], old_shape[1])
        self.observation_space = gym.spaces.Box(
            low=0.0,
            high=1.0,
            shape=new_shape,
            dtype=np.float32,
        )

    def observation(self, observation: np.ndarray) -> np.ndarray:
        # Move color channel to the front
        return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
    """Scale frame intensities to [0, 1] interval"""

    def observation(self, obs: np.ndarray) -> np.ndarray:
        return np.array(obs).astype(np.float32) / 255.0


class BufferWrapper(gym.ObservationWrapper):
    """
    Make a sliding window of N consecutive frames and
    return these windows as new observations.
    """

    buffer: np.ndarray

    def __init__(
        self,
        env: gym.Env,
        n_steps: int,
        dtype: np.dtype = np.float32,
    ) -> None:
        super().__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(
            old_space.low.repeat(n_steps, axis=0),
            old_space.high.repeat(n_steps, axis=0),
            dtype=dtype,
        )

    def reset(self) -> np.ndarray:
        self.buffer = np.zeros_like(
            self.observation_space.low,
            dtype=self.dtype,
        )
        return self.observation(self.env.reset())

    def observation(self, observation: np.ndarray) -> np.ndarray:
        # Pop the oldest frame from the buffer and add new one
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer


def make_env(env_name: str) -> gym.Env:
    """Make a new wrapped Atari environment"""
    env = gym.make(env_name)
    env = MaxAndSkipEnv(env)
    env = FireResetEnv(env)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, 4)
    return ScaledFloatFrame(env)

### Atari DQN
The DQN architecture is quite simple and consists of two sequential parts:
1. Convolutional part with three 2D conv. layers with ReLU activations
1. Fully connected part with one dense layer with ReLU activation and final linar layer

There is an implicit flatten between these two parts to convert a 4D output of convolutions (batch of 3D images) to 2D input of the FC layer (batch of 1D feature vectors).

Finally, the we output Q value for each action as it is more efficient to model $Q(\cdot, a)$ rather than pass both states and actions on the input and output single $Q(s, a)$ value.

In [3]:
import torch  # noqa
import torch.nn as nn  # noqa


class DQN(nn.Module):
    def __init__(self, input_shape: Tuple[int, ...], n_actions: int) -> None:
        super().__init__()

        n_conv_inputs = input_shape[0]
        n_fc_inputs = self._conv_output_dim(input_shape)

        # Stack of 2D convolutional layers with ReLU activations
        self.conv = nn.Sequential(
            nn.Conv2d(n_conv_inputs, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )

        # Fully connected layers for the regression part
        #  - Outputs Q(., a) for each action a
        self.fc = nn.Sequential(
            nn.Linear(n_fc_inputs, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def _conv_output_dim(self, shape: Tuple[int, ...]) -> int:
        dummy_conv_input = torch.zeros(1, *shape)
        dummy_conv_output = self.conv(dummy_conv_input)
        return int(np.prod(dummy_conv_output.size()))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # `view(batch_size, -1)` flattens all the feature dimensions
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)