<a href="https://colab.research.google.com/github/ntmdrgl/CSCI-166-DQN-Atari-Project/blob/main/CSCI_166_Play_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://drive.google.com/file/d/1i7p5jXJsxDiqdT2e2M8z1UijP33lUZbk/view?usp=drive_link
model_name = "ALE_Pong-v5-best_1-20250612-2042-test_epsdec150000_rs10000_sync1000.dat"
!gdown 1i7p5jXJsxDiqdT2e2M8z1UijP33lUZbk

Downloading...
From: https://drive.google.com/uc?id=1i7p5jXJsxDiqdT2e2M8z1UijP33lUZbk
To: /content/ALE_Pong-v5-best_-11-20251130-2340-test_epsdec10000_rs1000_sync500.dat
  0% 0.00/6.75M [00:00<?, ?B/s] 70% 4.72M/6.75M [00:00<00:00, 35.0MB/s]100% 6.75M/6.75M [00:00<00:00, 37.1MB/s]


In [None]:
!pip install gymnasium[atari,accept-rom-license]
!pip install autorom
!pip install stable-baselines3



In [None]:
!AutoROM --accept-license

AutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/usr/local/lib/python3.12/dist-packages/AutoROM/roms

Existing ROMs will be overwritten.


In [None]:
import ale_py
import gymnasium as gym

env = gym.make("ALE/Pong-v5", render_mode="rgb_array")
obs, _ = env.reset()
print(obs.shape)
env.close()

(210, 160, 3)


In [None]:
from dataclasses import dataclass
import argparse
import time
import numpy as np
import collections
import typing as tt

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.tensorboard.writer import SummaryWriter

In [None]:
#dqn_model
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
        )
        size = self.conv(torch.zeros(1, *input_shape)).size()[-1]
        self.fc = nn.Sequential(
            nn.Linear(size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    def forward(self, x: torch.ByteTensor):
        x = x.float() / 255.0
        return self.fc(self.conv(x))

In [None]:
#wrappers

from gymnasium import spaces
from stable_baselines3.common import atari_wrappers
from stable_baselines3.common.atari_wrappers import AtariWrapper


class ImageToPyTorch(gym.ObservationWrapper):
    """
    ImageToPyTorch: Reorders image dimensions from (H, W, C) to (C, H, W)
    for compatibility with PyTorch convolutional layers.
    """
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        obs = self.observation_space
        assert isinstance(obs, gym.spaces.Box)
        assert len(obs.shape) == 3
        new_shape = (obs.shape[-1], obs.shape[0], obs.shape[1])
        self.observation_space = gym.spaces.Box(
            low=obs.low.min(), high=obs.high.max(),
            shape=new_shape, dtype=obs.dtype)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class BufferWrapper(gym.ObservationWrapper):
    """
    BufferWrapper: Maintains a rolling window of the last `n_steps` frames
    to give the agent a sense of temporal context.
    """
    def __init__(self, env, n_steps):
        super(BufferWrapper, self).__init__(env)
        obs = env.observation_space
        assert isinstance(obs, spaces.Box)
        new_obs = gym.spaces.Box(
            obs.low.repeat(n_steps, axis=0), obs.high.repeat(n_steps, axis=0),
            dtype=obs.dtype)
        self.observation_space = new_obs
        self.buffer = collections.deque(maxlen=n_steps)

    def reset(self, *, seed: tt.Optional[int] = None, options: tt.Optional[dict[str, tt.Any]] = None):
        for _ in range(self.buffer.maxlen):
            self.buffer.append(np.zeros_like(self.env.observation_space.low))
        obs, extra = self.env.reset()
        return self.observation(obs), extra

    def observation(self, observation: np.ndarray) -> np.ndarray:
        self.buffer.append(observation)
        return np.concatenate(self.buffer)

def make_env(env_name: str, n_steps=4, render_mode=None, **kwargs):
    print(f"Creating environment {env_name}")
    env = gym.make(env_name, render_mode=render_mode, **kwargs)
    env = atari_wrappers.AtariWrapper(env, clip_reward=False, noop_max=0)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, n_steps=n_steps)
    return env

def make_eval_env(env_name, n_steps=4, render_mode="rgb_array"):
    env = gym.make(env_name, render_mode=render_mode)
    env = AtariWrapper(env, clip_reward=False, noop_max=0)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, n_steps)
    return env


