# Pong - Model Architecture

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os

cwd = os.getcwd()
src_path = os.path.join(cwd, 'src')

print(f'adding {src_path} to path')

if src_path not in sys.path:
    sys.path.append(src_path)

import gymnasium as gym
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import itertools
import ale_py
from collections import deque
import cv2
import time

from dqn_a import DQN

gym.register_envs(ale_py)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
adding /Users/jollyjerr/code/@school/pong/src to path


In [34]:
def preprocess(obs):
    gray = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
    resized = cv2.resize(gray, (64, 64), interpolation=cv2.INTER_AREA)
    return resized / 255.0

def epsilon_by_frame(frame_idx, eps_start, eps_end, eps_decay):
    return max(eps_end, eps_start - (eps_start - eps_end) * frame_idx / eps_decay)

def sample_batch(buffer, batch_size, device):
    batch = random.sample(buffer, batch_size)
    states, actions, rewards, next_states, dones = zip(*batch)

    return (
        torch.tensor(np.array(states), dtype=torch.float32).to(device),
        torch.tensor(np.array(actions), dtype=torch.int64).to(device),
        torch.tensor(np.array(rewards), dtype=torch.float32).to(device),
        torch.tensor(np.array(next_states), dtype=torch.float32).to(device),
        torch.tensor(np.array(dones), dtype=torch.float32).to(device)
    )

In [35]:
grid = {
    "learning_rate": [1e-4, 5e-4],
    "fc_units": [256, 512],
    "gamma": [0.95, 0.99]
}

param_combos = list(itertools.product(*grid.values()))

conv_layers = [
    {"out_channels": 32, "kernel_size": 5, "stride": 2},  # (32, 30, 30)
    {"out_channels": 64, "kernel_size": 4, "stride": 2},  # (64, 14, 14)
    {"out_channels": 64, "kernel_size": 3, "stride": 2},  # (64, 6, 6)
]

In [36]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
REPLAY_BUFFER_SIZE = 100_000
BATCH_SIZE = 32
TARGET_UPDATE_FREQ = 1000
EPS_START = 1.0
EPS_END = 0.1
EPS_DECAY = 1e6
MAX_FRAMES = 1000 # 200_000
STACK_SIZE = 4

In [37]:
env = gym.make('ALE/Pong-v5', render_mode=None)

In [40]:
def run_training(hparams):
    learning_rate, fc_units, gamma = hparams
    print(f"\n=== Training with LR={learning_rate}, FC={fc_units}, Gamma={gamma} ===")

    env = gym.make("ALE/Pong-v5", render_mode=None)
    n_actions = env.action_space.n
    obs_shape = (STACK_SIZE, 64, 64)

    policy_net = DQN(obs_shape, n_actions, fc_units=fc_units, conv_layers=conv_layers).to(DEVICE)
    target_net = DQN(obs_shape, n_actions, fc_units=fc_units, conv_layers=conv_layers).to(DEVICE)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)
    replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)

    state = preprocess(env.reset()[0])
    state_stack = np.stack([state] * STACK_SIZE, axis=0)
    episode_reward = 0
    episode_rewards = []

    frame_idx = 0

    while frame_idx < MAX_FRAMES:
        epsilon = epsilon_by_frame(frame_idx, EPS_START, EPS_END, EPS_DECAY)

        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                state_tensor = torch.tensor([state_stack], dtype=torch.float32).to(DEVICE)
                q_values = policy_net(state_tensor)
                action = q_values.argmax(dim=1).item()

        next_obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        next_frame = preprocess(next_obs)
        next_state_stack = np.roll(state_stack, -1, axis=0)
        next_state_stack[-1] = next_frame

        replay_buffer.append((state_stack, action, reward, next_state_stack, done))
        state_stack = next_state_stack
        episode_reward += reward
        frame_idx += 1

        if len(replay_buffer) >= BATCH_SIZE:
            states, actions, rewards, next_states, dones = sample_batch(replay_buffer, BATCH_SIZE, DEVICE)

            q_values = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
            with torch.no_grad():
                next_q = target_net(next_states).max(1)[0]
                expected_q = rewards + gamma * next_q * (1 - dones)

            loss = F.smooth_l1_loss(q_values, expected_q)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if frame_idx % TARGET_UPDATE_FREQ == 0:
            target_net.load_state_dict(policy_net.state_dict())

        if done:
            state = preprocess(env.reset()[0])
            state_stack = np.stack([state] * STACK_SIZE, axis=0)
            episode_rewards.append(episode_reward)
            print(f"Frame {frame_idx}, Reward: {episode_reward:.2f}, Epsilon: {epsilon:.3f}")
            episode_reward = 0

    avg_reward = np.mean(episode_rewards[-10:])
    print(f"[FINISHED] Avg reward over last 10 episodes: {avg_reward:.2f}")
    env.close()
    torch.save(policy_net.state_dict(), "temp_best_model.pth")
    return avg_reward

