# Deep Q-Network on Atari games using Custom CNN

This notebook will explore the implementation of a DQN from Stable Baseline3 using a Custom CNN feature extractor on Atari Environment.

### What you will learn?
* Using Deep Q-Network from Stable Baseline3 Using Custom CNN
* Training on Pong Environment

## Setup
Install necessary dependencies.

In [None]:
# Installing the necessary packages for Atari environments
!pip install gymnasium
!pip install gymnasium[atari]
!pip install ale-py

In [None]:
# Installing necessary packages for visualization and virtual display
!sudo apt-get update
!sudo apt-get install -y cmake
# Updating the package list and installing ffmpeg and freeglut3-dev for visualization, xvfb for virtual display
!sudo apt-get install -y ffmpeg freeglut3-dev xvfb

In [None]:
# Installing Stable Baselines3 with extra dependencies
!pip install "stable-baselines3[extra]"
!pip install moviepy

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # suppress TF INFO/WARN/ERROR

Import libraries

In [None]:
# Import necessary libraries and modules
import gymnasium as gym
import stable_baselines3
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import torch
from pathlib import Path
import base64
from IPython import display as ipythondisplay

# Import utility functions for creating Atari environments and stacking frames
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

print(f"{gym.__version__=}")
print(f"{stable_baselines3.__version__=}")

Configure Environment

In [None]:
# Set up a fake display for rendering videos in the cloud environment
os.system("Xvfb :1 -screen 0 1024x768x24 &")
os.environ['DISPLAY'] = ':1'

## Callbacks and directory setup

In [None]:
# Define a callback class for saving models at regular intervals during training
class SaveOnIntervalCallback(BaseCallback):
    def __init__(self, save_interval: int, save_path: str, verbose=1):
        super().__init__(verbose)
        self.save_interval = save_interval
        self.save_path = save_path

    def _on_step(self) -> bool:
        # Save the model every 'save_interval' steps
        if self.num_timesteps % self.save_interval == 0:
            save_file = os.path.join(self.save_path, f'model_{self.num_timesteps}')
            self.model.save(save_file)
            if self.verbose > 0:
                print(f'Saving model to {save_file}.zip')
        return True


In [None]:
# Creating directories for storing logs and models
log_dir = "pong/logs"  # Directory for storing training logs
models_dir = "pong/models/"  # Directory for storing models

# Ensuring the directories exist or creating them
os.makedirs(log_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)


## Define Custom Feature Extractor and Create DQN Model

In [None]:
# Define a custom CNN feature extractor for processing observations from the environment
class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        # Define the convolutional layers
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_space.shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute the size of the output tensor after passing through the CNN
        with torch.no_grad():
            n_flatten = self.cnn(
                torch.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        # Define the linear layers that follow the convolutional layers
        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        # Process the observations through the CNN and linear layers
        return self.linear(self.cnn(observations))


#### Initial Atari Environment for 'Pong'

In [None]:
import ale_py
env_id = "PongNoFrameskip-v4"
# Initialize the Atari environment with the specified game and configurations
env = make_atari_env(env_id, n_envs=4, seed=0, monitor_dir=log_dir)
# Stack 4 consecutive frames together to provide temporal information
env = VecFrameStack(env, n_stack=4)

#### Initial the DQN model using custon CNN feature extractor

In [None]:
# Initialize the DQN agent with specified parameters
model = DQN(
    env=env,
    policy='CnnPolicy',
    verbose=1,
    learning_rate=0.0001,
    buffer_size=100000,
    learning_starts=1000000,
    gradient_steps=1,
    exploration_fraction=0.1,
    exploration_final_eps=0.01,
    train_freq=4,
    batch_size=32,
    ######-----------ADD CUSTOM CNN TO SB3----------------------########
    # Uncomment the next line to use the custom CNN feature extractor
    policy_kwargs={'features_extractor_class': CustomCNN}
    ######------------------------------------------------------########
)

## Train the Model

In [None]:
# Define the interval at which models are saved during training
save_interval = 100000
save_callback = SaveOnIntervalCallback(save_interval, models_dir)

# Train the DQN agent
model.learn(total_timesteps=10000000, callback=save_callback)

# Save the final model after training completes
final_model_path = os.path.join(models_dir, f'model_{total_timesteps}')
model.save(final_model_path)

## Performance evaluation

In [None]:
def plot_results(log_folder: str):
    """
    Plots the training curve from the Monitor log file.
    :param log_folder: the save directory of the Monitor logs
    """
    x, y = ts2xy(load_results(log_folder), 'timesteps')
    
    # Smooth the curve
    y_smooth = np.convolve(y, np.ones(100)/100, mode='valid')
    x_smooth = x[len(x) - len(y_smooth):]

    fig = plt.figure("Training Curve")
    plt.plot(x_smooth, y_smooth, label="Smoothed Reward")
    plt.plot(x, y, alpha=0.2, label="Raw Reward")
    plt.xlabel("Number of Timesteps")
    plt.ylabel("Rewards")
    plt.title("Training Curve")
    plt.legend()
    plt.show()

# Call the function after training to see the learning curve
plot_results(log_dir)

## Video Recording and Display Functions

In [None]:
# Functions to record videos of the agent playing and display the videos

def show_videos(video_path="", prefix=""):
    """Displays videos from a specified directory."""
    html = []
    for mp4 in Path(video_path).glob("{}*.mp4".format(prefix)):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append(
            '''<video alt="{0}" autoplay
                      loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{1}" type="video/mp4" />
            </video>'''.format(mp4, video_b64.decode('ascii'))
        )
    ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))


