# PPO

## Imports and hyperparameters

In [2]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import numpy as np

import random
import math
import time

import matplotlib.pyplot as plt

from collections import namedtuple

import gymnasium as gym

In [27]:
env = gym.make('CartPole-v1')

N_OBSERVATIONS = env.observation_space.shape[0]
N_ACTIONS = env.action_space.n

# Hyperparameters
N_AGENTS = 2
T = 50
GAMMA = 0.999
EPSILON = 0.2
BATCH_SIZE = 10

TEST = True
TRAIN = False

layer_dim = 128

## Model and selecting actions

### Model

I'll use the same parameters for the policy and value networks. If this doesn't work well I can change it later.

In [4]:
class PolicyAndValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(N_OBSERVATIONS, layer_dim),
            nn.ReLU(),
            nn.Linear(layer_dim, layer_dim),
            nn.ReLU(),
            nn.Linear(layer_dim, layer_dim),
            nn.ReLU(),
            # +1 comes from value prediction
            nn.Linear(layer_dim, N_ACTIONS + 1)
        )

    def forward(self, x):
        logits = self.network(x)
        return logits

### Select action

Select action based on log probabilities.

In [5]:
def select_action(action_logits):
    action_probs = torch.softmax(action_logits, dim=0)
    action = torch.multinomial(action_probs, num_samples=1).item()
    return action

In [6]:
if TEST:
    test_env = gym.make('CartPole-v1', render_mode='human')
    test_model = PolicyAndValueNetwork()

    for _ in range(1):
        state, info = test_env.reset()
        terminated, truncated = False, False
        while not (terminated or truncated):
            tensor_state = torch.from_numpy(state)
            logits = test_model(tensor_state)
            action_logits = logits[:-1]
            action = select_action(action_logits)
            state, reward, terminated, truncated, info = test_env.step(action)

    test_env.close()

## Data

### Data collector

We first collect the data while marking episode ends when they happen.

In [23]:
class DataCollector():
    def __init__(self):
        self.states = [[] for _ in range(N_AGENTS)]
        self.actions = [[] for _ in range(N_AGENTS)]
        self.rewards = [[] for _ in range(N_AGENTS)]
        self.pis = [[] for _ in range(N_AGENTS)]
        self.episode_ends = [[] for _ in range(N_AGENTS)]

    def add_step(self, agent_id, state, action, reward, pi):
        self.states[agent_id].append(state)
        self.actions[agent_id].append(action)
        self.rewards[agent_id].append(reward)
        self.pis[agent_id].append(pi)

    def mark_episode_end(self, agent_id, timestep):
        self.episode_ends[agent_id].append(timestep)

    def fetch_data(self):
        list_of_agent_states = [torch.stack(agent_states) for agent_states in self.states]
        states = torch.stack([agent_states for agent_states in list_of_agent_states])
        actions = self.actions
        rewards = torch.tensor(self.rewards)
        pis = torch.tensor(self.pis)
        episode_ends = self.episode_ends

        return states, actions, rewards, pis, episode_ends

### Compute advantages and target values

Next we'll need to compute advantages and target values. We'll start from the end of the data for each agent and move backwards. This makes the desired values easier to compute and to take into account episode ends.

Gradients should be on here because I'll use the computations here to optimize the model.

> If the trajectory terminated due to the maximal trajectory length T
being reached, Vωold (st+n ) denotes the state value associated with state st+n as predicted by the state value
network. Otherwise, Vωold (st+n ) is set to 0

Instead of computing r_t + V(s_t+1) I'm using V(s_t). It's similar enough and shouldn't be that much of a problem.

In [20]:
def compute_advantages_and_target_values(model, states, rewards, episode_ends):
    # We'll add them from last to first, then reverse list
    reversed_target_values = [[] for _ in range(N_AGENTS)]
    reversed_advantages = [[] for _ in range(N_AGENTS)]

    for agent in range(N_AGENTS):
        for i in range(T):
            # Get step state and reward and compute predicted value
            state = states[agent, -i-1]
            reward = rewards[agent, -i-1]
            predicted_value = model(state)[-1]

            # If step is final
            if i in episode_ends[agent]:
                target_value = reward
            # If step is last but not final
            elif i == 0:
                target_value = predicted_value
            # non-last non-final step
            else:
                # Get previous target value and compute 
                previous_target_value = reversed_target_values[agent][-1]
                target_value = reward + GAMMA * previous_target_value
            advantage = target_value - predicted_value

            # Append target value and advantage
            reversed_target_values[agent].append(target_value)
            reversed_advantages[agent].append(advantage)

    target_values = [torch.stack(agent_target_values[::-1]) for agent_target_values in reversed_target_values]
    advantages = [torch.stack(agent_advantage[::-1]) for agent_advantage in reversed_advantages]

    return target_values, advantages

