# PPO

## Imports and hyperparameters

In [3]:
import torch
from torch import nn, optim
import numpy as np

import random
import math
import time

import matplotlib.pyplot as plt

from collections import namedtuple

import gymnasium as gym

In [4]:
env = gym.make('CartPole-v1')

N_OBSERVATIONS = env.observation_space.shape[0]
N_ACTIONS = env.action_space.n
N_AGENTS = 2
T = 50
GAMMA = 1.

TEST = True
TRAIN = False

layer_dim = 128

## Model and selecting actions

### Model

I'll use the same parameters for the policy and value networks. If this doesn't work well I can change it later.

In [5]:
class PolicyAndValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(N_OBSERVATIONS, layer_dim),
            nn.ReLU(),
            nn.Linear(layer_dim, layer_dim),
            nn.ReLU(),
            nn.Linear(layer_dim, layer_dim),
            nn.ReLU(),
            # +1 comes from value prediction
            nn.Linear(layer_dim, N_ACTIONS + 1)
        )

    def forward(self, x):
        logits = self.network(x)
        return logits

### Select action

Select action based on log probabilities.

In [6]:
def select_action(action_logits):
    action_probs = torch.softmax(action_logits, dim=0)
    action = torch.multinomial(action_probs, num_samples=1).item()
    return action

In [7]:
if TEST:
    test_env = gym.make('CartPole-v1', render_mode='human')
    test_model = PolicyAndValueNetwork()

    for _ in range(1):
        state, info = test_env.reset()
        terminated, truncated = False, False
        while not (terminated or truncated):
            tensor_state = torch.from_numpy(state)
            logits = test_model(tensor_state)
            action_logits = logits[:-1]
            action = select_action(action_logits)
            state, reward, terminated, truncated, info = test_env.step(action)

    test_env.close()

## Data

### Data collector

We first collect the data while marking episode ends when they happen.

In [8]:
class DataCollector():
    def __init__(self, n_agents):
        self.states = [[] for _ in range(n_agents)]
        self.actions = [[] for _ in range(n_agents)]
        self.rewards = [[] for _ in range(n_agents)]
        self.pis = [[] for _ in range(n_agents)]
        self.episode_ends = [[] for _ in range(n_agents)]

    def add_step(self, agent_id, state, action, reward, pi):
        self.states[agent_id].append(state)
        self.actions[agent_id].append(action)
        self.rewards[agent_id].append(reward)
        self.pis[agent_id].append(pi)

    def mark_episode_end(self, agent_id, timestep):
        self.episode_ends[agent_id].append(timestep)

    def fetch_data(self):
        list_of_agent_states = [torch.stack(agent_states) for agent_states in self.states]
        states = torch.stack([agent_states for agent_states in list_of_agent_states])
        actions = self.actions
        rewards = torch.tensor(self.rewards)
        pis = torch.tensor(self.pis)
        episode_ends = self.episode_ends

        return states, actions, rewards, pis, episode_ends

In [9]:
if TEST:
    # Initialize training data, test environment
    test_data_collector = DataCollector(N_AGENTS)
    test_env = gym.make('CartPole-v1')
    test_model = PolicyAndValueNetwork()
    
    for agent_id in range(N_AGENTS):
        state, _ = test_env.reset()
        terminated, truncated = False, False
        
        for step in range(T):
            # Compute and divide model output
            with torch.no_grad():
                state = torch.from_numpy(state)
                output = test_model(state)
            action_logits, value_logit = output[:-1], output[-1]

            # Select and perform action
            action = select_action(action_logits)
            next_state, reward, terminated, truncated, _ = test_env.step(action)

            # Compute pi
            pi = action_logits[action]

            # Store data
            test_data_collector.add_step(agent_id, state, action, reward, pi)

            # If terminated reset env and mark end, otherwise update state
            if terminated or truncated:
                test_data_collector.mark_episode_end(agent_id, step)
                state, _ = test_env.reset()
                terminated, truncated = False, False
            else:
                state = next_state

    states, actions, rewards, pis, episode_ends = test_data_collector.fetch_data()
    print(f"states: {states}")
    print(f"actions: {actions}")
    print(f"rewards: {rewards}")
    print(f"pis: {pis}")
    print(f"episode ends: {episode_ends}")