In [41]:
results = {}
start_time = time.time()

best_reward = float("-inf")
best_hparams = None
best_model_state = None

for hparams in param_combos:
    avg_reward = run_training(hparams)
    results[hparams] = avg_reward

    if avg_reward > best_reward:
        best_reward = avg_reward
        best_hparams = hparams
        best_model_state = torch.load("temp_best_model.pth")

print("\n=== Grid Search Results ===")
for hparams, reward in sorted(results.items(), key=lambda x: -x[1]):
    print(f"LR={hparams[0]:.1e}, FC={hparams[1]}, Gamma={hparams[2]} --> Avg10 Reward: {reward:.2f}")

if best_model_state is not None:
    torch.save(best_model_state, "best_dqn_model.pth")
    print(f"\nBest model saved to 'best_dqn_model.pth' (LR={best_hparams[0]}, FC={best_hparams[1]}, Gamma={best_hparams[2]})")

print(f"\nTotal tuning time: {(time.time() - start_time)/60:.1f} min")


=== Training with LR=0.0001, FC=256, Gamma=0.95 ===
Frame 852, Reward: -21.00, Epsilon: 0.999
[FINISHED] Avg reward over last 10 episodes: -21.00

=== Training with LR=0.0001, FC=256, Gamma=0.99 ===
Frame 882, Reward: -21.00, Epsilon: 0.999
[FINISHED] Avg reward over last 10 episodes: -21.00

=== Training with LR=0.0001, FC=512, Gamma=0.95 ===


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


[FINISHED] Avg reward over last 10 episodes: nan

=== Training with LR=0.0001, FC=512, Gamma=0.99 ===
[FINISHED] Avg reward over last 10 episodes: nan

=== Training with LR=0.0005, FC=256, Gamma=0.95 ===


KeyboardInterrupt: 

In [None]:
results = {}
start_time = time.time()

best_reward = float("-inf")
best_hparams = None
best_model_state = None

for hparams in param_combos:
    avg_reward = run_training(hparams)
    results[hparams] = avg_reward

    if avg_reward > best_reward:
        best_reward = avg_reward
        best_hparams = hparams
        best_model_state = torch.load("temp_best_model.pth")

print("\n=== Grid Search Results ===")
for hparams, reward in sorted(results.items(), key=lambda x: -x[1]):
    print(f"LR={hparams[0]:.1e}, FC={hparams[1]}, Gamma={hparams[2]} --> Avg10 Reward: {reward:.2f}")



=== Training with LR=0.0001, FC=256, Gamma=0.95 ===
Frame 1025, Reward: -19.00, Epsilon: 0.999


  state_tensor = torch.tensor([state_stack], dtype=torch.float32).to(DEVICE)


