# Import libraries

In [1]:
import random
import collections
import numpy as np

import imageio

import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym

# Hyperparameters

In [2]:
ENV_NAME = "CartPole-v1"

GAMMA = 0.99
LR = 1e-3

BATCH_SIZE = 64
BUFFER_CAPACITY = 50_000
WARMUP_STEPS = 1000          # cần đủ mẫu trong buffer trước khi train
TRAIN_EVERY = 4              # mỗi 4 step môi trường thì train 1 lần

TARGET_UPDATE_EVERY = 1000   # mỗi 1000 bước update mạng target

EPS_START = 1.0
EPS_END = 0.05
EPS_DECAY = 0.995

MAX_EPISODES = 5000

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Replay Buffer

In [3]:
Transition = collections.namedtuple("Transition", ["state", "action", "reward", "next_state", "done"])


class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = collections.deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append(Transition(state, action, reward, next_state, done))

    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        states      = torch.tensor(np.array([t.state      for t in batch]), dtype=torch.float32)
        actions     = torch.tensor([t.action for t in batch], dtype=torch.int64).unsqueeze(-1)
        rewards     = torch.tensor([t.reward for t in batch], dtype=torch.float32).unsqueeze(-1)
        next_states = torch.tensor(np.array([t.next_state for t in batch]), dtype=torch.float32)
        dones       = torch.tensor([t.done   for t in batch], dtype=torch.float32).unsqueeze(-1)
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)

# Q-Network

In [4]:
class QNetwork(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions),
        )

    def forward(self, x):
        return self.net(x)

# ε-greedy

In [5]:
def select_action(q_net: QNetwork, state: np.ndarray, eps: float, env: gym.Env):
    # random với xác suất eps
    if random.random() < eps:
        return env.action_space.sample()
    # còn lại chọn argmax Q(s,a)
    state_t = torch.tensor(state, dtype=torch.float32, device=DEVICE).unsqueeze(0)  # [1, obs_dim]
    with torch.no_grad():
        q_values = q_net(state_t)  # [1, n_actions]
        action = int(torch.argmax(q_values, dim=1).item())
    return action

# Training DQN

In [6]:
def train_step(q_online, q_target, optimizer, replay_buffer):
    if len(replay_buffer) < BATCH_SIZE:
        return None  # chưa đủ batch

    states, actions, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)
    states = states.to(DEVICE)
    actions = actions.to(DEVICE)
    rewards = rewards.to(DEVICE)
    next_states = next_states.to(DEVICE)
    dones = dones.to(DEVICE)

    # Q(s,a) từ mạng ONLINE
    q_values = q_online(states)              # [B, n_actions]
    q_sa = q_values.gather(1, actions)       # [B, 1]

    # ====== DOUBLE DQN: TÍNH TARGET ======
    with torch.no_grad():
        # 1) ONLINE chọn action tốt nhất ở state s'
        q_next_online = q_online(next_states)              # [B, n_actions]
        best_next_actions = torch.argmax(q_next_online, dim=1, keepdim=True)  # [B, 1]

        # 2) TARGET đánh giá Q(s', a*) với a* do ONLINE chọn
        q_next_target = q_target(next_states)              # [B, n_actions]
        q_next_best = q_next_target.gather(1, best_next_actions)  # [B, 1]

        # 3) Bellman target
        y = rewards + GAMMA * (1.0 - dones) * q_next_best  # [B, 1]

    loss = nn.MSELoss()(q_sa, y)

    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(q_online.parameters(), 10.0)
    optimizer.step()

    return loss.item()

# Helper

In [7]:
def record_policy(env, q_online, eps=0.0, filename="cartpole_demo.mp4"):
    frames = []
    reset_out = env.reset()
    if isinstance(reset_out, tuple):
        state, _ = reset_out
    else:
        state = reset_out

    done = False
    while not done:
        frame = env.render()  # lấy frame
        frames.append(frame)

        # chọn action greedy
        action = select_action(q_online, state, eps, env)
        step_out = env.step(action)
        if len(step_out) == 5:
            next_state, reward, terminated, truncated, _ = step_out
            done = terminated or truncated
        else:
            next_state, reward, done, _ = step_out
        state = next_state

    env.close()
    imageio.mimsave(filename, frames, fps=30)
    print(f"✅ Video saved: {filename}")

