# DQN from Scratch on CartPole

In this tutorial, you will implement Deep Q-Networks (DQN) from scratch using PyTorch on `CartPole-v1`.

Objectives:
- Build a replay buffer and Q-network
- Train with target network and ε-greedy exploration
- Evaluate and visualize learning

Read each markdown cell before running the corresponding code cell.


In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, namedtuple
import random
import matplotlib.pyplot as plt

print("Imports ready.")


## Environment and Seeding
We'll use `CartPole-v1` for speed and clarity.


In [None]:
env = gym.make("CartPole-v1")
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

np.random.seed(0)
random.seed(0)
torch.manual_seed(0)

obs_dim, act_dim


## Replay Buffer
We store transitions `(s, a, r, s', done)` and sample random minibatches for training.


In [None]:
Transition = namedtuple("Transition", ["s", "a", "r", "s2", "d"])

class ReplayBuffer:
    def __init__(self, capacity=50_000):
        self.buf = deque(maxlen=capacity)
    def push(self, *args):
        self.buf.append(Transition(*args))
    def sample(self, batch_size):
        batch = random.sample(self.buf, batch_size)
        s = torch.tensor(np.array([t.s for t in batch]), dtype=torch.float32)
        a = torch.tensor([t.a for t in batch], dtype=torch.int64).unsqueeze(-1)
        r = torch.tensor([t.r for t in batch], dtype=torch.float32).unsqueeze(-1)
        s2 = torch.tensor(np.array([t.s2 for t in batch]), dtype=torch.float32)
        d = torch.tensor([t.d for t in batch], dtype=torch.float32).unsqueeze(-1)
        return s, a, r, s2, d
    def __len__(self):
        return len(self.buf)

rb = ReplayBuffer()
len(rb)


## Q-Network and Target Network
Two MLPs: one for online Q-values and one target network for stable targets.


In [None]:
class QNet(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 128), nn.ReLU(),
            nn.Linear(128, 128), nn.ReLU(),
            nn.Linear(128, act_dim)
        )
    def forward(self, x):
        return self.net(x)

qnet = QNet(obs_dim, act_dim)
qtarget = QNet(obs_dim, act_dim)
qtarget.load_state_dict(qnet.state_dict())
opt = optim.Adam(qnet.parameters(), lr=1e-3)

qnet, qtarget


## Training Loop
- Collect experience with ε-greedy
- Sample batches and compute TD targets
- Periodically sync target network


In [None]:
epsilon = 1.0
eps_min, eps_decay = 0.05, 0.995
batch_size = 64
gamma = 0.99
target_update = 1000
steps, episode, returns = 0, 0, []

for episode in range(200):
    s, _ = env.reset()
    done = False
    ret = 0
    while not done:
        steps += 1
        if random.random() < epsilon:
            a = env.action_space.sample()
        else:
            with torch.no_grad():
                a = int(torch.argmax(qnet(torch.tensor(s, dtype=torch.float32).unsqueeze(0))).item())
        s2, r, term, trunc, _ = env.step(a)
        d = float(term or trunc)
        rb.push(s, a, r, s2, d)
        s = s2
        ret += r

        if len(rb) >= batch_size:
            S, A, R, S2, D = rb.sample(batch_size)
            q_vals = qnet(S).gather(1, A)
            with torch.no_grad():
                target = R + (1 - D) * gamma * qtarget(S2).max(1, keepdim=True)[0]
            loss = nn.MSELoss()(q_vals, target)
            opt.zero_grad(); loss.backward(); opt.step()

        if steps % target_update == 0:
            qtarget.load_state_dict(qnet.state_dict())

        if d == 1.0:
            break

    epsilon = max(eps_min, epsilon * eps_decay)
    returns.append(ret)
    if (episode+1) % 10 == 0:
        print(f"Ep {episode+1:3d} | Ret {np.mean(returns[-10:]):6.2f} | eps {epsilon:5.3f}")

plt.plot(returns); plt.title('DQN from Scratch - Returns'); plt.show()
