# CartPole DQN (PyTorch)

Fixed epsilon-greedy + replay buffer + terminal masking.

In [3]:
import random
from collections import deque
from dataclasses import dataclass
from typing import Deque, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Try gymnasium first, fall back to gym
try:
    import gymnasium as gym
    GYMNASIUM = True
except ImportError:
    import gym
    GYMNASIUM = False


# -------------------------
# Utils
# -------------------------
def set_seed(seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def to_tensor(x, device):
    return torch.as_tensor(x, dtype=torch.float32, device=device)


# -------------------------
# Replay Buffer
# -------------------------
@dataclass
class Transition:
    s: np.ndarray
    a: int
    r: float
    s2: np.ndarray
    done: float  # 1.0 if terminal else 0.0


class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buf: Deque[Transition] = deque(maxlen=capacity)

    def __len__(self):
        return len(self.buf)

    def push(self, s, a, r, s2, done: bool):
        self.buf.append(Transition(s, int(a), float(r), s2, float(done)))

    def sample(self, batch_size: int):
        batch = random.sample(self.buf, batch_size)
        s  = np.stack([t.s  for t in batch], axis=0)
        a  = np.array([t.a  for t in batch], dtype=np.int64)
        r  = np.array([t.r  for t in batch], dtype=np.float32)
        s2 = np.stack([t.s2 for t in batch], axis=0)
        d  = np.array([t.done for t in batch], dtype=np.float32)
        return s, a, r, s2, d


# -------------------------
# Q Network
# -------------------------
class QNet(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, action_dim),
        )

    def forward(self, x):
        return self.net(x)


# -------------------------
# DQN Agent
# -------------------------
class DQNAgent:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        device: torch.device,
        gamma: float = 0.99,
        lr: float = 5e-4,
        batch_size: int = 64,
        buffer_size: int = 50_000,
        min_buffer: int = 2_000,
        target_update_freq: int = 1_000,  # in env steps
        eps_start: float = 1.0,
        eps_end: float = 0.05,
        eps_decay_steps: int = 50_000,
        double_dqn: bool = True,
    ):
        self.device = device
        self.gamma = gamma
        self.batch_size = batch_size
        self.min_buffer = min_buffer
        self.target_update_freq = target_update_freq

        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay_steps = eps_decay_steps
        self.step = 0

        self.double_dqn = double_dqn

        self.q = QNet(state_dim, action_dim).to(device)
        self.q_tgt = QNet(state_dim, action_dim).to(device)
        self.q_tgt.load_state_dict(self.q.state_dict())
        self.q_tgt.eval()

        self.opt = optim.Adam(self.q.parameters(), lr=lr)
        self.loss_fn = nn.SmoothL1Loss()  # Huber

        self.rb = ReplayBuffer(buffer_size)
        self.action_dim = action_dim

    def epsilon(self):
        # linear decay
        frac = min(1.0, self.step / float(self.eps_decay_steps))
        return self.eps_start - (self.eps_start - self.eps_end) * frac

    @torch.no_grad()
    def act(self, s, greedy: bool = False):
        eps = 0.0 if greedy else self.epsilon()
        self.step += 1

        if random.random() < eps:
            return random.randrange(self.action_dim)

        s_t = to_tensor(s, self.device).unsqueeze(0)
        qvals = self.q(s_t)
        return int(torch.argmax(qvals, dim=1).item())

    def push(self, s, a, r, s2, done: bool):
        self.rb.push(s, a, r, s2, done)

    def learn(self):
        if len(self.rb) < self.min_buffer:
            return None

        s, a, r, s2, d = self.rb.sample(self.batch_size)

        s_t  = to_tensor(s, self.device)
        s2_t = to_tensor(s2, self.device)
        a_t  = torch.as_tensor(a, dtype=torch.int64, device=self.device).unsqueeze(1)
        r_t  = torch.as_tensor(r, dtype=torch.float32, device=self.device).unsqueeze(1)
        d_t  = torch.as_tensor(d, dtype=torch.float32, device=self.device).unsqueeze(1)

        q_sa = self.q(s_t).gather(1, a_t)

        with torch.no_grad():
            if self.double_dqn:
                # action selection from online net, evaluation from target net
                a2 = torch.argmax(self.q(s2_t), dim=1, keepdim=True)
                q_next = self.q_tgt(s2_t).gather(1, a2)
            else:
                q_next = self.q_tgt(s2_t).max(dim=1, keepdim=True)[0]

            y = r_t + (1.0 - d_t) * self.gamma * q_next  # terminal masking

        loss = self.loss_fn(q_sa, y)

        self.opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.q.parameters(), 10.0)
        self.opt.step()

        # target update
        if self.step % self.target_update_freq == 0:
            self.q_tgt.load_state_dict(self.q.state_dict())

        return float(loss.item())


# -------------------------
# Train / Eval
# -------------------------
def reset_env(env):
    out = env.reset()
    if isinstance(out, tuple):
        obs = out[0]
    else:
        obs = out
    return obs

