# Create Labeled Dataset

In [29]:
import torch
import numpy as np
import lightning as L
import gymnasium as gym

from collections import OrderedDict
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from lightning.fabric.loggers import TensorBoardLogger

import utils
from dataset import PongDataset, StaticImageDataset
from supervised import PolicyNetwork, SaveBestModel, train

In [1]:
dataset = PongDataset(epsilon=0.05, sigma=0.2)

In [94]:
n = 100_000  # number of samples to collect
output_samples = torch.zeros(n, 1, 84, 84, dtype=torch.float32)
output_labels = torch.zeros(n, 6, dtype=torch.float32)

for i in range(n):
    sample, label = dataset[i]
    output_samples[i] = sample
    output_labels[i] = label
    
output_samples.shape, output_labels.shape

(torch.Size([100000, 1, 84, 84]), torch.Size([100000, 6]))

In [8]:
torch.save({"X": output_samples, "y": output_labels}, "datasets/Xy_train_model.pt")

# Supervised Pre-Training

In [8]:
# Settings
num_epochs = 10
lr = 1e-3
pct_start = 0.2
batch_size = 512
framestack = 1
log_interval = 20
log_dir = "logs"
model_dir = "models"
name = "supervised_pretraining"
seed = 0
accelerator = "cuda"
precision = "32-true"
features_dim = 512
num_classes = 6
label_smoothing = 0.
weight_decay = 1e-4

In [10]:
# Setup logger
tb_logger = TensorBoardLogger(root_dir=log_dir, name=name)

# Add custom callback
save_best_model = SaveBestModel(model_dir)

# Configure Fabric
fabric = L.Fabric(
    accelerator=accelerator,
    precision=precision,
    callbacks=[save_best_model],
    loggers=[tb_logger]
)

# Set seed
fabric.seed_everything(seed)

Seed set to 0


0

In [11]:
# Create model and optimizer
obs_space = gym.spaces.Box(low=0, high=1, shape=(framestack, 84, 84), dtype=np.float32)
with fabric.init_module():
    model = PolicyNetwork(obs_space, features_dim=features_dim, normalized_image=True, out_classes=num_classes)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

# Load dataset
loaded = torch.load("datasets/Xy_train_model.pt")
images_tensor, labels_tensor = loaded["X"], loaded["y"]

# Create dataloader
dataset = StaticImageDataset(images_tensor, labels_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

scheduler = OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(dataloader), epochs=num_epochs, pct_start=pct_start, div_factor=25, final_div_factor=1)

torch.set_float32_matmul_precision("high")

# Set up objects
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)

In [12]:
# Run training loop
train(fabric, model, optimizer, scheduler, dataloader, num_epochs, log_interval, num_classes, label_smoothing)

Epoch 1/10 completed.
Epoch 2/10 completed.
Epoch 3/10 completed.
Epoch 4/10 completed.
Epoch 5/10 completed.
Epoch 6/10 completed.
Epoch 7/10 completed.
Epoch 8/10 completed.
Epoch 9/10 completed.
Epoch 10/10 completed.


# Initialize RL Agent with Pre-Trained Weights

In [2]:
vec_env = make_atari_env(env_id="PongNoFrameskip-v4", n_envs=16, seed=1)

In [4]:
model = PPO(
    "CnnPolicy",
    vec_env,
    learning_rate=2.5e-4, #
    n_steps=128, #
    batch_size=256, #
    n_epochs=3, #
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    clip_range_vf=None,
    normalize_advantage=True,
    ent_coef=0.01, #
    vf_coef=0.5,
    max_grad_norm=0.5,
    tensorboard_log="sl/logs/",
    policy_kwargs=dict(features_extractor_kwargs={"features_dim": 512}),
    device="cuda",
    seed=1,
)

utils.print_model_parameters(model, shared_extractor=True)

features_extractor: 1,677,984
pi_features_extractor: 1,677,984
vf_features_extractor: 1,677,984
mlp_extractor: 0
action_net: 3,078
value_net: 513
Total number of parameters: 1,681,575


In [5]:
# load supervised model
fabric = L.Fabric(accelerator="cuda", precision="32-true")
full_checkpoint = fabric.load("models/final_checkpoint_step=3920_loss=0.7992.ckpt")

