# PPO

## Imports and hyperparameters

In [1]:
import torch
from torch import nn, optim
import numpy as np

import random
import math
import time

import matplotlib.pyplot as plt

from collections import namedtuple

import gymnasium as gym

In [17]:
env = gym.make('CartPole-v1')

In [67]:
N_OBSERVATIONS = env.observation_space.shape[0]
N_ACTIONS = env.action_space.n
N_AGENTS = 10
T = 100

TEST = True
TRAIN = False

layer_dim = 128

## Model, action and dataset

### Model

I'll use the same parameters for the policy and value networks. If this doesn't work well I can change it later.

In [84]:
class PolicyAndValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(N_OBSERVATIONS, layer_dim),
            nn.ReLU(),
            nn.Linear(layer_dim, layer_dim),
            nn.ReLU(),
            nn.Linear(layer_dim, layer_dim),
            nn.ReLU(),
            # +1 comes from value prediction
            nn.Linear(layer_dim, N_ACTIONS + 1)
        )

    def forward(self, x):
        logits = self.network(x)
        return logits

### Select action

Select action based on log probabilities.

In [49]:
def select_action(action_logits):
    action_probs = torch.softmax(action_logits, dim=0)
    action = torch.multinomial(action_probs, num_samples=1).item()
    return action

In [76]:
if TEST:
    test_env = gym.make('CartPole-v1', render_mode='human')
    test_model = PolicyAndValueNetwork()

    for _ in range(1):
        state, info = test_env.reset()
        terminated, truncated = False, False
        while not (terminated or truncated):
            tensor_state = torch.from_numpy(state)
            logits = test_model(tensor_state)
            action_logits = logits[:-1]
            action = select_action(action_logits)
            state, reward, terminated, truncated, info = test_env.step(action)

    test_env.close()

### Dataset

I need to decide whether the best format for the training data is one list of tuples or many different lists.

In [100]:
Transition = namedtuple('Transition', ['state', 'action', 'reward', 'pi', 'advantage', 'target_value'])

class TrainingData():
    def __init__(self):
        # List of Transitions
        self.data = []
        # Current batch index
        self.batch_index = 0

    def __len__(self):
        return len(self.data)

    def push(self, state, action, reward, pi):
        # Initially we don't have the advantage and target value
        # they'll be computed during training
        advantage, target_value = 0., 0.
        self.data.append(Transition(state, action, reward, pi, advantage, target_value))

    def push_A_and_V(self, idx, advantage, target_value):
        # It makes sense to add the advantage and target value index by index
        state, action, reward, pi, _, _ = self.data[idx]
        complete_transition = Transition(state, action, reward, pi, advantage, target_value)
        self.data[idx] = complete_transition

    def randomize_order(self):
        random.shuffle(self.data)

    def fetch_batch(self, batch_size):
        i = self.batch_index
        batch = self.data[i:i+batch_size]
        
        # Update index
        self.batch_index += batch_size
        if i+batch_size >= self.__len__():
            # Reset index
            self.batch_index = 0

        return batch
    
    def unpack(self, batch=None):
        # batch should be list of Transitions or None
        # in which case we unpack all of data
        if not batch:
            batch = self.data
            
        states = torch.stack([t.state for t in batch])
        actions = torch.tensor([t.action for t in batch])
        rewards = torch.tensor([t.reward for t in batch])
        pis = torch.stack([t.pi for t in batch])
        advantages = torch.tensor([t.advantage for t in batch])
        target_values = torch.tensor([t.target_value for t in batch])

        return states, actions, rewards, pis, advantages, target_values

In [102]:
if TEST:
    # Initialize training data, test environment
    test_training_data = TrainingData()
    test_env = gym.make('CartPole-v1')
    test_model = PolicyAndValueNetwork()
    
    state, _ = test_env.reset()
    terminated, truncated = False, False
    
    while not (terminated or truncated):
        # Compute and divide model output
        with torch.no_grad():
            state = torch.from_numpy(state)
            output = test_model(state)
        action_logits, value_logit = output[:-1], output[-1]

        # Select and perform action
        action = select_action(action_logits)
        next_state, reward, terminated, truncated, _ = test_env.step(action)

        # Compute pi
        pi = action_logits

        # Store data
        test_training_data.push(state, action, reward, pi)

        # Update state
        state = next_state

    # Unpack training data
    state, action, reward, pi, _, _ = test_training_data.unpack()
    print(f"state: {state}")
    print(f"action: {action}")
    print(f"reward: {reward}")
    print(f"pi: {pi}")

    # Compute advantage and target value for each transition
    for i in range(len(test_training_data)):
        print(test_training_data.batch_index)
        i_data = test_training_data.fetch_batch(1)
        print(i_data)

state: tensor([[-1.0995e-02, -9.2247e-03, -4.9037e-02,  3.5234e-02],
        [-1.1180e-02,  1.8656e-01, -4.8333e-02, -2.7251e-01],
        [-7.4484e-03, -7.8352e-03, -5.3783e-02,  4.5469e-03],
        [-7.6051e-03, -2.0215e-01, -5.3692e-02,  2.7979e-01],
        [-1.1648e-02, -6.3011e-03, -4.8096e-02, -2.9335e-02],
        [-1.1774e-02,  1.8948e-01, -4.8683e-02, -3.3680e-01],
        [-7.9845e-03,  3.8526e-01, -5.5419e-02, -6.4442e-01],
        [-2.7937e-04,  1.9095e-01, -6.8307e-02, -3.6970e-01],
        [ 3.5396e-03,  3.8697e-01, -7.5701e-02, -6.8311e-01],
        [ 1.1279e-02,  5.8306e-01, -8.9363e-02, -9.9863e-01],
        [ 2.2940e-02,  3.8924e-01, -1.0934e-01, -7.3530e-01],
        [ 3.0725e-02,  5.8569e-01, -1.2404e-01, -1.0603e+00],
        [ 4.2439e-02,  7.8221e-01, -1.4525e-01, -1.3892e+00],
        [ 5.8083e-02,  9.7881e-01, -1.7303e-01, -1.7235e+00],
        [ 7.7659e-02,  7.8604e-01, -2.0750e-01, -1.4893e+00]])
action: tensor([1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0])
