In [4]:
# Update and install display packages and stable baseline 3
!apt-get update && apt-get install swig cmake -y
!apt-get update && apt-get install ffmpeg freeglut3-dev xvfb -y
!pip install box2d-py
!pip install moviepy
!pip install "stable-baselines3[extra]>=2.0.0a4"
!pip install gymnasium


Successfully installed AutoROM.accept-rom-license-0.6.1 absl-py-2.1.0 ale-py-0.8.1 autorom-0.6.1 farama-notifications-0.0.4 grpcio-1.67.0 gymnasium-0.29.1 importlib-resources-6.4.5 markdown-3.7 pygame-2.6.1 shimmy-1.3.0 stable-baselines3-2.4.0a7 tensorboard-2.18.0 tensorboard-data-server-0.7.2
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


## Import Libraries

In [5]:
# Import rquired libraries and modules
import os
import gymnasium as gym
import stable_baselines3
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import torch
from pathlib import Path
import base64
from IPython import display as ipythondisplay
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack


In [6]:
# Check gyma and stable baseline 3 versions
print(f"{gym.__version__=}")
print(f"{stable_baselines3.__version__=}")


gym.__version__='0.29.1'
stable_baselines3.__version__='2.4.0a7'


## Configure Environment

In [7]:
# Display holder for video
os.system("Xvfb :1 -screen 0 1024x768x24 &")
os.environ['DISPLAY'] = ':1'


# xvfb_command = "Xvfb :1 -screen 0 1024x768x24"
# xvfb_process = subprocess.Popen(xvfb_command.split())

# os.environ['DISPLAY'] = ':1'

# xvfb_process.terminate()  # Gracefully terminate


The XKEYBOARD keymap compiler (xkbcomp) reports:
> Internal error:   Could not resolve keysym XF86AudioPreset
> Internal error:   Could not resolve keysym XF86MonBrightnessCycle
> Internal error:   Could not resolve keysym XF86WWAN
> Internal error:   Could not resolve keysym XF86RFKill
> Internal error:   Could not resolve keysym XF86Keyboard
> Internal error:   Could not resolve keysym XF86RotationLockToggle
> Internal error:   Could not resolve keysym XF86FullScreen
Errors from xkbcomp are not fatal to the X server


## Callbacks and Directory Setup

In [8]:
# class to manage display
class Display:
    def __init__(self, command: str):
        self.command = command
        
    def start(self):
        self.process = subprocess.Popen(self.command.split())  
        os.environ['DISPLAY'] = ':1'
        
    def terminate(self):
        self.process.terminate()
        
# callback for saving model at regular intervals
class SaveOnIntervalCallback(BaseCallback):
    def __init__(self, save_interval: int, save_path: str, verbose=1):
        super().__init__(verbose)
        self.save_interval = save_interval
        self.save_path = save_path

    def _on_step(self) -> bool:
        # Save the model every 'save_interval' steps
        if self.num_timesteps % self.save_interval == 0:
            save_file = os.path.join(self.save_path, f'model_{self.num_timesteps}')
            self.model.save(save_file)
            if self.verbose > 0:
                print(f'Saving model to {save_file}.zip')
        return True


## Define Custom Feature Extractor and Create DQN Model

