Notebook useful to train using GPUs on Kaggle/Colab, contains the same code of the repo, just reformatted for imports between files

In [None]:
# !pip install procgen
# !pip install moviepy

In [3]:
import wandb
from collections import deque
from tqdm import tqdm
import copy

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.transforms as transforms

import gym 
import gym.wrappers

import random

In [11]:
global_batch = 0
global_step = 0

def seed_everything(seed):
    """Seed all sources of randomness for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False

In [12]:
class TransitionsDataset(Dataset):
    def __init__(self, transitions, transform=None, normalize_v_targets=False, v_mu=None, v_std=None):
        self.transitions = transitions
        self.transform = transform
        
        self.normalize_v_targets = normalize_v_targets
        if normalize_v_targets:
            self.v_mu = v_mu
            self.v_std = v_std

    def __len__(self):
        return len(self.transitions)
    
    def __getitem__(self, idx):
        state_t = self.transitions[idx]['s_t']
        action_t = self.transitions[idx]['a_t']
        advantage_t = self.transitions[idx]['A_t']
        v_target_t = self.transitions[idx]['v_target_t']

        if self.transform:
            state_t = self.transform(state_t)

        if self.normalize_v_targets:
            v_target_t = (v_target_t - self.v_mu) / max(self.v_std, 1e-6)

        return state_t, action_t, advantage_t, v_target_t.astype(np.float32)


In [13]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, batch_norm):
        super(ConvBlock, self).__init__()

        if batch_norm:
            self.layer = nn.Sequential(
                nn.BatchNorm2d(in_channels),
                nn.ReLU(),
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            )
        else:
            self.layer = nn.Sequential(
                nn.ReLU(),
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            )

    def forward(self, x):
        return self.layer(x)

class ImpalaNetwork(torch.nn.Module):
    def __init__(self, in_channels, num_actions, batch_norm):
        super(ImpalaNetwork, self).__init__()
        
        self.num_actions = num_actions

        self.stems = nn.ModuleList()
        self.res_blocks1 = nn.ModuleList()
        self.res_blocks2 = nn.ModuleList()

        hidden_channels = [16, 32, 32]

        for out_channels in hidden_channels:

            # Don't use batch_norm in the first layer as it should go after MaxPool2d, 
            # but it's already present in the successive ConvBlock
            self.stems.append(torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding="same"),
                torch.nn.MaxPool2d(kernel_size=3, stride=2)
            ))

            self.res_blocks1.append(torch.nn.Sequential(
                ConvBlock(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same", batch_norm=batch_norm),
                ConvBlock(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same", batch_norm=batch_norm),
            ))

            self.res_blocks2.append(torch.nn.Sequential(
                ConvBlock(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same", batch_norm=batch_norm),
                ConvBlock(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same", batch_norm=batch_norm),
            ))

            in_channels = out_channels

        self.fc = torch.nn.Linear(32 * 7 * 7, out_features=256)

        self.out = torch.nn.Linear(256, num_actions)

        
        if num_actions > 1:
            # policy network initialization
            nn.init.orthogonal_(self.fc.weight, gain=0.01)
            nn.init.constant_(self.fc.bias, 0)
        else:
            # value network initialization
            nn.init.orthogonal_(self.out.weight, gain=1)
            nn.init.constant_(self.out.bias, 0)



    def forward(self, x):
        for stem, res_block1, res_block2 in zip(self.stems, self.res_blocks1, self.res_blocks2):
            x = stem(x)
            x = res_block1(x) + x
            x = res_block2(x) + x

        x = nn.functional.relu(x)

        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        x = nn.functional.relu(x)

        if self.num_actions > 1:
            logits = self.out(x)
            output = torch.distributions.Categorical(logits=logits)
        else:
            output = self.out(x).squeeze()

        return output

class PPO:
    def __init__(self, env, config):
        self.policy_net = ImpalaNetwork(config.stack_size * 3, env.action_space.n, config.batch_norm)
        self.value_net = ImpalaNetwork(config.stack_size * 3, 1, config.batch_norm)

        self.normalize_v_targets = config.normalize_v_targets

        if self.normalize_v_targets:
            self.value_mean = 0
            self.value_std = 1
            self.values_count = 0

    def act(self, state):
        dist, value = self.actions_dist_and_v(state)
        action = dist.sample()

        return action.item(), value.item()
    
    def actions_dist_and_v(self, state):
        dist = self.policy_net(state)
        value = self.value_net(state)

        if self.normalize_v_targets:
            # denormalize value
            value = value * max(self.value_std, 1e-6) + self.value_mean

        return dist, value
      
    def to(self, device):
        self.policy_net.to(device)
        self.value_net.to(device)

    def eval(self):
        self.policy_net.eval()
        self.value_net.eval()

    def train(self):
        self.policy_net.train()
        self.value_net.train()

    def update_v_target_stats(self, v_targets):
        self.value_mean = (self.value_mean * self.values_count + v_targets.mean() * len(v_targets)) / (self.values_count + len(v_targets) + 1e-6)
        self.value_std = (self.value_std * self.values_count + v_targets.std() * len(v_targets)) / (self.values_count + len(v_targets) + 1e-6)
        self.values_count += len(v_targets)

In [14]:
class RecorderWrapper(gym.Wrapper):
    def __init__(self, env, episode_frequency_rec):
        super().__init__(env)
        self.env = env
        self.episode_frequency_rec = episode_frequency_rec

        self.episode_counter = 1
        self.recording = False if episode_frequency_rec > 1 else True
        self.frames = []

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)

        if self.recording:
            self.frames.append(np.moveaxis(obs, -1, 0))
            
            if terminated:
                self.save_video()
                self.recording = False
                self.frames = []
        
        if terminated:
            self.episode_counter += 1
            if self.episode_counter % self.episode_frequency_rec == 0:
                self.recording = True
            

        return obs, reward, terminated, truncated, info
    
    def save_video(self):
        global global_step
        wandb.log({"video": wandb.Video(np.array(self.frames), caption=f"step: {global_step} - episode: {self.episode_counter}", fps=30, format="mp4")})

    def close(self):
        super().close()

In [15]:
def train(policy, policy_old, train_dataloader, optimizer_policy, optimizer_value, device, config, scheduler_policy=None, scheduler_value=None):

    global global_batch

    policy.train()
    policy_old.eval()
    assert policy_old.policy_net.training == False and policy_old.value_net.training == False, "Old policy should be in evaluation mode here"
    assert policy.policy_net.training == True and policy.value_net.training == True, "Policy should be in training mode here"
    for epoch in tqdm(range(config.epochs)):
        for batch, (states, actions, advantages, value_targets) in enumerate(train_dataloader):
            # normalize advantages between 0 and 1
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

            states = states.to(device)
            actions = actions.to(device)
            advantages = advantages.to(device)
            value_targets = value_targets.to(device)
            
            dists, values = policy.actions_dist_and_v(states)
            old_dists, _ = policy_old.actions_dist_and_v(states)

            log_probs = dists.log_prob(actions)
            old_log_probs = old_dists.log_prob(actions)

            # Equivalent of doing exp(log_probs) / exp(old_log_probs) 
            # but avoids overflows and division by (potentially if underflown) zero, breaking loss function
            ratios = torch.exp(log_probs - old_log_probs)

            # clipped surrogate loss
            l_clips = -torch.min(ratios * advantages, torch.clip(ratios, 1-config.eps_clip, 1+config.eps_clip) * advantages)
            loss_pi = torch.mean(l_clips)
            loss_entropy = dists.entropy().mean()
            loss_policy = loss_pi - config.entropy_bonus * loss_entropy

            # mse loss
            loss_value = torch.nn.functional.mse_loss(values, value_targets)

            # with two different optimizers
            loss_policy.backward()
            optimizer_policy.step()
            optimizer_policy.zero_grad()

            loss_value.backward()
            optimizer_value.step()
            optimizer_value.zero_grad()

            if global_batch % config.log_frequency == 0:
                wandb.log({"train/loss_pi": loss_pi, 
                           "train/loss_v": loss_value,
                           "train/entropy": loss_entropy,
                           "train/lr_policy": optimizer_policy.param_groups[0]['lr'],
                           "train/lr_value": optimizer_value.param_groups[0]['lr'],
                           "train/batch": global_batch})
            
            global_batch += 1
        
        if scheduler_policy is not None:
            scheduler_policy.step()
        if scheduler_value is not None:
            scheduler_value.step()

        with torch.no_grad():
            # KL divergence between old and new policy for early stopping
            kl_div = torch.distributions.kl.kl_divergence(dists, old_dists).mean().item()
            wandb.log({"train/kl_div": kl_div, "train/batch": global_batch})
            if kl_div > config.kl_limit:
                print(f"Early stopping at epoch {epoch} due to KL divergence {round(kl_div, 4)} > {config.kl_limit}")
                break

In [16]:

frame_to_tensor = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])

def compute_advantages(values, rewards, gamma, lambda_):
    # GAE estimator
    deltas = np.array(rewards) + gamma * np.array(values[1:]) - np.array(values[:-1])
    advantages = [deltas[-1]] 

    for t in range(len(deltas)-2, -1, -1):
        advantage_t = deltas[t] + gamma * lambda_ * advantages[-1]
        advantages.append(advantage_t)

    advantages = advantages[::-1]
    return advantages

def compute_value_targets(advantages, values, rewards, config):
    value_targets = []
    if config.v_target == "TD-lambda":
        for t in range(len(advantages)):
            value_targets.append(advantages[t] + values[t])
    elif config.v_target == "MC":
        value_targets.append(rewards[-1])
        for t in range(len(rewards)-2, -1, -1):
            value_targets.append(rewards[t] + config.gamma * value_targets[-1])
        value_targets = value_targets[::-1]
    else:
        raise ValueError(f"Unknown value target type {config.v_target}, choose between 'TD-lambda' and 'MC'.")
    return value_targets


def play_and_train(env, policy, policy_old, optimizer_policy, optimizer_value, device, config, **kwargs):  

    global global_step
    # obs, _ = env.reset()   # TODO: aggiungi solo se metti step sotto invece di reset

    for iteration in range(config.num_iterations):
        print(f"===============Iteration {iteration+1}===============")

        transitions = []

        obs, _ = env.reset()    # TODO: prova a mettere env.step qui

        # stack frames together to introduce temporal information
        state_deque = deque()
        for _ in range(config.stack_size):
            state_deque.append(frame_to_tensor(obs))

        state = torch.concatenate(list(state_deque), axis=0)

        trajectory = {
            'states': [state],
            'actions': [],
            'rewards': [],
            'values': [],
        }

        policy.eval()

        for step in tqdm(range(config.iteration_timesteps)):
            assert not policy.policy_net.training and not policy.value_net.training, "Policy should be in evaluation mode here"

            state = state.unsqueeze(0).to(device)
            action, value = policy.act(state)

            next_obs, reward, terminated, truncated, info = env.step(action)
            truncated = truncated or step == config.iteration_timesteps - 1

            # update step count
            global_step += 1

            # collect transition info in trajectory
            trajectory['values'].append(value)

            trajectory['actions'].append(action)
            trajectory['rewards'].append(reward)

            # udpate state to become next state using the new observation
            state_deque.popleft()
            state_deque.append(frame_to_tensor(next_obs))
            state = torch.concatenate(list(state_deque), axis=0)

            trajectory['states'].append(state)


            if terminated or truncated:
                # see terminated vs truncated API at https://farama.org/Gymnasium-Terminated-Truncated-Step-API
                if terminated:
                    # final value is 0 if the episode terminated, i.e. reached a final state
                    trajectory['values'].append(0)
                else:
                    # bootstrap if the episode was truncated, i.e. didn't reach a final state
                    state = state.unsqueeze(0).to(device)
                    _, value = policy.act(state)
                    trajectory['values'].append(value)
                
                assert len(trajectory['states']) >= 2, "Trajectory must have at least 2 states to compute advantages."
                assert len(trajectory['states']) == len(trajectory['actions']) + 1 , "Trajectory must have one more state than actions."
                advantages = compute_advantages(trajectory['values'], trajectory['rewards'], config.gamma, config.lambda_)

                value_targets = compute_value_targets(advantages, trajectory['values'], trajectory['rewards'], config)

                if config.normalize_v_targets:
                    policy.update_v_target_stats(np.array(value_targets))


                # convert trajectory into list of transitions
                for t in range(len(trajectory['states'])-1):    # -1 because advantages already encode the value of state t+1
                    transitions.append({
                        's_t': trajectory['states'][t],
                        'a_t': trajectory['actions'][t],
                        'A_t': advantages[t],
                        'v_target_t': value_targets[t],
                    })

                # log and update episodes count only if episode terminated
                if terminated:
                    wandb.log({"play/episodic_reward": sum(trajectory['rewards']), 
                            "play/episode_length": len(trajectory['states'])-1,
                            "play/step": global_step})
                
                if step < config.iteration_timesteps - 1:
                    # reset env and trajectory
                    obs, _ = env.reset()

                    state_deque = deque()
                    for _ in range(config.stack_size):
                        state_deque.append(frame_to_tensor(obs))

                    state = torch.concatenate(list(state_deque), axis=0)

                    trajectory = {
                        'states': [state],
                        'actions': [],
                        'rewards': [],
                        'values': [],
                    }


        # end of play loop
        if config.normalize_v_targets:
            dataset = TransitionsDataset(transitions, normalize_v_targets=True, v_mu=policy.value_mean, v_std=policy.value_std)
        else:
            dataset = TransitionsDataset(transitions)
        train_dataloader = DataLoader(dataset, 
                                    batch_size=config.batch_size, 
                                    shuffle=True)

        print(f"Collected {len(transitions)} transitions, starting training...")

        # update policy
        train(policy, policy_old, train_dataloader, optimizer_policy, optimizer_value, device, config, **kwargs)
        print("Training done!")

        del policy_old
        policy_old = copy.deepcopy(policy)
        policy_old.to(device)


def test(env, policy, device, config):
    obs, _ = env.reset()

    # stack frames together to introduce temporal information
    state_deque = deque()
    for _ in range(config.stack_size):
        state_deque.append(frame_to_tensor(obs))
    state = torch.concatenate(list(state_deque), axis=0)

    policy.eval()
    assert not policy.policy_net.training and not policy.value_net.training, "Policy should be in evaluation mode here"
    
    episode_steps = 0
    cum_reward = 0

    for step in tqdm(range(config.tot_timesteps)):

        state = state.unsqueeze(0).to(device)
        action, _ = policy.act(state)

        next_obs, reward, terminated, truncated, info = env.step(action)

        episode_steps += 1
        cum_reward += reward

        # udpate state to become next state using the new observation
        state_deque.popleft()
        state_deque.append(frame_to_tensor(next_obs))
        state = torch.concatenate(list(state_deque), axis=0)

        if terminated or truncated:
            wandb.log({"test/episodic_reward": cum_reward, 
                    "test/episode_length": episode_steps,
                    "test/step": step})
            
            episode_steps = 0
            cum_reward = 0
                
            if step < config.tot_timesteps - 1:
                # reset env and initial obs
                obs, _ = env.reset()

                state_deque = deque()
                for _ in range(config.stack_size):
                    state_deque.append(frame_to_tensor(obs))
                state = torch.concatenate(list(state_deque), axis=0)

In [17]:

### CONFIGURATION ###
TOT_TIMESTEPS = int(2**18)  #int(2**20)  # approx 1M
ITER_TIMESTEPS = 1024
NUM_ITERATIONS = TOT_TIMESTEPS // ITER_TIMESTEPS
CONFIG = {
    # Game
    "game": "coinrun",
    "num_levels": 200,
    "seed": 6,
    "difficulty": "easy",
    "backgrounds": False,
    "stack_size": 4,

    # Timesteps and iterations
    "tot_timesteps": TOT_TIMESTEPS,
    "iteration_timesteps": ITER_TIMESTEPS,
    "num_iterations": NUM_ITERATIONS,

    # Network architecture
    "batch_norm": True,

    # Training params
    "epochs": 3,
    "batch_size": 64,
    "lr_policy_network": 5e-4,
    "lr_value_network": 5e-4,
    "kl_limit": 0.015,

    # PPO params
    "gamma": 0.999,
    "lambda_": 0.95,
    "eps_clip": 0.2,
    "entropy_bonus": 0.01,
    "v_target": "TD-lambda",  # "TD-lambda" (for advantage + value) or "MC" (for cumulative reward)
    "normalize_v_targets": True,

    # Logging
    "log_frequency": 5,
    "log_video": True,
    "episode_video_frequency": 5,
}


### WANDB ###
wandb.login()
wandb.init(project="ppo-procgen", name=f"{CONFIG['game']}_{CONFIG['num_levels']}_{CONFIG['difficulty']}", config=CONFIG)
config = wandb.config

wandb.define_metric("play/step")
wandb.define_metric("train/batch")
wandb.define_metric("test/step")

wandb.define_metric("play/episodic_reward", step_metric="play/step")
wandb.define_metric("play/episode_length", step_metric="play/step")
wandb.define_metric("train/loss_pi", step_metric="train/batch")
wandb.define_metric("train/loss_v", step_metric="train/batch")
wandb.define_metric("train/entropy", step_metric="train/batch")
wandb.define_metric("train/lr_policy", step_metric="train/batch")
wandb.define_metric("train/lr_value", step_metric="train/batch")
wandb.define_metric("test/episodic_reward", step_metric="test/step")
wandb.define_metric("test/episode_length", step_metric="test/step")


### PLAY AND TRAIN PHASE ###
env = gym.make(
    f"procgen:procgen-{config.game}-v0",
    num_levels=config.num_levels,
    start_level=config.seed,
    distribution_mode=config.difficulty,
    use_backgrounds=config.backgrounds,
    render_mode='rgb_array',
    apply_api_compatibility=True,
    rand_seed=config.seed
)

if config.log_video:
    env = RecorderWrapper(env, config.episode_video_frequency)

seed_everything(config.seed)

### CREATE PPO AGENTS AND OPTIMIZERS ###
policy = PPO(env, config)
policy_old = copy.deepcopy(policy)

print(f"Model has {sum(p.numel() for p in policy.policy_net.parameters()) + sum(p.numel() for p in policy.value_net.parameters())} total parameters.")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device} device")

policy.to(device)
policy_old.to(device)

optimizer_policy = torch.optim.Adam(policy.policy_net.parameters(), lr=config.lr_policy_network)
scheduler_policy = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_policy, T_max=config.num_iterations*config.epochs, eta_min=1e-6)

optimizer_value = torch.optim.Adam(policy.value_net.parameters(), lr=config.lr_value_network)
scheduler_value = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_value, T_max=config.num_iterations*config.epochs, eta_min=1e-6)

play_and_train(env, policy, policy_old, optimizer_policy, optimizer_value, device, config, scheduler_policy=scheduler_policy, scheduler_value=scheduler_value)


### TEST PHASE ###
env_test = gym.make(
    f"procgen:procgen-{config.game}-v0",
    num_levels=0,
    start_level=config.seed,
    distribution_mode=config.difficulty,
    use_backgrounds=config.backgrounds,
    render_mode='rgb_array',
    apply_api_compatibility=True,
    rand_seed=config.seed
)

test(env_test, policy, device, config)


wandb.finish()



Model has 1006512 total parameters.
Using cpu device


  if not isinstance(terminated, (bool, np.bool8)):
100%|██████████| 1024/1024 [00:31<00:00, 32.17it/s]


Collected 1024 transitions, starting training...


 67%|██████▋   | 2/3 [01:25<00:42, 42.99s/it]


Early stopping at epoch 2 due to KL divergence 0.025 > 0.015
Training done!


 22%|██▏       | 224/1024 [00:08<00:31, 25.04it/s]


KeyboardInterrupt: 