def record_video(env_id, model, video_length=500, prefix="", video_folder="pong/videos"):
    eval_env = make_atari_env(env_id, n_envs=1, seed=0, vec_env_cls=DummyVecEnv)  # Create a single Atari environment
    eval_env = VecFrameStack(eval_env, n_stack=4)  # Stack 4 frames together for temporal information
    eval_env = VecVideoRecorder(  # Initialize video recorder
        eval_env,
        video_folder=video_folder,  # Specify folder to save videos
        record_video_trigger=lambda step: step == 0,  # Set trigger to start recording at step 0
        video_length=video_length,  # Set video length
        name_prefix=prefix,  # Set prefix for video filenames
    )
    obs = eval_env.reset()  # Reset the environment to get initial observation
    for _ in range(video_length):  # Loop through for the specified video length
        action, _ = model.predict(obs)  # Predict action based on current observation
        obs, _, _, _ = eval_env.step(action)  # Execute action in the environment
    eval_env.close()  # Close the environment and video recorder

## Record and Display Videos at Different Training Stages

In [None]:
# Functions for analyzing the trained models

# Function to get the identifiers of saved models
def get_model_identifiers(models_dir):
    files = os.listdir(models_dir)  # Listing files in the models directory
    model_files = [f for f in files if f.startswith('model_')]  # Filtering out model files
    identifiers = [f.split('_')[1] for f in model_files]  # Extracting identifiers from file names
    return identifiers

# Function to find key identifiers (earliest, middle, final)
def find_key_identifiers(identifiers):
    identifiers.sort()  # Sorting identifiers
    earliest = identifiers[0]  # Earliest identifier
    final = identifiers[-1]  # Final identifier
    middle = identifiers[len(identifiers) // 2]  # Middle identifier
    return earliest, middle, final

# Function to view videos of the models at different training stages
def view(models_dir):
    identifiers = get_model_identifiers(models_dir)  # Getting model identifiers
    print(identifiers)
    earliest, middle, final = find_key_identifiers(identifiers)  # Finding key identifiers

    # Recording and displaying videos at the beginning, middle, and end of training
    for stage, identifier in zip(["beginning", "middle", "end"], [earliest, middle, final]):
        model_path = os.path.join(models_dir, f'model_{identifier}')  # Forming the model path
        model = DQN.load(model_path)  # Loading the model
        record_video("PongNoFrameskip-v4", model, video_length=5000, prefix=f'dqn-pong-{stage}')  # Recording video
        show_videos("pong/videos/", prefix=f'dqn-pong-{stage}')  # Showing videos

In [None]:
# Changes the model directory below
models_dir="pong/models"
view(models_dir)  # Calling the view function

In [None]:
# show_videos(video_path="", prefix="dqn-pong-")