In [None]:
%pip install "gymnasium[atari,accept-rom-license]==1.0.0" "ale-py==0.9.1"

In [4]:
import sys, os
import gymnasium as gym
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
import numpy as np
import random
from gymnasium.spaces import Box
from collections import deque
import copy
from gymnasium.wrappers import FrameStackObservation
import ale_py

%matplotlib inline

In [5]:
class SkipFrame(gym.Wrapper):
    def __init__(self, env, num_skip):
        super().__init__(env)
        self.num_skip = num_skip

    def step(self, action):
        total_reward = 0.0
        for _ in range(self.num_skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            if terminated or truncated:
                break

        return obs, total_reward, terminated, truncated, info


class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape[:2]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.float32)

    def observation(self, observation):
        observation = np.transpose(observation, (2, 0, 1))
        observation = torch.tensor(observation.copy(), dtype=torch.float)
        transform = torchvision.transforms.Grayscale()
        observation = transform(observation)
        return observation


class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        self.shape = (shape, shape) if isinstance(shape, int) else tuple(shape)
        obs_shape = self.shape + self.observation_space.shape[2:]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.float32)

    def observation(self, observation):
        transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(self.shape),
                                                     torchvision.transforms.Normalize(0, 255)])
        return transforms(observation).squeeze(0)


class ExperienceReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def __len__(self):
        return len(self.memory)

    def store(self, state, next_state, action, reward, terminated, truncated):
        state = state.__array__()
        next_state = next_state.__array__()
        self.memory.append((state, next_state, action, reward, terminated, truncated))

    def sample(self, batch_size):
        # TODO: uniformly sample batches of Tensors for: state, next_state, action, reward, terminated, truncated
        # ...
        pass

In [None]:
seed = 957
np.random.seed(seed)
torch.manual_seed(seed)
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

gym.register_envs(ale_py)

env_rendering = False    # Set to False while training your model on Colab

# Create and preprocess the Atari Breakout environment
if env_rendering:
    env = gym.make("ALE/Breakout-v5", full_action_space=False, render_mode="human")
else:
    env = gym.make("ALE/Breakout-v5", full_action_space=False)

env = SkipFrame(env, num_skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=84)
env = FrameStackObservation(env, stack_size=4)

image_stack, h, w = env.observation_space.shape
num_actions = env.action_space.n
print(f'Number of stacked frames: {image_stack}')
print(f'Resized observation space dimensionality: {h}, {w}')
print(f'Number of available actions by the agent: {num_actions}')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

In [None]:
class DeepQNet(torch.nn.Module):
    def __init__(self, h, w, image_stack, num_actions):
        super(DeepQNet, self).__init__()
        # TODO: create a convolutional neural network
        pass

    def forward(self, x):
        # TODO: forward pass from the neural network
        pass


# TODO: create an online and target DQN (Hint: Use copy.deepcopy() and 
#       set requires_grad to False for the parameters of the target DQN)
online_dqn = ...
target_dqn = ...
online_dqn.to(device)
target_dqn.to(device)

In [8]:
def convert(x):
    return torch.tensor(x.__array__()).float()


