In [11]:
import gymnasium as gym
import ale_py

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import random

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from collections import deque, namedtuple
from dataclasses import dataclass, field

* State: What agent sees
* Action: What agent can do
* Rewards: Feedback after the agent does something
* Q-value: A number that says how good an action is

## Rewards

* Pac-Dot = 10 Pts
* Power Pellet = 50 Pts
* 1st Ghost = 200 Pts
* 2nd Ghost = 400 Pts
* 3rd Ghost = 800 Pts
* 4th Ghost = 1600 Pts
* Cherry = 100 Pts
* Strawberry = 300 Pts
* Orange = 500 Pts
* Apple = 700 Pts
* Melon = 1000 Pts
* Galaxian = 2000 Pts
* Bell = 3000 Pts
* Key = 5000 Pts

In [12]:
# At the beginning of each life there is a period of
# time where the game doesn't allow the player to
# move Pacman we skip those steps
AVOIDED_STEPS = 80

REWARD_LOG_BASE = 1000
LOSE_REWARD = -3
WIN_REWARD = 3

# Observation size
OBS_SIZE = 128

# Actions
ACTIONS = [
    1,  # UP
    2,  # RIGHT
    3,  # LEFT
    4,  # DOWN
]
N_ACTIONS = 4

EPSILON_MAX = 1.0
EPSILON_MIN = 0.1
EPSILON_DECAY = 1_000_000

# How often to update the target model so its the same
# as the policy model (in steps)
TARGET_UPDATE_FREQ = 8_000

# How many elements to pull from the replay buffer
# at once to use for training
BATCH_SIZE = 64

# Device to run on
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# How often should we save the model (in episodes)
SAVE_MODEL = 20

NUM_EPISODES = 1000

# How much future rewards are valued compared to current rewards
DISCOUNT = 0.99

LEARNING_RATE = 1e-3

# How many frames to skip
SKIP_FRAMES = 4

We log the reward scale so that large rewards (such as eating a ghost which can be up 1600 points) don't provide an outsized effect on the model.

In [13]:
def transform_reward(reward: int):
    return np.emath.logn(REWARD_LOG_BASE, reward) if reward > 0 else reward

In [14]:
Experience = namedtuple(
    "Experience", ("state", "action", "reward", "next_state", "done")
)

In [15]:
# Experience Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        samples = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*samples)

        # Convert to torch tensors
        states = torch.tensor(np.array(states), dtype=torch.float32, device=DEVICE)
        actions = torch.tensor(actions, dtype=torch.int64, device=DEVICE)
        rewards = torch.tensor(rewards, dtype=torch.float32, device=DEVICE)
        next_states = torch.tensor(
            np.array(next_states), dtype=torch.float32, device=DEVICE
        )
        dones = torch.tensor(dones, dtype=torch.float32, device=DEVICE)

        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)

In [16]:
class DQN(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(OBS_SIZE, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, N_ACTIONS)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)

        return x

In [17]:
@dataclass
class DataRecord:
    rewards: np.ndarray = field(
        default_factory=lambda: np.zeros(NUM_EPISODES, dtype=np.int64)
    )
    qvalues: np.ndarray = field(
        default_factory=lambda: np.zeros(NUM_EPISODES, dtype=np.float64)
    )
    losses: np.ndarray = field(
        default_factory=lambda: np.zeros(NUM_EPISODES, dtype=np.float64)
    )
    episodes: int = 0
    successes: int = field(default_factory=int)

    def __iter__(self):
        yield self.losses, self.rewards, self.qvalues

    def append(self, reward, qval, loss):
        # we need to dynamically reallocate it similar to what
        # std::vector does in C++
        if self.episodes >= len(self.rewards):
            # Double the size
            new_size = len(self.rewards) * 2
            self.rewards = np.resize(self.rewards, new_size)
            self.qvalues = np.resize(self.qvalues, new_size)
            self.losses = np.resize(self.losses, new_size)

        self.rewards[self.episodes] = reward
        self.qvalues[self.episodes] = qval
        self.losses[self.episodes] = loss

        self.episodes += 1

    def to_pandas(self):
        return pd.DataFrame(
            {
                "Rewards": self.rewards[: self.episodes],
                "Q-Values": self.qvalues[: self.episodes],
                "Losses": self.losses[: self.episodes],
            }
        )

    def plot(self, save_path=None):
        episodes = range(self.episodes)

        plt.figure(figsize=(12, 4))

        # Plot Rewards
        plt.subplot(1, 3, 1)
        plt.plot(episodes, self.rewards[: self.episodes], label="Reward")
        plt.title("Episode Rewards")
        plt.xlabel("Episode")
        plt.ylabel("Reward")
        plt.grid(True)

        # Plot Q-Values
        plt.subplot(1, 3, 2)
        plt.plot(
            episodes, self.qvalues[: self.episodes], label="Q-Value", color="orange"
        )
        plt.title("Mean Q-Values")
        plt.xlabel("Episode")
        plt.ylabel("Q-Value")
        plt.grid(True)

        # Plot Losses
        plt.subplot(1, 3, 3)
        plt.plot(episodes, self.losses[: self.episodes], label="Loss", color="red")
        plt.title("Training Loss")
        plt.xlabel("Episode")
        plt.ylabel("Loss")
        plt.grid(True)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path)
            print(f"Plot saved to {save_path}")
        else:
            plt.show()

