In [5]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import ale_py
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

In [6]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
 )
 
ENV_NAME = "ALE/Pong-v5"
SEED = 42
FRAME_STACK = 4
STATE_SHAPE = (FRAME_STACK, 80, 80)
INPUT_DIM = int(np.prod(STATE_SHAPE))
 
def make_env(seed: int = SEED):
    env = gym.make(ENV_NAME, obs_type="rgb")
    env.reset(seed=seed)
    env.action_space.seed(seed)
    return env
 
env = make_env()
N_ACTIONS = env.action_space.n

In [7]:
def preprocess_frame(frame):
    frame = frame[35:195]  # Crop
    frame = frame[::2, ::2]  # Downsample
    frame = frame.mean(axis=2)  # Convert to grayscale
    frame = frame.astype(np.float32) / 255.0
    return frame
 
def stack_to_state(frames):
    return np.array(frames, dtype=np.float32)
 
class PolicyNetwork(nn.Module):
    def __init__(self, input_shape, n_actions):
        super().__init__()
        c, h, w = input_shape
        self.conv = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        conv_out_size = self._get_conv_out((c, h, w))
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
 
    def _get_conv_out(self, shape):
        with torch.no_grad():
            o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))
 
    def forward(self, x):
        conv_out = self.conv(x).view(x.size(0), -1)
        return self.fc(conv_out)
 
class PolicyGradientAgent:
    def __init__(self, input_shape, n_actions, device, lr_start = 5*1e-4, lr_decay = 0.999, gamma = 0.998):
        self.device = device
        self.gamma = gamma
        self.policy_net = PolicyNetwork(input_shape, n_actions).to(device)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr_start)
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=lr_decay)
 
    def select_action(self, state: np.ndarray):
        state_tensor = torch.from_numpy(state).unsqueeze(0).to(self.device)
        logits = self.policy_net(state_tensor)
        dist = Categorical(logits=logits)
        action_idx = dist.sample()
        log_prob = dist.log_prob(action_idx).squeeze(0)
        return action_idx.item(), log_prob
 
    def _compute_returns(self, rewards):
        returns = np.zeros(len(rewards), dtype=np.float32)
        running = 0.0
        for t in reversed(range(len(rewards))):
            if rewards[t] != 0:
                running = 0.0  # reset at game boundaries for Pong
            running = rewards[t] + self.gamma * running
            returns[t] = running
        if returns.std() > 0:
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        return torch.from_numpy(returns).to(self.device)
 
    def train(self, env, n_episodes: int = 1000):
        episode_rewards, episode_losses = [], []
        for episode in range(1, n_episodes + 1):
            frame_stack = deque(maxlen=FRAME_STACK)
            obs, _ = env.reset()
            frame = preprocess_frame(obs)
            for _ in range(FRAME_STACK):
                frame_stack.append(frame)
 
            log_probs, rewards = [], []
            done = False
            total_reward = 0.0
 
            while not done:
                state = stack_to_state(frame_stack)
                action, log_prob = self.select_action(state)
                obs, reward, terminated, truncated, _ = env.step(action)
                log_probs.append(log_prob)
                rewards.append(reward)
                total_reward += reward
 
                frame = preprocess_frame(obs)
                frame_stack.append(frame)
                done = terminated or truncated
 
            returns = self._compute_returns(rewards)
            log_probs = torch.stack(log_probs)
            loss = -(log_probs * returns).mean()
 
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
 
            episode_rewards.append(total_reward)
            episode_losses.append(loss.item())
            rolling_mean = np.mean(episode_rewards[-50:])
            current_lr = self.scheduler.get_last_lr()[0]
            print(f"Episode {episode:04d} | reward={total_reward:.1f} | mean_50={rolling_mean:.2f} | lr={current_lr:.2e}")
 
        return episode_rewards, episode_losses

In [8]:
agent_pong_pg = PolicyGradientAgent(
    input_shape=STATE_SHAPE,
    n_actions=N_ACTIONS,
    device=device,
    lr_start=0.5*1e-3,
    lr_decay=1,
    gamma=0.97
)

In [9]:
rewards, losses = agent_pong_pg.train(env, n_episodes=1500)

Episode 0001 | reward=-20.0 | mean_50=-20.00 | lr=5.00e-04
Episode 0002 | reward=-21.0 | mean_50=-20.50 | lr=5.00e-04
Episode 0002 | reward=-21.0 | mean_50=-20.50 | lr=5.00e-04
Episode 0003 | reward=-21.0 | mean_50=-20.67 | lr=5.00e-04
Episode 0003 | reward=-21.0 | mean_50=-20.67 | lr=5.00e-04
Episode 0004 | reward=-21.0 | mean_50=-20.75 | lr=5.00e-04
Episode 0004 | reward=-21.0 | mean_50=-20.75 | lr=5.00e-04
Episode 0005 | reward=-21.0 | mean_50=-20.80 | lr=5.00e-04
Episode 0005 | reward=-21.0 | mean_50=-20.80 | lr=5.00e-04
Episode 0006 | reward=-21.0 | mean_50=-20.83 | lr=5.00e-04
Episode 0006 | reward=-21.0 | mean_50=-20.83 | lr=5.00e-04
Episode 0007 | reward=-20.0 | mean_50=-20.71 | lr=5.00e-04
Episode 0007 | reward=-20.0 | mean_50=-20.71 | lr=5.00e-04
Episode 0008 | reward=-20.0 | mean_50=-20.62 | lr=5.00e-04
Episode 0008 | reward=-20.0 | mean_50=-20.62 | lr=5.00e-04
Episode 0009 | reward=-21.0 | mean_50=-20.67 | lr=5.00e-04
Episode 0009 | reward=-21.0 | mean_50=-20.67 | lr=5.00e-

KeyboardInterrupt: 

In [None]:
# train more

rewards_2, losses_2 = agent_pong_pg.train(env, n_episodes=5000)

rewards.extend(rewards_2)
losses.extend(losses_2)

In [None]:
plt.figure(figsize=(8, 4))
plt.plot(rewards)
plt.title("Pong Policy Gradient Rewards")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.grid(True)

In [None]:
plt.figure(figsize=(8, 4))
plt.plot(losses)
plt.title("Pong Policy Gradient Loss")
plt.xlabel("Episode")
plt.ylabel("Loss")
plt.grid(True)

In [None]:
window = 50
plt.figure(figsize=(8, 4))
if len(rewards) >= window:
    rolling_rewards = [np.mean(rewards[i-window:i]) for i in range(window, len(rewards) + 1)]
    plt.plot(range(window, len(rewards) + 1), rolling_rewards)
else:
    plt.plot(rewards)
plt.title(f"Rolling Mean Reward (window={window})")
plt.xlabel("Episode")
plt.ylabel("Rolling Reward")
plt.grid(True)

In [None]:
window = 50
plt.figure(figsize=(8, 4))
if len(losses) >= window:
    rolling_losses = [np.mean(losses[i-window:i]) for i in range(window, len(losses) + 1)]
    plt.plot(range(window, len(losses) + 1), rolling_losses)
else:
    plt.plot(losses)
plt.title(f"Rolling Mean Loss (window={window})")
plt.xlabel("Episode")
plt.ylabel("Rolling Loss")
plt.grid(True)

In [None]:
torch.save(agent_pong_pg.policy_net.state_dict(), "pg_agent_pong_model.pth")

In [None]:
np.save("pg_pong_rewards.npy", np.array(rewards))
np.save("pg_pong_losses.npy", np.array(losses))