# Main loop

In [8]:
def main():
    env = gym.make(ENV_NAME)
    obs_space = env.observation_space
    act_space = env.action_space

    obs_dim = obs_space.shape[0]
    n_actions = act_space.n

    # Mạng online & target
    q_online = QNetwork(obs_dim, n_actions).to(DEVICE)
    q_target = QNetwork(obs_dim, n_actions).to(DEVICE)
    q_target.load_state_dict(q_online.state_dict())
    q_target.eval()

    optimizer = optim.Adam(q_online.parameters(), lr=LR)
    replay_buffer = ReplayBuffer(BUFFER_CAPACITY)

    eps = EPS_START
    global_step = 0

    for episode in range(1, MAX_EPISODES + 1):
        reset_out = env.reset()
        if isinstance(reset_out, tuple):
            state, _ = reset_out
        else:
            state = reset_out

        episode_return = 0.0
        episode_len = 0

        while True:
            global_step += 1

            # 1) chọn action theo ε-greedy
            action = select_action(q_online, state, eps, env)

            # 2) bước môi trường
            step_out = env.step(action)
            if len(step_out) == 5:
                next_state, reward, terminated, truncated, _ = step_out
                done = terminated or truncated
            else:
                next_state, reward, done, _ = step_out

            # 3) lưu transition
            replay_buffer.push(state, action, reward, next_state, float(done))

            state = next_state
            episode_return += reward
            episode_len += 1

            # 4) Train ONLINE nếu đủ điều kiện
            if global_step > WARMUP_STEPS and global_step % TRAIN_EVERY == 0:
                train_step(q_online, q_target, optimizer, replay_buffer)

            # 5) Cập nhật TARGET mỗi TARGET_UPDATE_EVERY step
            if global_step % TARGET_UPDATE_EVERY == 0:
                q_target.load_state_dict(q_online.state_dict())

            # 6) Giảm epsilon + log
            if done:
                eps = max(EPS_END, eps * EPS_DECAY)
                print(
                    f"Episode {episode:03d} | Return: {episode_return:6.1f} "
                    f"| Len: {episode_len:4d} | eps: {eps:.3f}"
                )
                break
    env.close()
    record_env = gym.make(ENV_NAME, render_mode="rgb_array")
    record_policy(record_env, q_online, eps=0.0)

# Main

In [9]:
if __name__ == "__main__":
    main()

Episode 001 | Return:   18.0 | Len:   18 | eps: 0.995
Episode 002 | Return:   23.0 | Len:   23 | eps: 0.990
Episode 003 | Return:   24.0 | Len:   24 | eps: 0.985
Episode 004 | Return:   15.0 | Len:   15 | eps: 0.980
Episode 005 | Return:   10.0 | Len:   10 | eps: 0.975
Episode 006 | Return:   14.0 | Len:   14 | eps: 0.970
Episode 007 | Return:   12.0 | Len:   12 | eps: 0.966
Episode 008 | Return:   14.0 | Len:   14 | eps: 0.961
Episode 009 | Return:   40.0 | Len:   40 | eps: 0.956
Episode 010 | Return:   19.0 | Len:   19 | eps: 0.951
Episode 011 | Return:   17.0 | Len:   17 | eps: 0.946
Episode 012 | Return:   15.0 | Len:   15 | eps: 0.942
Episode 013 | Return:   30.0 | Len:   30 | eps: 0.937
Episode 014 | Return:   32.0 | Len:   32 | eps: 0.932
Episode 015 | Return:   20.0 | Len:   20 | eps: 0.928
Episode 016 | Return:   20.0 | Len:   20 | eps: 0.923
Episode 017 | Return:   13.0 | Len:   13 | eps: 0.918
Episode 018 | Return:   17.0 | Len:   17 | eps: 0.914
Episode 019 | Return:   13.0



✅ Video saved: cartpole_demo.mp4