Frame 1969, Reward: -21.00, Epsilon: 0.998
Frame 2761, Reward: -21.00, Epsilon: 0.998
Frame 3961, Reward: -18.00, Epsilon: 0.996
Frame 4948, Reward: -20.00, Epsilon: 0.996
Frame 5868, Reward: -19.00, Epsilon: 0.995
Frame 6753, Reward: -21.00, Epsilon: 0.994
Frame 7663, Reward: -21.00, Epsilon: 0.993
Frame 8446, Reward: -21.00, Epsilon: 0.992
Frame 9470, Reward: -19.00, Epsilon: 0.991
Frame 10295, Reward: -21.00, Epsilon: 0.991
Frame 11087, Reward: -21.00, Epsilon: 0.990
Frame 12276, Reward: -19.00, Epsilon: 0.989
Frame 13254, Reward: -21.00, Epsilon: 0.988
Frame 14108, Reward: -21.00, Epsilon: 0.987
Frame 14933, Reward: -21.00, Epsilon: 0.987
Frame 15845, Reward: -21.00, Epsilon: 0.986
Frame 16897, Reward: -19.00, Epsilon: 0.985
Frame 17764, Reward: -20.00, Epsilon: 0.984
Frame 18652, Reward: -20.00, Epsilon: 0.983
Frame 19573, Reward: -19.00, Epsilon: 0.982
Frame 20337, Reward: -21.00, Epsilon: 0.982
Frame 21313, Reward: -21.00, Epsilon: 0.981
Frame 22134, Reward: -21.00, Epsilon: 0.9

KeyboardInterrupt: 

In [9]:
print(torch.__version__)
print(np.__version__)

2.6.0+cu124
2.0.2


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gymnasium as gym
import numpy as np
import random
import cv2
import time
from collections import deque
import ale_py

gym.register_envs(ale_py)

In [3]:
class DQN(nn.Module):
     def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()

        self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

        conv_w = self._conv_output_size(input_shape[1], 8, 4)
        conv_w = self._conv_output_size(conv_w, 4, 2)
        conv_w = self._conv_output_size(conv_w, 3, 1)

        conv_h = self._conv_output_size(input_shape[2], 8, 4)
        conv_h = self._conv_output_size(conv_h, 4, 2)
        conv_h = self._conv_output_size(conv_h, 3, 1)

        linear_input_size = conv_w * conv_h * 64

        self.fc1 = nn.Linear(linear_input_size, 512)
        self.fc2 = nn.Linear(512, n_actions)

        self._initialize_weights()

    def _conv_output_size(self, size, kernel_size, stride):
        return (size - kernel_size) // stride + 1

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [4]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

In [5]:
def preprocess_frame(frame):
    """Preprocess frame: grayscale, resize, normalize"""
    gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    resized = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)
    normalized = resized.astype(np.float32) / 255.0
    return normalized

def stack_frames(stacked_frames, frame, is_new_episode=False):
    frame = preprocess_frame(frame)

    if is_new_episode:
        stacked_frames = deque([frame for _ in range(4)], maxlen=4)
    else:
        stacked_frames.append(frame)

    stacked_state = np.stack(stacked_frames, axis=0)
    return stacked_state, stacked_frames

def epsilon_greedy_policy(q_values, epsilon):
    if random.random() < epsilon:
        return random.randint(0, len(q_values) - 1)
    else:
        return np.argmax(q_values)

def compute_td_loss(policy_net, target_net, replay_buffer, batch_size, gamma, device):
    if len(replay_buffer) < batch_size:
        return None

    state, action, reward, next_state, done = replay_buffer.sample(batch_size)

    state = torch.FloatTensor(state).to(device)
    action = torch.LongTensor(action).to(device)
    reward = torch.FloatTensor(reward).to(device)
    next_state = torch.FloatTensor(next_state).to(device)
    done = torch.BoolTensor(done).to(device)

    q_values = policy_net(state)
    q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)

    with torch.no_grad():
        next_q_values = target_net(next_state)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = reward + (gamma * next_q_value * ~done)

    loss = F.smooth_l1_loss(q_value, expected_q_value)
    return loss

