# Introduction

In Q-leraning, we maintain a matrix to save all the action values. However, this does not work for continuous space, or a complex environment with huge action and state spaces. It is natural to consider replace the Q-table by a function. In other words, we use a function to approximate Q. This is the core idea of DQN (Deep-Q Network).

Notice that DQN may not be suitable for continuous action space, since we will do $\max Q$ for the value updates, which could be hard for an NN if the actions are not discrete. 

First we write out the update rulws for Q-learning: $$Q(s, a) \leftarrow Q(s, a) + \alpha (r + \gamma \max_{a'}Q(s', a')- Q(s, a))$$

# CartPole env
In the CartPole Env, there exists a cart, on top of which stands a pole. The agent is required to move the cart horizontally such to make the pole standing. If the pole tilts too much, or the cart moves too far from the initial place, or neither of the two happenes for 200 frames, the game ends. The state of the agent is a tuple of (cart_position, cart_velocity, pole_angle, pole_tip_velocity). We have 2 actions, 0 for left and 1 for right.

# DQN for CartPole env

Instead of learning a function $f: (s, a) \rightarrow Q$, we can simply set it to learn the action given the state, i.e. $f: s \rightarrow a$. 

Now an improtant question comes to us: how do we define the loss function?

It is simple to use an MSE loss between $Q_\omega(s, a)$ (the learnt action value) and the TD estimation $r + \gamma \max_{a'}Q(s', a')$, which theoretically should equal the actaul action values:

$$\omega^* = \arg \min_{\omega} \frac{1}{2N}\sum_{i=1}^N[Q_\omega(s_i, a_i) - (r + \gamma \max_{a'}Q_\omega(s'_i, a'))]^2$$

Now that we have the loss function, we have expand RL into its NN form. Since DQN use the same idea as of Q-learning, it is also off-policy. Therefore, we can banalnce exploration and exploitation by $\epsilon$-greedy, and collect the sampled data for later updates. 

Before we implement DQN, there are 2 modules we need to know that facilitates DQN training to be stable and eminent, namely experience replay and target network. 

## Experience replay

Consider a supervised learning task, where we sample 1 or a batch of data for 1 gradient update. With the training goes on, especially when we have multiple epochs, a certain data sample is used multiple times. The reason why we can do this is that in supervised learning, we have an important assumption: the i.i.d assumption that says, all the data in the training set are independent from an identical distrituion. This is also the reason why RL, involving temporal data, could be hard to train using supervised methods: because the temporal struction breaks the i.i.d assumption. 

Since we use an NN to estimate $Q$, we will need much more data to feed it so that the network is thoroughly trained. This is why we need experience replay. 

In experience replay, we maintain a replay-loading area, where you will find all the sampled 4-element tuples $(s, a, r, s')$. In the training of the Q-network, we randomly sample from the area. By doing so, the training data satisfies i.i.d assumption, and each data sample can be used multiple times to train the network. 

## Target Network
The target of DQN training is to let $Q_\omega(s, a)$ approach $(r + \gamma \max_{a'} Q_\omega(s', a'))$. Since the TD error contains the output of the network, and the output of the network is changing during updates, it is very likely to induce an unstable training. In order to solve this issue, we introduce the target network. 

The idea of the target network is to fix the Q network for loss calculation. Thus, we have 2 sets of Q-networks:
- the previous Q-network for update, responsible for $Q_\omega(s, a)$ in loss calculation
- the target Q-network for stable loss calculation, responsible for $\max_{a'}Q_{\omega^-} (s', a')$

Every $C$ step the target $Q_{\omega^-}$ will be synchronized to $Q_\omega$, while $Q_\omega$ is updated every step based on gradient updating. 

To sum up, DQN algorithm is:

- Init $Q_\omega(s, a)$
- Copy $Q_\omega(s, a)$ to $Q_{\omega^-}(s, a)$
- Init replay pool $R$
- for episode $e \leftarrow 1$ to $E$:
    - get init state $s_1$
    - for timestep $t \leftarrow 1$ to $T$:
        - use $\epsilon$-greedy to choose action $a_t$
        - take $a_t$ and get the response $r_t, s_{t+1}$
        - put $(s_t, a_t, r_t, s_{t+1})$ into $R$
        - if $R$ has enough data, sample $N$ data $\{(s_i, a_i, r_i, s_{i+1})\}_{i=1, ..., N}$ from $R$
        - for each sampled data from $R$, calculate target $y_i = r_i + \gamma \max_{a'} Q_{\omega^-}(s_{i+1}, a')
        - minimize loss $L = \frac{1}{N} \sum_{i} (y_i - Q_\omega(s_t, a_t))^2$, then update $Q_\omega$
        - update $Q_{\omega^-}$
    - end for
- end for

In [2]:
!pip install torch

Collecting torch
  Downloading torch-2.7.0-cp312-none-macosx_11_0_arm64.whl.metadata (29 kB)
Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting typing-extensions>=4.10.0 (from torch)
  Downloading typing_extensions-4.13.2-py3-none-any.whl.metadata (3.0 kB)
Collecting setuptools (from torch)
  Downloading setuptools-80.8.0-py3-none-any.whl.metadata (6.6 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Using cached networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Downloading MarkupSafe-3.0

In [3]:
import random
import gym
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import rl_utils

In [4]:
class ReplayBuffer:
    ''' 经验回放池 '''
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)  # 队列,先进先出

    def add(self, state, action, reward, next_state, done):  # 将数据加入buffer
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):  # 从buffer中采样数据,数量为batch_size
        transitions = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done

    def size(self):  # 目前buffer中数据的数量
        return len(self.buffer)

In [5]:
class Qnet(torch.nn.Module):
    ''' 只有一层隐藏层的Q网络 '''
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(Qnet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))  # 隐藏层使用ReLU激活函数
        return self.fc2(x)

In [None]:
class DQN:
    ''' DQN算法 '''
    def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma,
                 epsilon, target_update, device):
        self.action_dim = action_dim
        self.q_net = Qnet(state_dim, hidden_dim,
                          self.action_dim).to(device)  # Q网络
        # 目标网络
        self.target_q_net = Qnet(state_dim, hidden_dim,
                                 self.action_dim).to(device)
        # 使用Adam优化器
        self.optimizer = torch.optim.Adam(self.q_net.parameters(),
                                          lr=learning_rate)
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # epsilon-贪婪策略
        self.target_update = target_update  # 目标网络更新频率
        self.count = 0  # 计数器,记录更新次数
        self.device = device

    def take_action(self, state):  # epsilon-贪婪策略采取动作
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            state = torch.tensor([state], dtype=torch.float).to(self.device)
            action = self.q_net(state).argmax().item()
        return action

    def update(self, transition_dict):
        # implement DQN
        pass