In [18]:
class MsPacman:
    def __init__(self):
        self.steps_done = 0

        self.policy_net = DQN().to(DEVICE)
        self.target_net = DQN().to(DEVICE)

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)

        self.memory = ReplayBuffer(capacity=10_000)

    def act(self, state):
        # Given a state choose an action (random or best possible)
        eps_threshold = EPSILON_MIN + (EPSILON_MAX - EPSILON_MIN) * np.exp(
            -1.0 * self.steps_done / EPSILON_DECAY
        )

        self.steps_done += 1

        # EXPLORE
        if random.random() < eps_threshold:
            # Generate a single random integer between 1 (inclusive) and 5 (exclusive)
            return np.random.randint(4), None

        # EXPLOIT
        else:
            with torch.no_grad():
                state = torch.tensor(
                    state, dtype=torch.float32, device=DEVICE
                ).unsqueeze(0)
                q_values = self.policy_net(state)
                return q_values.argmax().item(), q_values.max(1)[0].item()

    def memorize(self, experience):
        self.memory.push(experience)

    def recall(self):
        states, actions, rewards, next_states, dones = self.memory.sample(BATCH_SIZE)
        actions = actions.unsqueeze(1)
        return states, actions, rewards, next_states, dones

    def learn(self):
        if len(self.memory) < BATCH_SIZE:
            return

        states, actions, rewards, next_states, dones = self.recall()

        q_values = self.policy_net(states).gather(1, actions).squeeze()

        with torch.no_grad():
            best_actions = self.policy_net(next_states).argmax(1)
            next_q_values = (
                self.target_net(next_states)
                .gather(1, best_actions.unsqueeze(1))
                .squeeze(1)
            )
            target_q = rewards + DISCOUNT * next_q_values * (1 - dones)

        loss = F.smooth_l1_loss(q_values, target_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def sync(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())


In [19]:
class Scheduler:
    def __init__(self, render: bool = True):
        self.episodes = 0
        self.steps_done = 0

        render_mode = "human" if render else None
        self.env = gym.make("ALE/MsPacman-v5", render_mode=render_mode, obs_type="ram")

        self.record = DataRecord()

        self.actor = MsPacman()

    def run_one_episode(self):
        obs, _ = self.env.reset()
        lives = 3
        total_reward = 0

        total_qval = 0.0
        total_loss = 0.0
        step_count = 0

        self.episodes += 1

        # Skip initial frames
        for _ in range(AVOIDED_STEPS):
            self.env.step(3)

        done = False

        while not done:
            action_index, q_val = self.actor.act(obs)
            action = ACTIONS[action_index]

            total_step_reward = 0

            for _ in range(SKIP_FRAMES):
                next_obs, reward_, terminated, truncated, info = self.env.step(action)

                if reward_ in {200, 400, 800, 1600}:
                    print("Ate a ghost")

                reward = transform_reward(reward_)

                # We ran into a ghost
                if info["lives"] < lives:
                    lives -= 1
                    reward += LOSE_REWARD

                if terminated and lives > 0:
                    reward += WIN_REWARD

                total_step_reward += reward
                done = terminated or truncated

                if done:
                    break

            total_reward += total_step_reward

            self.actor.memorize((obs, action_index, total_step_reward, next_obs, done))
            loss = self.actor.learn()

            # Record Q-value and loss for this step
            if q_val is not None:
                total_qval += q_val
            if loss is not None:
                total_loss += loss

            obs = next_obs

            if self.steps_done % TARGET_UPDATE_FREQ == 0:
                self.actor.sync()

            step_count += 1
            self.steps_done += 1

        mean_qval = total_qval / step_count if step_count > 0 else 0.0
        mean_loss = total_loss / step_count if step_count > 0 else 0.0
        self.record.append(total_reward, mean_qval, mean_loss)

        print(f"[Episode {self.episodes}] Total Reward: {total_reward:.3f}")

    def run(self, num_episodes=1000):
        for _ in range(num_episodes):
            self.run_one_episode()

            if self.episodes % SAVE_MODEL == 0 and self.episodes != 0:
                self.save_checkpoint("./results/checkpoint.pth")

        self.env.close()

    def save_checkpoint(self, path):
        checkpoint = {
            "episode": self.episodes,
            "steps_done": self.actor.steps_done,
            "model_state_dict": self.actor.policy_net.state_dict(),
            "target_state_dict": self.actor.target_net.state_dict(),
            "optimizer_state_dict": self.actor.optimizer.state_dict(),
            "replay_buffer": list(self.actor.memory.buffer),  # deque to list
            "record": {
                "rewards": self.record.rewards,
                "qvalues": self.record.qvalues,
                "losses": self.record.losses,
                "episodes": self.record.episodes,
                "successes": self.record.successes,
            },
        }

        torch.save(checkpoint, path)
        print(f"Checkpoint saved to {path}")

    def load_checkpoint(self, path="checkpoint.pth"):
        checkpoint = torch.load(path, map_location=DEVICE, weights_only=False)
        self.episodes = checkpoint["episode"]
        self.actor.steps_done = checkpoint["steps_done"]

        self.actor.policy_net.load_state_dict(checkpoint["model_state_dict"])
        self.actor.target_net.load_state_dict(checkpoint["target_state_dict"])
        self.actor.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        self.actor.memory.buffer = deque(checkpoint["replay_buffer"], maxlen=10_000)

        self.record.rewards = checkpoint["record"]["rewards"]
        self.record.qvalues = checkpoint["record"]["qvalues"]
        self.record.losses = checkpoint["record"]["losses"]
        self.record.episodes = checkpoint["record"]["episodes"]
        self.record.successes = checkpoint["record"]["successes"]

        print(f"Checkpoint loaded from {path}")

In [10]:
scheduler = Scheduler(render=False)

try:
    scheduler.load_checkpoint("./results/checkpoint.pth")
except FileNotFoundError:
    print("No checkpoint found. Starting fresh.")

scheduler.run(NUM_EPISODES)

Checkpoint loaded from ./results/checkpoint.pth
Ate a ghost
Ate a ghost
Ate a ghost
[Episode 3961] Total Reward: 22.867
[Episode 3962] Total Reward: 14.566
Ate a ghost
[Episode 3963] Total Reward: 9.333
Ate a ghost


KeyboardInterrupt: 

In [None]:
scheduler.record.plot()

In [None]:
window = 20
df = scheduler.record.to_pandas()
df["RollingReward"] = df["Rewards"].rolling(window).mean()
df["RollingQ"] = df["Q-Values"].rolling(window).mean()

# Plot both
plt.figure(figsize=(9, 4))

# Plot rewards and rolling average
plt.subplot(1, 2, 1)
plt.plot(df["Rewards"], label="Reward")
plt.plot(df["RollingReward"], label=f"Rolling Avg ({window})", linewidth=2)
plt.title("Episode Reward with Rolling Average")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.legend()

# Plot Q-values and rolling average
plt.subplot(1, 2, 2)
plt.plot(df["Q-Values"], label="Q-Value", color="orange")
plt.plot(df["RollingQ"], label=f"Rolling Avg ({window})", color="red", linewidth=2)
plt.title("Mean Q-Value with Rolling Average")
plt.xlabel("Episode")
plt.ylabel("Q-Value")
plt.legend()

plt.tight_layout()
plt.show()

In [20]:
scheduler = Scheduler(render=True)

try:
    scheduler.load_checkpoint("./results/checkpoint.pth")
except FileNotFoundError:
    print("No checkpoint found. Starting fresh.")

obs, _ = scheduler.env.reset()
lives = 3
total_reward = 0

done = False
while not done:
    with torch.no_grad():
        state = torch.tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        q_values = scheduler.actor.target_net(state)
        action_index = q_values.argmax().item()

    action = ACTIONS[action_index]
    obs, reward_, terminated, truncated, info = scheduler.env.step(action)

    done = terminated or truncated

Checkpoint loaded from ./results/checkpoint.pth


In [None]:
import imageio
from PIL import Image


def record_gif(env, policy_net, save_path="pacman_run.gif", max_steps=1000):
    frames = []
    obs, _ = env.reset()
    obs = obs / 255.0

    done = False
    lives = 3
    steps = 0

    while not done and steps < max_steps:
        frame = env.render()  # RGB array (H, W, 3)
        frames.append(Image.fromarray(frame))

        with torch.no_grad():
            state = torch.tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            q_values = policy_net(state)
            action_index = q_values.argmax().item()

        action = ACTIONS[action_index]

        next_obs, reward_, terminated, truncated, info = env.step(action)
        next_obs = next_obs / 255.0
        done = terminated or truncated
        obs = next_obs
        steps += 1

    # Save to GIF
    frames[0].save(
        save_path,
        save_all=True,
        append_images=frames[1:],
        duration=40,  # ms per frame (~25 FPS)
        loop=0,
    )
    print(f"🎥 Saved gameplay to {save_path}")


In [None]:
env = gym.make("ALE/MsPacman-v5", render_mode="rgb_array", obs_type="ram")
model = scheduler.actor.target_net

record_gif(env, model, save_path="pacman.gif")
env.close()


🎥 Saved gameplay to pacman.gif


: 