In [6]:
def train_dqn():
    # Hyperparameters
    LEARNING_RATE = 1e-4
    GAMMA = 0.99
    EPSILON_START = 1.0
    EPSILON_END = 0.01
    EPSILON_DECAY = 1000000
    BATCH_SIZE = 16          # Reduced from 32
    TARGET_UPDATE = 10000
    MEMORY_SIZE = 25000      # Reduced from 100000
    MIN_MEMORY = 10000       # Reduced from 50000
    MAX_EPISODES = 10000

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    env = gym.make('ALE/Pong-v5', render_mode=None)
    n_actions = env.action_space.n

    policy_net = DQN((4, 84, 84), n_actions).to(device)
    target_net = DQN((4, 84, 84), n_actions).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
    replay_buffer = ReplayBuffer(MEMORY_SIZE)

    episode_rewards = []
    losses = []
    frame_count = 0
    best_reward = float('-inf')

    print("Starting training...")
    start_time = time.time()

    for episode in range(MAX_EPISODES):
        obs, _ = env.reset()
        stacked_frames = None
        state, stacked_frames = stack_frames(stacked_frames, obs, is_new_episode=True)

        episode_reward = 0
        episode_loss = []

        while True:
            frame_count += 1

            epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
                     np.exp(-1.0 * frame_count / EPSILON_DECAY)

            with torch.no_grad():
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
                q_values = policy_net(state_tensor).cpu().numpy()[0]

            action = epsilon_greedy_policy(q_values, epsilon)

            next_obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

            reward = np.clip(reward, -1, 1)

            next_state, stacked_frames = stack_frames(stacked_frames, next_obs, is_new_episode=False)

            replay_buffer.push(state, action, reward, next_state, done)

            state = next_state
            episode_reward += reward

            if len(replay_buffer) >= MIN_MEMORY and frame_count % 4 == 0:
                loss = compute_td_loss(policy_net, target_net, replay_buffer,
                                     BATCH_SIZE, GAMMA, device)
                if loss is not None:
                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 10)
                    optimizer.step()

                    losses.append(loss.item())
                    episode_loss.append(loss.item())

            if frame_count % TARGET_UPDATE == 0:
                target_net.load_state_dict(policy_net.state_dict())
                print(f"Updated target network at frame {frame_count}")

            if done:
                break

        episode_rewards.append(episode_reward)

        if episode % 100 == 0:
            avg_reward = np.mean(episode_rewards[-100:])
            avg_loss = np.mean(losses[-1000:]) if len(losses) >= 1000 else np.mean(losses) if losses else 0

            elapsed_time = time.time() - start_time
            print(f"Episode {episode}")
            print(f"  Average Reward (last 100): {avg_reward:.2f}")
            print(f"  Current Epsilon: {epsilon:.3f}")
            print(f"  Average Loss: {avg_loss:.4f}")
            print(f"  Frame Count: {frame_count}")
            print(f"  Elapsed Time: {elapsed_time/60:.1f} min")
            print(f"  Replay Buffer Size: {len(replay_buffer)}")
            print("-" * 50)

            if avg_reward > best_reward:
                best_reward = avg_reward
                torch.save({
                    'policy_net_state_dict': policy_net.state_dict(),
                    'target_net_state_dict': target_net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'episode': episode,
                    'best_reward': best_reward,
                    'frame_count': frame_count
                }, 'best_dqn_pong.pth')
                print(f"New best model saved! Average reward: {best_reward:.2f}")

        if len(episode_rewards) >= 100 and np.mean(episode_rewards[-100:]) >= 18.0:
            print(f"Solved! Average reward over last 100 episodes: {np.mean(episode_rewards[-100:]):.2f}")
            break

    env.close()
    print("Training completed!")

    return episode_rewards

In [8]:
print("Starting DQN training for Pong...")
episode_rewards = train_dqn()

Starting DQN training for Pong...
Using device: cuda
Starting training...
Episode 0
  Average Reward (last 100): -20.00
  Current Epsilon: 0.999
  Average Loss: 0.0000
  Frame Count: 966
  Elapsed Time: 0.1 min
  Replay Buffer Size: 966
--------------------------------------------------
New best model saved! Average reward: -20.00
Updated target network at frame 10000
Updated target network at frame 20000
Updated target network at frame 30000
Updated target network at frame 40000
Updated target network at frame 50000
Updated target network at frame 60000
Updated target network at frame 70000
Updated target network at frame 80000
Updated target network at frame 90000
Episode 100
  Average Reward (last 100): -20.33
  Current Epsilon: 0.912
  Average Loss: 0.0032
  Frame Count: 93358
  Elapsed Time: 4.8 min
  Replay Buffer Size: 25000
--------------------------------------------------
Updated target network at frame 100000
Updated target network at frame 110000
Updated target network at f

KeyboardInterrupt: 