In [None]:
from pathlib import Path
from gymnasium.wrappers import RecordVideo

def make_eval_env_with_recording(env_name, model_name, video_dir="video", n_steps=4, render_mode="rgb_array"):
    # Clean prefix from filename
    model_base = Path(model_name).stem  # removes .dat
    env = make_env(env_name, n_steps=n_steps, render_mode=render_mode)
    env = RecordVideo(
        env,
        video_folder=video_dir,
        name_prefix=model_base,
        episode_trigger=lambda ep: True  # record the first episode
    )
    return env


# Run Model

In [None]:
# Base Configuration
DEFAULT_ENV_NAME = "ALE/Pong-v5"

In [None]:
## Load the Model
env_name = DEFAULT_ENV_NAME
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# env = make_env(env_name, render_mode="rgb_array")
env = make_eval_env_with_recording(env_name, model_name)
obs, _ = env.reset()
net = DQN(env.observation_space.shape, env.action_space.n)
net.load_state_dict(torch.load(model_name, map_location="cpu"))



Creating environment ALE/Pong-v5


  logger.warn(


<All keys matched successfully>

In [None]:
from gymnasium.wrappers import RecordVideo
from IPython.display import HTML
import base64
from pathlib import Path

# Create renderable environment and wrap for video recording
video_dir = "video"
env = RecordVideo(env, video_folder=video_dir, name_prefix="dqn_pong_eval", episode_trigger=lambda x: True)

obs, _ = env.reset()
total_reward = 0.0

while True:
    with torch.no_grad():
        state_v = torch.as_tensor(obs).unsqueeze(0).to(device)
        q_vals = net(state_v)
        action = int(torch.argmax(q_vals, dim=1).item())

    obs, reward, terminated, truncated, _ = env.step(action)
    total_reward += reward
    if terminated or truncated:
        break

env.close()
print(f"‚úÖ Episode finished. Total reward: {total_reward:.2f}")

‚úÖ Episode finished. Total reward: -8.00


In [None]:
from IPython.display import HTML
from pathlib import Path
import base64

def show_video_by_episode(episode_number: int, model_name: str, video_dir="video"):
    """
    Displays a specific episode's video recorded using RecordVideo, using the model filename as the prefix.

    Parameters:
    - episode_number: The episode number (0, 1, ...)
    - model_name: The full .dat model filename used as name_prefix in RecordVideo
    - video_dir: Folder where video files are stored
    """
    prefix = Path(model_name).stem  # Remove .dat extension
    filename = f"{prefix}-episode-{episode_number}.mp4"
    video_path = Path(video_dir) / filename

    if not video_path.exists():
        print(f"‚ùå Video not found: {video_path}")
        return

    print(f"üé• Showing episode {episode_number}: {video_path.name}")

    with open(video_path, "rb") as f:
        video_data = f.read()
    encoded = base64.b64encode(video_data).decode("ascii")

    video_html = f"""
    <video width="640" height="480" controls style="max-width: 100%;" controlslist="nodownload" allowfullscreen>
        <source src="data:video/mp4;base64,{encoded}" type="video/mp4">
        Your browser does not support the video tag.
    </video>
    <p>üí° Right-click and choose <strong>‚ÄúOpen video in new tab‚Äù</strong> to use fullscreen mode.</p>
    """

    return HTML(video_html)

In [None]:
show_video_by_episode(1, model_name)

üé• Showing episode 1: ALE_Pong-v5-best_1-20250612-2042-test_epsdec150000_rs10000_sync1000-episode-1.mp4
