# üéÆ Lab 9: Deep Q-Network (DQN) on Atari

In this lab, you will extend your reinforcement learning skills from classical control environments (such as **CartPole**) to more complex **Atari games** like *Pong-v5*.  
You will implement a **Deep Q-Network (DQN)** that learns directly from raw pixel observations.

### Learning Objectives
- Understand how DQN combines **Q-learning** with **deep neural networks** to handle high-dimensional visual inputs.  
- Implement essential components:  
  - Replay Buffer  
  - Target Network  
  - Œµ-Greedy Exploration Strategy  
- Train an agent to achieve meaningful performance on an Atari environment.  
- Visualize training progress and recorded gameplay frames.

###  Part 1: Environment Setup

Before starting this lab, you need to create a new Conda environment (Python 3.10) and install the required packages for Atari reinforcement learning.

- Step 1. Create and activate the environment
```bash
conda create -n atari python=3.10 -y
conda activate atari

- Step 2. Use pip to install Gymnasium with Atari support, PyTorch, and the utilities used later in the lab.
```bash
pip install gymnasium[atari,accept-rom-license]==0.29.1
pip install autorom[accept-rom-license]
pip install stable-baselines3[extra]
pip install opencv-python imageio matplotlib
AutoROM --accept-license 

- Step 3. Install Torch and TorchRL
```bash
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
pip install torchrl

- Step 4. Verify you installation by using the ‚ÄúPong‚Äù environment

In [1]:
import gymnasium as gym
import numpy as np

env = gym.make("ALE/Pong-v5", render_mode="rgb_array")
frames = []

obs, info = env.reset(seed=0)
done = False
while not done:
    obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
    frames.append(env.render())  
    done = terminated or truncated

env.close()

  logger.warn(


In [2]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

frames = np.load("pong_frames_uint8.npy")
N = len(frames)

out = widgets.Output()
slider = widgets.IntSlider(min=0, max=N-1, step=1, value=0, description="Frame")

@widgets.interact(i=slider)
def _show(i):
    with out:
        clear_output(wait=True)
        plt.imshow(frames[i])
        plt.axis('off')
        plt.show()

display(out)


interactive(children=(IntSlider(value=0, description='Frame', max=990), Output()), _dom_classes=('widget-inter‚Ä¶

Output(outputs=({'output_type': 'display_data', 'data': {'text/plain': '<Figure size 640x480 with 1 Axes>', 'i‚Ä¶

- Overview of the Atari Environment

The **Atari environments** are among the most widely used benchmarks in reinforcement learning research.  
They provide visually rich and challenging tasks that allow agents to learn control policies directly from **raw pixel inputs**.  
These environments are part of the **Arcade Learning Environment (ALE)**, accessible via **Gymnasium**.

-  Components of the Environment

| Component | Description |
|------------|--------------|
| **State (Observation)** | A raw RGB image of size **(210 √ó 160 √ó 3)** representing the game screen. For DQN, these frames are usually converted to grayscale, resized (e.g., 84 √ó 84), and stacked (e.g., 4 frames) to provide temporal context. |
| **Action Space** | A discrete set of valid joystick actions that differ between games. For example, in **Pong**, there are 6 possible actions: <br> `0: NOOP`  (no operation) <br> `1: FIRE`  (start game) <br> `2: MOVE RIGHT` <br> `3: MOVE LEFT` <br> `4: MOVE UP` <br> `5: MOVE DOWN` |
| **Reward** | A scalar signal returned after each action. <br> ‚Ä¢ In **Pong**, +1 is given when the agent scores a point, and -1 when the opponent scores. <br> ‚Ä¢ In **Breakout**, the agent receives +1 for breaking a brick. <br> The cumulative reward reflects the agent‚Äôs game score. |
| **Done Flag** | Indicates whether the game has ended (win, lose, or max steps reached). |

#### Common Preprocessing Steps
To stabilize learning, observations are typically preprocessed as follows:
1. Convert RGB frames to grayscale.  
2. Resize to 84 √ó 84.  
3. Stack the most recent 4 frames.  
4. Normalize pixel values to `[0, 1]`.  

This reduces computational cost and helps the agent perceive motion.

###  Part 2: Introduction to TorchRL

TorchRL is a PyTorch-based library for **Reinforcement Learning (RL)** research and education. It provides a modular framework that integrates environments, data collection, replay buffers, transforms, and policy learning ‚Äî all built on top of **PyTorch** and **TensorDict**.

---
Traditional RL implementations often require extensive boilerplate code for:
- environment wrappers and preprocessing (e.g., grayscale, resize, frame stacking)
- replay buffer design and sampling
- batched rollouts and asynchronous data collection
- stable interfacing with PyTorch tensors and GPU devices  

TorchRL simplifies these tasks with consistent data structures and modular components, making RL experiments both **reproducible and scalable**. Below are the core concepts of TorchRL

---
| Concept | Description |
|----------|-------------|
| **`TensorDict`** | A dictionary-like container that holds tensors together with consistent batch shapes (used for observations, actions, rewards, etc.). |
| **`EnvBase` / `GymEnv`** | TorchRL‚Äôs base class for environments, compatible with Gymnasium environments (e.g., `ALE/Pong-v5`). |
| **`TransformedEnv`** | A wrapper that applies a chain of **transforms** (e.g., grayscale, resize, normalization) to the environment automatically. |
| **`ReplayBuffer`** | A memory module for storing and sampling transitions. TorchRL supports both simple and prioritized buffers. |
| **`Collector`** | Handles rollouts and data collection efficiently, supporting multiple parallel environments. |
| **`LossModules`** | Ready-to-use loss implementations for algorithms like DQN, A2C, PPO, SAC, etc. |

---

In [3]:
import torch
from torchrl.envs import GymEnv, TransformedEnv, Compose
from torchrl.envs.transforms import ToTensorImage, GrayScale, Resize, CatFrames, DoubleToFloat, RewardClipping
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
from torchrl.data.replay_buffers.samplers import RandomSampler

First, we introduce **TensorDict** and the **TorchRL replay buffer** by sampling trajectories from the Atari environment.

In TorchRL, data collected from the environment (such as observations, actions, rewards, and next states) are stored inside a **TensorDict** ‚Äî a dictionary-like container that holds PyTorch tensors with consistent batch shapes.  
Each environment step returns a TensorDict that organizes data in a structured and device-aware format, making it easy to manipulate, transform, and store for later use.

Raw Atari frames are high-dimensional RGB images (210√ó160√ó3) that are not directly suitable for deep Q-learning.  
TorchRL allows automatic preprocessing using **environment transforms**, which wrap the base environment in a `TransformedEnv`.  
Each transform modifies the observation data inside the TensorDict before it‚Äôs returned.

Typical transforms for DQN on Atari include:
- `ToTensorImage()` ‚Äî converts images from HWC uint8 ‚Üí CHW float [0, 1]  
- `GrayScale()` ‚Äî converts RGB to grayscale (reducing input channels from 3 ‚Üí 1)  
- `Resize(84, 84)` ‚Äî resizes frames to the standard 84√ó84 input  
- `CatFrames(N=4)` ‚Äî stacks the last 4 frames to capture motion information  
- `RewardClipping(-1, 1)` ‚Äî stabilizes training by limiting reward magnitude  
- `DoubleToFloat()` ‚Äî ensures float32 precision for network input  

Below is an example setup for a preprocessed Atari environment in TorchRL:

In [4]:
# Base Gymnasium environment
base_env = GymEnv("ALE/Pong-v5", from_pixels=True, pixels_only=True, render_mode="rgb_array")
n_actions = base_env.action_space.n
obs_shape = (4, 84, 84)

# Apply preprocessing transforms
env = TransformedEnv(
    base_env,
    Compose(
        ToTensorImage(),         # Convert to tensor format
        GrayScale(),             # Convert RGB ‚Üí grayscale
        Resize(84, 84),          # Resize to 84√ó84
        CatFrames(N=4, dim=-3),  # Stack 4 frames ‚Üí (4, 84, 84)
        DoubleToFloat(),         # Ensure float32 precision
        RewardClipping(-1, 1),   # Clip rewards to [-1, 1]
    ),
)

#### TensorDict: the Core Data Container

A TensorDict acts like a dictionary, but it ensures that all tensors it contains share the same batch dimensions.  
For example, an environment step may produce a TensorDict of the form:

```python
TensorDict({
    'pixels': Tensor(...),          # current observation
    'action': Tensor(...),
    'next': {
        'pixels': Tensor(...),      # next observation
        'reward': Tensor(...),
        'done': Tensor(...)
    }
}) 