class AtariAgent:
    def __init__(self, buffer, eps, eps_decay, min_eps, gamma, batch_size,
                 online_dqn, target_dqn, run_as_ddqn, 
                 optimizer, criterion, device,
                 max_train_frames, burn_in_phase, sync_target):

        self.buffer = buffer
        self.eps = eps
        self.eps_decay = eps_decay
        self.min_eps = min_eps
        self.gamma = gamma
        self.batch_size = batch_size

        self.online_dqn = online_dqn
        self.target_dqn = target_dqn
        self.run_as_ddqn = run_as_ddqn
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.max_train_frames = max_train_frames
        self.burn_in_phase = burn_in_phase
        self.sync_target = sync_target

        self.current_step = 0


    def policy(self, state, is_training):
        state = convert(state).unsqueeze(0).to(self.device)

        # TODO: Implement an epsilon-greedy policy
        pass


    def compute_loss(self, state, action, reward, next_state, truncated, terminated):
        state, action, reward, next_state, truncated, terminated = [x.to(self.device) for x in 
                                (state, action, reward, next_state, truncated, terminated)]

        # TODO: Compute the DQN (or DDQN) loss based on self.criterion
        pass


    def run_episode(self, is_training):
        episode_reward, episode_loss = 0, 0.
        state, _ = env.reset(seed=seed)

        for t in range(self.max_train_frames):
            action = self.policy(state, is_training)
            self.current_step += 1
            next_state, reward, terminated, truncated, _ = env.step(action)

            episode_reward += reward

            if is_training:
                self.buffer.store(state, next_state, action, reward, terminated, truncated)

                if self.current_step > self.burn_in_phase:
                    state_batch, next_state_batch, action_batch, \
                        reward_batch, terminated_batch, truncated_batch = self.buffer.sample(self.batch_size)

                    if self.current_step % self.sync_target == 0:
                        # TODO: Periodically update your target_dqn at each sync_target frames
                        pass

                    loss = self.compute_loss(state_batch, action_batch, reward_batch, 
                                             next_state_batch, terminated_batch, truncated_batch)
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    episode_loss += loss.detach().item()
            else:
                with torch.no_grad():
                    st = convert(state).to(self.device).unsqueeze(0)
                    next_st = convert(next_state).to(self.device).unsqueeze(0)
                    act = action.to(self.device)
                    rew = torch.tensor(reward).to(self.device)
                    trunc = torch.tensor(truncated).to(self.device)
                    term = torch.tensor(terminated).to(self.device)

                    episode_loss += self.compute_loss(st, act, rew, next_st, term, trunc).item()

            state = next_state

            if self.current_step > self.burn_in_phase and self.eps > self.min_eps:
                self.eps *= self.eps_decay

            if terminated or truncated:
                break

        return dict(reward=episode_reward, loss=episode_loss / t)


    def save_checkpoint(self, train_metrics, save_filename):
        save_dict = {'curr_step': self.current_step,
                    'train_metrics': train_metrics,
                    'eps': self.eps,
                    'online_dqn': self.online_dqn.state_dict(),
                    'target_dqn': self.target_dqn.state_dict()}

        torch.save(save_dict, save_filename)

In [9]:
def update_metrics(metrics, episode):
    for k, v in episode.items():
        metrics[k].append(v)


def print_metrics(it, metrics, is_training, window=100):
    reward_mean = np.mean(metrics['reward'][-window:])
    loss_mean = np.mean(metrics['loss'][-window:])
    mode = "train" if is_training else "test"
    print(f"Episode {it:4d} | {mode:5s} | reward {reward_mean:5.5f} | loss {loss_mean:5.5f}")

In [None]:
# Hyperparameters (TODO: modify as needed)
batch_size = 32
alpha = 0.00025
gamma = 0.95
eps, eps_decay, min_eps = 1.0, 0.999, 0.05
buffer = ExperienceReplayMemory(20_000)
burn_in_phase = 20_000
sync_target = 30_000
max_train_frames = 10_000
max_train_episodes = 100_000
max_test_episodes = 1
run_as_ddqn = False # Set the run_as_ddqn flag to True if you want to run the DDQN algorithm
save_filename = './saved_model.pt'

# TODO: create the appropriate MSE criterion and Adam optimizer
optimizer = ...
criterion = ...

testing_mode = False # Change to True if you want to load a saved model

if testing_mode:
    # TODO: Load your saved online_dqn model for testing. 
    #       The target_dqn should be the same as the online_dqn (it isn't needed for testing).
    pass

agent = AtariAgent(buffer=buffer, eps=eps, eps_decay=eps_decay, min_eps=min_eps, gamma=gamma, batch_size=batch_size,
                   online_dqn=online_dqn, target_dqn=target_dqn, run_as_ddqn=run_as_ddqn,
                   optimizer=optimizer, criterion=criterion, device=device, 
                   max_train_frames=max_train_frames, burn_in_phase=burn_in_phase, sync_target=sync_target)

if testing_mode:
    test_metrics = dict(reward=[], loss=[])
    for it in range(max_test_episodes):
        episode_metrics = agent.run_episode(is_training=False)
        update_metrics(test_metrics, episode_metrics)
        print_metrics(it + 1, test_metrics, is_training=False)
else:
    train_metrics = dict(reward=[], loss=[])
    for it in range(max_train_episodes):
        episode_metrics = agent.run_episode(is_training=True)
        update_metrics(train_metrics, episode_metrics)
        if it % 50 == 0:
            print_metrics(it, train_metrics, is_training=True)
            agent.save_checkpoint(train_metrics, save_filename)

In [None]:
# TODO: Plot your train_metrics and test_metrics