In [9]:
# Feature extraction from frames as observations / states
class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        # convolutional layers
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_space.shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        # flatten output tensor
        with torch.no_grad():
            n_flatten = self.cnn(
                torch.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        # fully conntected output layer
        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        # Process the observations through the CNN and linear layers
        return self.linear(self.cnn(observations))


In [10]:
# file saving locations
log_dir = "./Breakout/DQN/logs/"
models_dir = "./Breakout/DQN/models/"
os.makedirs(log_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)


#### Initial Atari Environment for 'Breakout'

In [11]:
# Initialize the Atari environment with the specified game and configurations
env = make_atari_env("ALE/Breakout-v5", n_envs=4, seed=0)
# Stack 4 consecutive frames together to provide temporal information
env = VecFrameStack(env, n_stack=5) # can fine tunenumber of frame images to learn from. Always more than 1. 4 frames in each state

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


#### Initial the DQN model using custon CNN feature extractor'

In [12]:
# Initialize the DQN agent with specified parameters
model = DQN(
    env=env,
    policy='CnnPolicy',
    verbose=1,
    learning_rate=0.0001,
    buffer_size=10000,          # how large is the buffer for the replay storage
    learning_starts=100000,     # starts at timestampe 100000, collecting data previously without learning
    gradient_steps=1,
    exploration_fraction=0.1,
    exploration_final_eps=0.01,
    train_freq=4,
    batch_size=32,
    policy_kwargs={'features_extractor_class': CustomCNN}
)


Using cuda device
Wrapping the env in a VecTransposeImage.


## Train the Model

In [None]:
# Define the interval at which models are saved during training
save_interval = 10000
save_callback = SaveOnIntervalCallback(save_interval, models_dir)

# Train the DQN agent
model.learn(total_timesteps=100000, callback=save_callback)

# Save the final model after training completes
final_model_path = os.path.join(models_dir, f'model_final')
model.save(final_model_path)

## Video Recording and Display Functions

In [14]:
# Functions to record videos of the agent playing and display the videos

def show_videos(video_path="", prefix=""):
    """Displays videos from a specified directory."""
    html = []
    for mp4 in Path(video_path).glob("{}*.mp4".format(prefix)):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append(
            '''<video alt="{0}" autoplay
                      loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{1}" type="video/mp4" />
            </video>'''.format(mp4, video_b64.decode('ascii'))
        )
    ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))


def record_video(env_id, model, video_length=500, prefix="", video_folder="./Breakout/DQN/videos/"):
    eval_env = make_atari_env(env_id, n_envs=1, seed=0, vec_env_cls=DummyVecEnv)  # Create a single Atari environment
    eval_env = VecFrameStack(eval_env, n_stack=4)  # Stack 4 frames together for temporal information
    eval_env = VecVideoRecorder(  # Initialize video recorder
        eval_env,
        video_folder=video_folder,  # Specify folder to save videos
        record_video_trigger=lambda step: step == 0,  # Set trigger to start recording at step 0
        video_length=video_length,  # Set video length
        name_prefix=prefix,  # Set prefix for video filenames
    )
    obs = eval_env.reset()  # Reset the environment to get initial observation
    for _ in range(video_length):  # Loop through for the specified video length
        action, _ = model.predict(obs)  # Predict action based on current observation
        obs, _, _, _ = eval_env.step(action)  # Execute action in the environment
    eval_env.close()  # Close the environment and video recorder




## Record and Display Videos at Different Training Stages

In [15]:
# Create a directory to save videos
os.makedirs("videos", exist_ok=True)

# Record a video of the trained agent
# record_video("BreakoutNoFrameskip-v4", model, video_length=1000, prefix='dqn-breakout')

# Display the video
import os

def get_model_identifiers(models_dir):
    files = os.listdir(models_dir)
    model_files = [f for f in files if f.startswith('model_')]
    identifiers = [f.split('_')[1] for f in model_files]
    return identifiers

def find_key_identifiers(identifiers):
    identifiers.sort()  # Ensure identifiers are sorted
    earliest = identifiers[0]
    final = identifiers[-1]
    middle = identifiers[len(identifiers) // 2]
    return earliest, middle, final

def view(models_dir):
    identifiers = get_model_identifiers(models_dir)
    earliest, middle, final = find_key_identifiers(identifiers)

    # Record videos at the beginning, middle, and end of training
    for stage, identifier in zip(["beginning", "middle", "end"], [earliest, middle, final]):
        model_path = os.path.join(models_dir, f'model_{identifier}')
        model = DQN.load(model_path)
        record_video("BreakoutNoFrameskip-v4", model, video_length=1000, prefix=f'dqn-breakout-{stage}')

    # Display the videos
    for stage in ["beginning", "middle", "end"]:
        show_videos("videos", prefix=f'dqn-breakout-{stage}')

# Changes the model directory below
view(models_dir)

  logger.warn(


ValueError: Error: Unexpected observation shape (1, 84, 84, 4) for Box environment, please use (5, 84, 84) or (n_env, 5, 84, 84) for the observation shape.