# PPO for Pong (PPO-PONG.ipynb)


In [9]:
# Imports and utilities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
# import gym
import gymnasium as gymn
from collections import deque, namedtuple
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import time
import random
from typing import Tuple
import ale_py

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


Device: cuda


In [10]:
# Config (change as needed)
env_id = 'ALE/Pong-v5'  # change if necessary (e.g., 'PongNoFrameskip-v4')
seed = 0
torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)

# PPO hyperparameters
cfg = {
    'lr': 2.5e-4,
    'gamma': 0.99,
    'gae_lambda': 0.95,
    'clip_eps': 0.2,
    'update_epochs': 4,
    'minibatch_size': 64,
    'rollout_steps': 2048,  # number of env steps per update
    'max_frames': 5_000_000, # cap training frames (adjust as needed)
    'eval_interval': 100_000,
    'num_eval_episodes': 10,
}

# environment preprocessing parameters
input_shape = (4, 84, 84)  # frame stack, HxW
frame_skip = 4


In [11]:
# Simple preprocessing wrappers to mimic Atari pipelines
import cv2

class PreprocessFrame:
    def __init__(self, shape=(84,84)):
        self.shape = shape
    def __call__(self, frame):
        # frame: HxWxC (uint8)
        img = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        img = cv2.resize(img, self.shape, interpolation=cv2.INTER_AREA)
        return img.astype(np.uint8)

class FrameStack:
    def __init__(self, k):
        self.k = k
        self.frames = deque(maxlen=k)
    def reset(self, obs):
        processed = obs
        for _ in range(self.k):
            self.frames.append(processed)
        return np.stack(self.frames, axis=0)
    def step(self, obs):
        self.frames.append(obs)
        return np.stack(self.frames, axis=0)

# Helper to create env with preprocessing
def make_env(env_id, seed=0):
    # Using gymnasium/ALE env
    env = gymn.make(env_id, render_mode=None)
    env.reset(seed=seed)
    preproc = PreprocessFrame((84,84))
    return env, preproc

# Note: for speed and correctness you might prefer standard wrappers like gym.wrappers AtariPreprocessing and FrameStack


In [12]:
# Actor-Critic network with CNN backbone (compatible with DQN convs)
class ActorCritic(nn.Module):
    def __init__(self, input_channels, n_actions, hidden_dim=512):
        super(ActorCritic, self).__init__()
        # convs similar to DQN-style
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        
        # compute conv output
        with torch.no_grad():
            x = torch.zeros(1, input_channels, 84, 84)
            x = F.relu(self.conv1(x))
            x = F.relu(self.conv2(x))
            x = F.relu(self.conv3(x))
            conv_out_size = int(x.numel())
        
        self.fc = nn.Linear(conv_out_size, hidden_dim)
        # actor & critic heads
        self.policy = nn.Linear(hidden_dim, n_actions)
        self.value = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        # expects x shape (B, C, H, W), float, normalized [0,1]
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        return self.policy(x), self.value(x)

    def act(self, obs):
        # obs: np array (C,H,W)
        obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)/255.0
        logits, value = self.forward(obs_t)
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action).item(), value.item(), probs.detach().cpu().numpy()

    def get_values(self, obs_batch):
        # obs_batch: tensor (B,C,H,W) float [0,1]
        logits, values = self.forward(obs_batch)
        probs = F.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        return dist, values.squeeze(-1)