states: tensor([[[ 3.1238e-02,  1.9498e-02,  2.6910e-02,  4.2010e-02],
         [ 3.1628e-02,  2.1422e-01,  2.7750e-02, -2.4206e-01],
         [ 3.5912e-02,  1.8717e-02,  2.2909e-02,  5.9243e-02],
         [ 3.6287e-02, -1.7673e-01,  2.4093e-02,  3.5906e-01],
         [ 3.2752e-02, -3.7218e-01,  3.1275e-02,  6.5925e-01],
         [ 2.5309e-02, -1.7751e-01,  4.4460e-02,  3.7657e-01],
         [ 2.1759e-02,  1.6955e-02,  5.1991e-02,  9.8233e-02],
         [ 2.2098e-02, -1.7887e-01,  5.3956e-02,  4.0686e-01],
         [ 1.8520e-02,  1.5445e-02,  6.2093e-02,  1.3166e-01],
         [ 1.8829e-02, -1.8051e-01,  6.4726e-02,  4.4327e-01],
         [ 1.5219e-02,  1.3640e-02,  7.3591e-02,  1.7167e-01],
         [ 1.5492e-02,  2.0764e-01,  7.7025e-02, -9.6920e-02],
         [ 1.9644e-02,  4.0157e-01,  7.5086e-02, -3.6434e-01],
         [ 2.7676e-02,  5.9555e-01,  6.7800e-02, -6.3244e-01],
         [ 3.9587e-02,  3.9955e-01,  5.5151e-02, -3.1920e-01],
         [ 4.7578e-02,  2.0369e-01,  4.8767e-02

### Compute advantages and target values

Next we'll need to compute advantages and target values. We'll start from the end of the data for each agent and move backwards. This makes the desired values easier to compute and to take into account episode ends.

Gradients should be on here because I'll use the computations here to optimize the model.

> If the trajectory terminated due to the maximal trajectory length T
being reached, Vωold (st+n ) denotes the state value associated with state st+n as predicted by the state value
network. Otherwise, Vωold (st+n ) is set to 0

Instead of computing r_t + V(s_t+1) I'm using V(s_t). It's similar enough and shouldn't be that much of a problem.

In [10]:
def compute_advantages_and_target_values(model, states, rewards, episode_ends):
    n_agents = len(episode_ends)

    # We'll add them from last to first, then reverse list
    reversed_target_values = [[] for _ in range(n_agents)]
    reversed_advantages = [[] for _ in range(n_agents)]

    for agent in range(n_agents):
        for i in range(T):
            # Get step state and reward and compute predicted value
            state = states[agent, -i-1]
            reward = rewards[agent, -i-1]
            predicted_value = model(state)[-1]

            # If step is final
            if i in episode_ends[agent]:
                target_value = reward
            # If step is last but not final
            elif i == 0:
                target_value = predicted_value
            # non-last non-final step
            else:
                # Get previous target value and compute 
                previous_target_value = reversed_target_values[agent][-1]
                target_value = reward + GAMMA * previous_target_value
            advantage = target_value - predicted_value

            # Append target value and advantage
            reversed_target_values[agent].append(target_value)
            reversed_advantages[agent].append(advantage)

    target_values = [torch.stack(agent_target_values[::-1]) for agent_target_values in reversed_target_values]
    advantages = [torch.stack(agent_advantage[::-1]) for agent_advantage in reversed_advantages]

    return target_values, advantages

In [11]:
if TEST:
    target_values, advantages = compute_advantages_and_target_values(test_model, states, rewards, episode_ends)
    print(f"target values: {target_values}")
    print(f"advantages: {advantages}")

target values: [tensor([ 2.0000e+01,  1.9000e+01,  1.8000e+01,  1.7000e+01,  1.6000e+01,
         1.5000e+01,  1.4000e+01,  1.3000e+01,  1.2000e+01,  1.1000e+01,
         1.0000e+01,  9.0000e+00,  8.0000e+00,  7.0000e+00,  6.0000e+00,
         5.0000e+00,  4.0000e+00,  3.0000e+00,  2.0000e+00,  1.0000e+00,
         2.8972e+01,  2.7972e+01,  2.6972e+01,  2.5972e+01,  2.4972e+01,
         2.3972e+01,  2.2972e+01,  2.1972e+01,  2.0972e+01,  1.9972e+01,
         1.8972e+01,  1.7972e+01,  1.6972e+01,  1.5972e+01,  1.4972e+01,
         1.3972e+01,  1.2972e+01,  1.1972e+01,  1.0972e+01,  9.9725e+00,
         8.9725e+00,  7.9725e+00,  6.9725e+00,  5.9725e+00,  4.9725e+00,
         3.9725e+00,  2.9725e+00,  1.9725e+00,  9.7249e-01, -2.7508e-02],
       grad_fn=<StackBackward0>), tensor([ 1.0000e+00,  1.1000e+01,  1.0000e+01,  9.0000e+00,  8.0000e+00,
         7.0000e+00,  6.0000e+00,  5.0000e+00,  4.0000e+00,  3.0000e+00,
         2.0000e+00,  1.0000e+00,  3.6999e+01,  3.5999e+01,  3.4999e+01,


### PPO training data

In [12]:
Transition = namedtuple('Transition', ['state', 'action', 'reward', 'pi'])

class TrainingData():
    def __init__(self):
        # List of Transitions
        self.data = []
        # Current batch index
        self.batch_index = 0

    def __len__(self):
        return len(self.data)

    def push(self, state, action, reward, pi):
        # Initially we don't have the advantage and target value
        # they'll be computed during training
        advantage, target_value = 0., 0.
        self.data.append(Transition(state, action, reward, pi, advantage, target_value))

    def push_A_and_V(self, idx, advantage, target_value):
        # It makes sense to add the advantage and target value index by index
        state, action, reward, pi, _, _ = self.data[idx]
        complete_transition = Transition(state, action, reward, pi, advantage, target_value)
        self.data[idx] = complete_transition

    def randomize_order(self):
        random.shuffle(self.data)

    def fetch_batch(self, batch_size):
        i = self.batch_index
        batch = self.data[i:i+batch_size]
        
        # Update index
        self.batch_index += batch_size
        if i+batch_size >= self.__len__():
            # Reset index
            self.batch_index = 0

        return batch
    
    def unpack(self, batch=None):
        # batch should be list of Transitions or None
        # in which case we unpack all of data
        if not batch:
            batch = self.data
            
        states = torch.stack([t.state for t in batch])
        actions = torch.tensor([t.action for t in batch])
        rewards = torch.tensor([t.reward for t in batch])
        pis = torch.stack([t.pi for t in batch])
        advantages = torch.tensor([t.advantage for t in batch])
        target_values = torch.tensor([t.target_value for t in batch])

        return states, actions, rewards, pis, advantages, target_values

In [13]:
if TEST:
    # Initialize training data, test environment
    test_training_data = TrainingData()
    test_env = gym.make('CartPole-v1')
    test_model = PolicyAndValueNetwork()
    
    state, _ = test_env.reset()
    terminated, truncated = False, False
    
    while not (terminated or truncated):
        # Compute and divide model output
        with torch.no_grad():
            state = torch.from_numpy(state)
            output = test_model(state)
        action_logits, value_logit = output[:-1], output[-1]

        # Select and perform action
        action = select_action(action_logits)
        next_state, reward, terminated, truncated, _ = test_env.step(action)

        # Compute pi
        pi = action_logits

        # Store data
        test_training_data.push(state, action, reward, pi)

        # Update state
        state = next_state

    # Unpack training data
    state, action, reward, pi, _, _ = test_training_data.unpack()
    print(f"state: {state}")
    print(f"action: {action}")
    print(f"reward: {reward}")
    print(f"pi: {pi}")

    # Compute advantage and target value for each transition
    for i in range(len(test_training_data)):
        print(test_training_data.batch_index)
        i_data = test_training_data.fetch_batch(1)
        print(i_data)

TypeError: Transition.__new__() takes 5 positional arguments but 7 were given