```
This structure allows TorchRL to manage complex rollouts and batch operations with minimal boilerplate.



#### The Native TorchRL Replay Buffer

TorchRL provides a powerful and flexible replay buffer system built around **TensorDicts**.  
The `TensorDictReplayBuffer` class, together with `LazyMemmapStorage`, allows you to efficiently store and sample transitions collected from the environment.

Key advantages include:
- Seamless integration with TensorDict-based environments  
- Automatic device handling (CPU/GPU)  
- Support for both in-memory and disk-backed storage  
- Built-in random or prioritized sampling strategies  

In [9]:
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
from tensordict import TensorDict

rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(max_size=50_000),  # disk-backed storage (efficient and scalable)
    sampler=RandomSampler(),                     # uniform random sampling
    batch_size=32,                               # default sample batch size
)

In [10]:
# Collect a trajectory with random transitions
td = env.reset()
for _ in range(5000):
    
    # sample an action
    a = env.action_spec.rand()
    obs = td.get("pixels")
    td = env.step(td.set("action", a))
    next_obs = td.get(("next", "pixels"))
    r = td.get(("next", "reward"))
    d = td.get(("next", "done"))

    transition = TensorDict(
        {
            "obs": obs,
            "action": a,
            "reward": r,
            "next_obs": next_obs,
            "done": d,
        },
        batch_size=[],
    )
    rb.add(transition)

    if d.item():
        td = env.reset()

print("Replay buffer size:", len(rb))

Replay buffer size: 5000


üîç Note: Difference Between `td` and `transition`

| Variable | Role | Structure | Usage |
|-----------|------|------------|--------|
| **`td`** | The live **TensorDict** returned by the environment through `env.reset()` or `env.step()`. It contains both the current and next-step information (nested under `"next"`). | `{ "action": ..., "next": { "pixels": ..., "reward": ..., "done": ... } }` | Used for **interacting with the environment** ‚Äî passed into `env.step()` and updated after each action. |
| **`transition`** | A **flattened TensorDict** created from the fields of `td`, containing exactly one tuple \((s_t, a_t, r_t, s_{t+1}, done_t)\). | `{ "obs": ..., "action": ..., "reward": ..., "next_obs": ..., "done": ... }` | Used for **storing in the replay buffer** and later sampling for training (e.g., in DQN updates). |

**In short:**
- `td` is the environment‚Äôs **structured live output** for the current step.  
- `transition` is the **simplified snapshot** of one experience transition that gets pushed into the replay buffer.  

In [15]:
batch = rb.sample(32)

###  Part 3: DQN on Atari

In this section, you will complete the implementation of the **Deep Q-Network (DQN)** algorithm for an Atari environment (e.g., *Pong*).  
The provided code initializes the environment, replay buffer, and Q-networks. Your task is to **connect all the pieces** to form the full DQN learning process.

---

### ‚öôÔ∏è Background

DQN learns an approximate action-value function $ Q_\theta(s, a) $ by minimizing the **Bellman error**:

$$
L(\theta) = \mathbb{E}\Big[(Q_\theta(s_t, a_t) - y_t)^2\Big],
$$
where
$$
y_t = r_t + \gamma (1 - d_t) \max_{a'} Q_{\theta^-}(s_{t+1}, a')
$$
and $ Q_{\theta^-} $ is the **target network** with delayed parameters.

---

### üß© Provided Components

You are already given:
- ‚úÖ A TorchRL Atari environment (`env`)
- ‚úÖ Replay buffer `rb`
- ‚úÖ Online and target Q-networks (`q`, `q_target`)
- ‚úÖ An optimizer and discount factor (`optimizer`, `gamma`)
- ‚úÖ Pre-written code for:
  - Sampling from `rb`
  - Computing `target` and `loss`
  - Performing one gradient update

You will now write the **training loop** to combine these elements.

---

In [None]:
# Your time to work on it (See below for some hints)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
'''
Hint 1: Neural Network Design
'''
class QNet(nn.Module):
    def __init__(self, n_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512), nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def forward(self, x):
        # x: (B,4,84,84) float32 in [0,1]
        z = self.conv(x)
        z = z.view(z.size(0), -1)
        return self.fc(z)

q = QNet(n_actions).to(device)
q_target = QNet(n_actions).to(device)
q_target.load_state_dict(q.state_dict())
q_target.eval()

optimizer = optim.Adam(q.parameters(), lr=1e-4)
gamma = 0.99

In [13]:
'''
Hint 2: How to implement gradient descent for Q-learning
'''

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch = rb.sample(batch_size)
obs_b      = batch["obs"].to(device)              # (B,4,84,84)
act_b      = batch["action"].long().to(device)    # (B,)
rew_b      = batch["reward"].to(device).squeeze(-1)  # make sure it's (B,)
next_obs_b = batch["next_obs"].to(device)         # (B,4,84,84)
done_b     = batch["done"].to(device).float().squeeze(-1)


with torch.no_grad():
    # target = r + gamma * (1-done) * max_a' Q_target(s',a')
    q_next = q_target(next_obs_b).max(1).values
    target = rew_b + gamma * (1.0 - done_b) * q_next

act_b_ind = act_b.argmax(dim=-1)
q_values = q(obs_b).gather(1, act_b_ind.view(-1, 1)).squeeze(1)
loss = F.smooth_l1_loss(q_values, target)

optimizer.zero_grad(set_to_none=True)
loss.backward()
nn.utils.clip_grad_norm_(q.parameters(), max_norm=10.0)
optimizer.step()

In [14]:
# Hint 3: Epsilon-greedy policy
def select_action(obs, eps: float):
    if torch.rand(1).item() < eps:
        # Use TorchRL action_spec for a proper tensor action
        return env.action_spec.rand()  # scalar tensor (long)
    with torch.no_grad():
        x = obs.unsqueeze(0).to(device)   # (1,4,84,84)
        qvals = q(x)                               # (1,n_actions)
        a = torch.argmax(qvals, dim=1).to("cpu")   # back to CPU
        return a.squeeze(0)      