In [13]:
# PPO Agent
class PPOAgent:
    def __init__(self, env, preproc, n_actions, input_channels=4, cfg=cfg):
        self.env = env
        self.preproc = preproc
        self.n_actions = n_actions
        self.cfg = cfg
        self.net = ActorCritic(input_channels, n_actions).to(device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=cfg['lr'])
        self.gamma = cfg['gamma']
        self.gae_lambda = cfg['gae_lambda']
        self.clip_eps = cfg['clip_eps']
        self.update_epochs = cfg['update_epochs']
        self.minibatch_size = cfg['minibatch_size']
        self.rollout_steps = cfg['rollout_steps']
        
    def collect_rollout(self):
        # collect rollout_steps transitions
        obs = self.env.reset()[0]
        obs = self.preproc(obs)
        fs = FrameStack(4)
        state = fs.reset(obs)
        rollout = []
        total_steps = 0
        ep_rewards = []
        ep_reward = 0
        while total_steps < self.rollout_steps:
            # run one step
            action, logp, value, _ = self.net.act(state)
            next_obs, reward, terminated, truncated, info = self.env.step(action)
            done = terminated or truncated
            next_obs_p = self.preproc(next_obs)
            next_state = fs.step(next_obs_p)
            rollout.append((state, action, logp, reward, value, done))
            state = next_state
            total_steps += 1
            ep_reward += reward
            if done:
                ep_rewards.append(ep_reward)
                ep_reward = 0
                obs = self.env.reset()[0]
                obs = self.preproc(obs)
                state = fs.reset(obs)
        return rollout, ep_rewards

    def compute_gae(self, rollout):
        # rollout is list of (s,a,logp,r,v,done)
        states = []
        actions = []
        old_logps = []
        rewards = []
        values = []
        dones = []
        for (s,a,logp,r,v,d) in rollout:
            states.append(s)
            actions.append(a)
            old_logps.append(logp)
            rewards.append(r)
            values.append(v)
            dones.append(d)
        # compute advantages
        advantages = []
        returns = []
        gae = 0
        next_value = 0
        for step in reversed(range(len(rewards))):
            mask = 1.0 - float(dones[step])
            delta = rewards[step] + self.gamma * next_value * mask - values[step]
            gae = delta + self.gamma * self.gae_lambda * mask * gae
            advantages.insert(0, gae)
            next_value = values[step]
            returns.insert(0, gae + values[step])
        # convert to tensors
        states = torch.tensor(np.array(states), dtype=torch.float32, device=device)/255.0
        actions = torch.tensor(actions, dtype=torch.long, device=device)
        old_logps = torch.tensor(old_logps, dtype=torch.float32, device=device)
        returns = torch.tensor(returns, dtype=torch.float32, device=device)
        advantages = torch.tensor(advantages, dtype=torch.float32, device=device)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        return states, actions, old_logps, returns, advantages

    def update(self, rollout):
        states, actions, old_logps, returns, advantages = self.compute_gae(rollout)
        n = states.size(0)
        for epoch in range(self.update_epochs):
            perm = torch.randperm(n)
            for i in range(0, n, self.minibatch_size):
                idx = perm[i:i+self.minibatch_size]
                batch_states = states[idx]
                batch_actions = actions[idx]
                batch_old_logps = old_logps[idx]
                batch_returns = returns[idx]
                batch_adv = advantages[idx]
                dist, values = self.net.get_values(batch_states)
                new_logps = dist.log_prob(batch_actions)
                entropy = dist.entropy().mean()
                ratio = (new_logps - batch_old_logps).exp()
                surr1 = ratio * batch_adv
                surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * batch_adv
                policy_loss = -torch.min(surr1, surr2).mean()
                value_loss = F.mse_loss(values, batch_returns)
                loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
                self.optimizer.step()
        return


In [14]:
# Training loop
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/PPO_Pong')
agent_env, preproc = make_env(env_id, seed=seed)
# get action space
# reset to get initial obs for determining action space
obs0 = agent_env.reset()[0]
# determine discrete action size
n_actions = agent_env.action_space.n
print('Action space:', n_actions)

agent = PPOAgent(agent_env, preproc, n_actions, input_channels=4, cfg=cfg)

frame_count = 0
start_time = time.time()
episode = 0
all_episode_rewards = []

while frame_count < cfg['max_frames']:
    rollout, ep_rewards = agent.collect_rollout()
    frame_count += len(rollout)
    if len(ep_rewards)>0:
        for r in ep_rewards:
            all_episode_rewards.append(r)
            writer.add_scalar('Episode/Reward', r, frame_count)
    # update
    agent.update(rollout)
    # evaluation logging
    if frame_count % cfg['eval_interval'] == 0:
        # run evaluation episodes
        eval_rewards = []
        for _ in range(cfg['num_eval_episodes']):
            obs = agent_env.reset()[0]
            obs_p = preproc(obs)
            fs = FrameStack(4)
            state = fs.reset(obs_p)
            done = False
            ep_r = 0
            while not done:
                action, _, _, _ = agent.net.act(state)
                next_obs, reward, term, trunc, info = agent_env.step(action)
                done = term or trunc
                next_obs_p = preproc(next_obs)
                state = fs.step(next_obs_p)
                ep_r += reward
            eval_rewards.append(ep_r)
        avg_eval = np.mean(eval_rewards)
        writer.add_scalar('Eval/Reward', avg_eval, frame_count)
        print(f'Frames: {frame_count} | AvgEval: {avg_eval:.2f} | Time: {time.time()-start_time:.1f}s')

print('Training finished. Total frames:', frame_count)
writer.close()


Action space: 6


KeyboardInterrupt: 

In [None]:
# Test / Play with trained model (manual run)
# Warning: this cell assumes you'll run after some training or load a saved model checkpoint.

env, preproc = make_env(env_id)
fs = FrameStack(4)
obs = env.reset()[0]
obs_p = preproc(obs)
state = fs.reset(obs_p)

for ep in range(3):
    done = False
    ep_r = 0
    while not done:
        action, _, _, _ = agent.net.act(state)
        obs, reward, term, trunc, info = env.step(action)
        done = term or trunc
        obs_p = preproc(obs)
        state = fs.step(obs_p)
        ep_r += reward
    print('Eval episode reward:', ep_r)

# To render, change env creation to include render_mode='human' and call env.render() as needed.


In [None]:
# Save model
torch.save(agent.net.state_dict(), 'ppo_pong_net.pth')
print('Saved ppo_pong_net.pth')


In [None]:
# Simple plotting FROM logged arrays (if available)
import matplotlib.pyplot as plt

if len(all_episode_rewards) > 0:
    plt.plot(np.convolve(all_episode_rewards, np.ones(50)/50, mode='valid'))
    plt.title('Smoothed Episode Rewards')
    plt.xlabel('Episodes')
    plt.ylabel('Reward')
    plt.show()
else:
    print('No episode rewards collected yet.')