In [7]:
# rename keys so pytorch can match all parameters
feature_extractor_params = OrderedDict()  # shared feature extractor
action_net_params = OrderedDict()  # action network

for key, value in full_checkpoint["model"].items():
    if key in ['action_net.weight', 'action_net.bias']:
        key = key.replace("action_net.", "")
        action_net_params[key] = value
    else:
        feature_extractor_params[key] = value
        
feature_extractor_params.keys(), action_net_params.keys()

(odict_keys(['cnn.0.weight', 'cnn.0.bias', 'cnn.2.weight', 'cnn.2.bias', 'cnn.4.weight', 'cnn.4.bias', 'linear.0.weight', 'linear.0.bias']),
 odict_keys(['weight', 'bias']))

In [8]:
# load feature extractor parameters
model.policy.pi_features_extractor.load_state_dict(feature_extractor_params)

<All keys matched successfully>

In [9]:
# load action network parameters
model.policy.action_net.load_state_dict(action_net_params)

<All keys matched successfully>

In [10]:
model.save("ppo_nature_cnn_pretrained")

# RL Finetuning

In [2]:
vec_env = make_atari_env(env_id="PongNoFrameskip-v4", n_envs=16, seed=1)

In [3]:
model = PPO.load("ppo_nature_cnn_pretrained", env=vec_env, device="cuda")

In [6]:
# evaluate pre-trained model
mean_reward, std_reward = utils.evaluate(model, vec_env, episodes=16, deterministic=False)

Mean reward: 11.31 +/- 4.87


In [7]:
# finetune for 1.5M steps
model.learn(total_timesteps=1_500_000, tb_log_name="ppo_nature_cnn_pretrained", reset_num_timesteps=True)

<stable_baselines3.ppo.ppo.PPO at 0x1d62e83ffa0>

In [9]:
# save finetuned model
model.save("ppo_nature_cnn_finetuned")

In [None]:
# load finetuned model
model = PPO.load("ppo_nature_cnn_finetuned")

In [11]:
# evaluate 1.5M finetuned model
mean_reward, std_reward = utils.evaluate(model, vec_env, episodes=16, deterministic=False)

Mean reward: 19.88 +/- 1.62


## Additional finetuning

In [15]:
# finetune for another 500k steps
model.learn(total_timesteps=500_000, tb_log_name="ppo_nature_cnn_pretrained", reset_num_timesteps=False)

<stable_baselines3.ppo.ppo.PPO at 0x1d62e83ffa0>

In [18]:
model.save("ppo_nature_cnn_finetuned_2M")

In [17]:
# evaluate 2M finetuned model
mean_reward, std_reward = utils.evaluate(model, vec_env, episodes=16, deterministic=False)

Mean reward: 20.62 +/- 0.48


# RL Training from Scratch

In [2]:
vec_env = make_atari_env(env_id="PongNoFrameskip-v4", n_envs=16, seed=1)

In [5]:
model = PPO(
    "CnnPolicy",
    vec_env,
    learning_rate=2.5e-4,
    n_steps=128,
    batch_size=256,
    n_epochs=3,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    clip_range_vf=None,
    normalize_advantage=True,
    ent_coef=0.01,
    vf_coef=0.5,
    max_grad_norm=0.5,
    tensorboard_log="/logs/ppo/",
    policy_kwargs=dict(features_extractor_kwargs={"features_dim": 512}),
    device="cuda",
    seed=1,
)

utils.print_model_parameters(model, shared_extractor=True)

features_extractor: 1,677,984
pi_features_extractor: 1,677,984
vf_features_extractor: 1,677,984
mlp_extractor: 0
action_net: 3,078
value_net: 513
Total number of parameters: 1,681,575


In [6]:
model.learn(total_timesteps=5_000_000, tb_log_name="ppo_nature_cnn_rl", reset_num_timesteps=True)

<stable_baselines3.ppo.ppo.PPO at 0x22e17896ad0>

In [7]:
model.save("/models/ppo/ppo_nature_cnn_rl")

In [8]:
model = PPO.load("/models/ppo/ppo_nature_cnn_rl")

In [10]:
mean_reward, std_reward = utils.evaluate(model, vec_env, episodes=16, deterministic=False)

Mean reward: 20.50 +/- 0.50