### PPO training data

I need to modify the data so that I can sample random batches from it.

In [18]:
class PPODataset(Dataset):
    def __init__(self, states, actions, pis, target_values, advantages):
        # Reshape all into (N_AGENTS * T, whatever (possibly 0))
        self.states = states.reshape(N_AGENTS * T, -1)
        self.actions = [action for agent in actions for action in agent]
        self.pis = pis.reshape(N_AGENTS * T)
        self.target_values = torch.cat(target_values)
        self.advantages = torch.cat(advantages)
        
    def __len__(self):
        return len(self.actions)
    
    def __getitem__(self, i):
        state = self.states[i]
        action = self.actions[i]
        pi = self.pis[i]
        target_value = self.target_values[i]
        advantage = self.advantages[i]

        return state, action, pi, target_value, advantage

### Test everything

In [31]:
if TEST:
    # Initialize training data, test environment
    test_data_collector = DataCollector()
    test_env = gym.make('CartPole-v1')
    test_model = PolicyAndValueNetwork()
    
    for agent in range(N_AGENTS):
        state, _ = test_env.reset()
        terminated, truncated = False, False
        
        for step in range(T):
            # Compute and divide model output
            with torch.no_grad():
                state = torch.from_numpy(state)
                output = test_model(state)
            action_logits, value_logit = output[:-1], output[-1]

            # Select and perform action
            action = select_action(action_logits)
            next_state, reward, terminated, truncated, _ = test_env.step(action)

            # Compute pi
            pi = action_logits[action]

            # Store data
            test_data_collector.add_step(agent, state, action, reward, pi)

            # If terminated reset env and mark end, otherwise update state
            if terminated or truncated:
                test_data_collector.mark_episode_end(agent, step)
                state, _ = test_env.reset()
                terminated, truncated = False, False
            else:
                state = next_state

    # Fetch data from data collector
    states, actions, rewards, pis, episode_ends = test_data_collector.fetch_data()

    # Compute advantages and target values
    target_values, advantages = compute_advantages_and_target_values(
        model=test_model,
        states=states,
        rewards=rewards,
        episode_ends=episode_ends,
    )

    # Add everything to dataset and create dataloader
    test_dataset = PPODataset(states, actions, pis, target_values, advantages)
    test_dataloader = DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    state_batch, action_batch, pi_batch, target_value_batch, advantage_batch = next(iter(test_dataloader))
    print(f"state_batch: {state_batch}")
print(f"action_batch: {action_batch}")
print(f"pi_batch: {pi_batch}")
print(f"target_value_batch: {target_value_batch}")
print(f"advantage_batch: {advantage_batch}")

state_batch: tensor([[ 8.7621e-02,  8.1141e-01, -1.6617e-01, -1.3870e+00],
        [-2.1354e-02, -7.1292e-03,  1.5443e-02, -1.5917e-02],
        [-5.6898e-02, -9.7270e-01,  5.5225e-02,  1.4389e+00],
        [-4.3018e-02,  3.6789e-01,  2.1321e-02, -5.9012e-01],
        [ 9.6179e-03, -4.0111e-01,  2.3109e-02,  5.8503e-01],
        [ 2.9481e-02, -1.7103e-01, -2.2468e-02,  3.0906e-01],
        [ 2.9451e-02,  4.8258e-02, -3.7498e-03,  3.6285e-02],
        [ 2.3571e-02, -1.1336e-02,  4.0302e-03,  9.9151e-03],
        [-1.8227e-02,  3.6026e-03,  6.2505e-04, -4.1077e-02],
        [-3.8638e-02, -7.3204e-01,  1.2175e-01,  1.2100e+00]])
action_batch: tensor([0, 0, 0, 0, 1, 0, 1, 1, 0, 0])
pi_batch: tensor([0.0183, 0.0127, 0.0572, 0.0193, 0.0896, 0.0230, 0.0907, 0.0916, 0.0128,
        0.0511])
target_value_batch: tensor([ 2.9303,  1.0000,  5.9850,  4.9154,  4.9900, 11.9342,  1.0000,  2.9970,
        10.9452, 16.8647], grad_fn=<StackBackward0>)
advantage_batch: tensor([ 3.0034,  1.0646,  6.0971,  

It's all working, fuck yeah. Now it's time to write the training code. By the way, I tested the code incrementally, not all at once, but now that it's all working I left only this big chunk of code, since the rest was mostly building blocks for this one.

## Training

### Loss function(s)