# ðŸŽ® Lab 9: Deep Q-Network (DQN) on Atari

In [1]:
import torch
from torchrl.envs import GymEnv, TransformedEnv, Compose
from torchrl.envs.transforms import ToTensorImage, GrayScale, Resize, CatFrames, DoubleToFloat, RewardClipping
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
from tensordict import TensorDict
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import time
from torchrl.objectives import DQNLoss
from torchrl.modules import QValueModule 

In [5]:
writer = SummaryWriter(log_dir=f"runs/pong_dqn_{time.strftime('%Y%m%d-%H%M%S')}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gamma = 0.99
BATCH_SIZE = 32
REPLAY_SIZE = 50_000
LEARNING_RATE = 1e-4
SYNC_TARGET_FRAMES = 1000
REPLAY_START_SIZE = 50_000

EPSILON_DECAY_LAST_FRAME = 400_000
EPSILON_START = 1.0
EPSILON_FINAL = 0.01

In [6]:
# Base Gymnasium environment
base_env = GymEnv("ALE/Pong-v5", from_pixels=True, pixels_only=True, render_mode="rgb_array")
n_actions = base_env.action_space.n
obs_shape = (4, 84, 84)

# Apply preprocessing transforms
env = TransformedEnv(
    base_env,
    Compose(
        ToTensorImage(),         # Convert to tensor format
        GrayScale(),             # Convert RGB â†’ grayscale
        Resize(84, 84),          # Resize to 84Ã—84
        CatFrames(N=4, dim=-3),  # Stack 4 frames â†’ (4, 84, 84)
        DoubleToFloat(),         # Ensure float32 precision
        RewardClipping(-1, 1),   # Clip rewards to [-1, 1]
    ),
)

In [7]:
rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(max_size=REPLAY_SIZE),  # disk-backed storage (efficient and scalable)
    sampler=RandomSampler(),                     # uniform random sampling
    batch_size=BATCH_SIZE,                               # default sample batch size
)

In [8]:
class QNet(nn.Module):
    def __init__(self, n_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.BatchNorm2d(64), nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512), nn.LayerNorm(512), nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def forward(self, x):
        # x: (B,4,84,84) float32 in [0,1]
        z = self.conv(x)
        z = z.view(z.size(0), -1)
        return self.fc(z)

def select_action(obs, eps: float):
    if torch.rand(1).item() < eps:
        # Use TorchRL action_spec for a proper tensor action
        return env.action_spec.rand()  # scalar tensor (long)
    with torch.no_grad():
        x = obs.unsqueeze(0).to(device)   # (1,4,84,84)
        qvals = q(x)                               # (1,n_actions)
        a = torch.argmax(qvals, dim=1).to("cpu")   # back to CPU
        return a.squeeze(0)      

In [9]:
q = QNet(n_actions).to(device)
q_target = QNet(n_actions).to(device)
q_target.load_state_dict(q.state_dict())
q_target.eval()
optimizer = optim.Adam(q.parameters(), lr=LEARNING_RATE)

In [10]:
frame_idx = 0
td = env.reset()
total_rewards = []
episode = 0

In [None]:
while True:
    frame_idx += 1
    epsilon = max(EPSILON_FINAL, EPSILON_START - frame_idx/EPSILON_DECAY_LAST_FRAME)
    
    obs = td.get("pixels")
    a = select_action(obs, epsilon)
    
    td = env.step(td.set("action", a))
    next_obs = td.get(("next", "pixels"))
    r = td.get(("next", "reward"))
    d = td.get(("next", "done"))

    transition = TensorDict(
        {
            "obs": obs,
            "action": a,
            "reward": r,
            "next_obs": next_obs,
            "done": d,
        },
        batch_size=[],
    )
    rb.add(transition)
    total_rewards.append(r)

    if frame_idx > REPLAY_START_SIZE+100:
        writer.add_scalar("Loss/frame_idx", loss.item(), frame_idx)
        
        
    '''
                PRINT RESULT: after a set finishes
    '''
    if d.item():
        td = env.reset()
        episode += 1
        m_reward = np.sum(total_rewards)
        print(f"{frame_idx}: done {episode} games, reward: {m_reward: .3f}, rb: {len(rb)}, eps: {epsilon}")
        writer.add_scalar("Reward/episode", m_reward, episode)      
        writer.add_scalar("Epsilon", epsilon, episode)
        total_rewards=[]

    if len(rb) < REPLAY_START_SIZE:
        continue
        
    if frame_idx % SYNC_TARGET_FRAMES == 0:
        q_target.load_state_dict(q.state_dict())
        print("Q Network updated")

    batch = rb.sample(BATCH_SIZE)
    obs_b      = batch["obs"].to(device)              # (B,4,84,84)
    act_b      = batch["action"].long().to(device)    # (B,)
    rew_b      = batch["reward"].to(device).squeeze(-1)  # make sure it's (B,)
    next_obs_b = batch["next_obs"].to(device)         # (B,4,84,84)
    done_b     = batch["done"].to(device).float().squeeze(-1)
    
    with torch.no_grad():
        q_next = q_target(next_obs_b).max(1).values
        target = rew_b + gamma * (1.0 - done_b) * q_next
    
    act_b_ind = act_b.argmax(dim=-1)
    q_values = q(obs_b).gather(1, act_b_ind.unsqueeze(-1)).squeeze(1)

    loss = F.mse_loss(q_values, target)  # shape (B,1)
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(q.parameters(), 10.0)
    optimizer.step()

3448: done 1 games, reward: -21.000, rb: 3448, eps: 0.99138