def step_env(env, action):
    out = env.step(action)
    if GYMNASIUM:
        obs2, reward, terminated, truncated, info = out
        done = terminated or truncated
    else:
        obs2, reward, done, info = out
    return obs2, reward, done

@torch.no_grad()
def evaluate(env, agent: DQNAgent, episodes: int = 20):
    scores = []
    for _ in range(episodes):
        s = reset_env(env)
        done = False
        total = 0.0
        while not done:
            a = agent.act(s, greedy=True)
            s, r, done = step_env(env, a)
            total += r
        scores.append(total)
    return float(np.mean(scores)), float(np.max(scores))


def train_cartpole(
    env_name: str = "CartPole-v1",
    seed: int = 0,
    episodes: int = 800,
    eval_every: int = 25,
):
    set_seed(seed)
    env = gym.make(env_name)
    env_eval = gym.make(env_name)

    # seed envs
    try:
        env.reset(seed=seed)
        env_eval.reset(seed=seed + 1)
    except TypeError:
        pass

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # agent = DQNAgent(
    #     state_dim=state_dim,
    #     action_dim=action_dim,
    #     device=device,
    #     gamma=0.99,
    #     lr=5e-4,
    #     batch_size=64,
    #     buffer_size=50_000,
    #     min_buffer=2_000,
    #     target_update_freq=1_000,
    #     eps_start=1.0,
    #     eps_end=0.05,
    #     eps_decay_steps=50_000,
    #     double_dqn=True,
    # )
    agent = DQNAgent(
    state_dim=state_dim,
    action_dim=action_dim,
    device=device,
    gamma=0.99,
    lr=1e-4,
    batch_size=128,
    buffer_size=100_000,
    min_buffer=10_000,
    target_update_freq=2000,
    eps_start=1.0,
    eps_end=0.02,
    eps_decay_steps=100_000,
    double_dqn=True,
)

    returns = []
    losses = []

    for ep in range(1, episodes + 1):
        s = reset_env(env)
        done = False
        ep_ret = 0.0

        while not done:
            a = agent.act(s, greedy=False)
            s2, r, done = step_env(env, a)
            agent.push(s, a, r, s2, done)
            s = s2
            ep_ret += r

            loss = agent.learn()
            if loss is not None:
                losses.append(loss)

        returns.append(ep_ret)

        if ep % eval_every == 0:
            mean_eval, best_eval = evaluate(env_eval, agent, episodes=20)
            print(f"Ep {ep:4d} | train_return(last)={ep_ret:6.1f} | "
                  f"eval_mean(20)={mean_eval:6.1f} | eval_best(20)={best_eval:6.1f} | "
                  f"eps={agent.epsilon():.3f} | buffer={len(agent.rb)}")

            # Early stop heuristic for CartPole-v1
            if mean_eval >= 475:
                print("Solved (eval_mean >= 475).")
                break

    env.close()
    env_eval.close()
    return agent, np.array(returns, dtype=float), np.array(losses, dtype=float)


if __name__ == "__main__":
    agent, returns, losses = train_cartpole(
        env_name="CartPole-v1",   # if you must use v0: "CartPole-v0"
        seed=0,
        episodes=1500,
        eval_every=25,
    )

    print("\nFinal 100-episode mean return:", returns[-100:].mean() if len(returns) >= 100 else returns.mean())

Ep   25 | train_return(last)=  15.0 | eval_mean(20)=  89.3 | eval_best(20)= 105.0 | eps=0.977 | buffer=589
Ep   50 | train_return(last)=  26.0 | eval_mean(20)=  91.3 | eval_best(20)= 107.0 | eps=0.952 | buffer=1325
Ep   75 | train_return(last)=  18.0 | eval_mean(20)=  92.2 | eval_best(20)= 108.0 | eps=0.927 | buffer=1950
Ep  100 | train_return(last)=  37.0 | eval_mean(20)=  94.3 | eval_best(20)= 107.0 | eps=0.903 | buffer=2589
Ep  125 | train_return(last)=  15.0 | eval_mean(20)=  92.1 | eval_best(20)= 104.0 | eps=0.879 | buffer=3187
Ep  150 | train_return(last)=  31.0 | eval_mean(20)=  90.5 | eval_best(20)= 102.0 | eps=0.854 | buffer=3870
Ep  175 | train_return(last)=  20.0 | eval_mean(20)=  92.7 | eval_best(20)= 106.0 | eps=0.829 | buffer=4556
Ep  200 | train_return(last)=  27.0 | eval_mean(20)=  89.4 | eval_best(20)= 105.0 | eps=0.807 | buffer=5074
Ep  225 | train_return(last)=  13.0 | eval_mean(20)=  90.0 | eval_best(20)= 107.0 | eps=0.783 | buffer=5738
Ep  250 | train_return(last)=

In [5]:
env_eval = gym.make("CartPole-v1")
mean_100, best_100 = evaluate(env_eval, agent, episodes=100)
print(f"[FINAL TEST] mean(100)={mean_100:.1f}, best(100)={best_100:.1f}")
env_eval.close()

[FINAL TEST] mean(100)=500.0, best(100)=500.0
