# PPO Base Implementation
This will be the baseline implementation for comparing with the other methods.

In [None]:
SEED = 1234
LEARNING_RATE = 1e-4
GAMMA = 0.99
EPOCHS = 20
CLIP_EPSILON = 0.2
BATCH_SIZE = 10

In [None]:
import random
import wandb

import gym
import numpy as np

import torch
from torch.nn import LeakyReLU, Linear, MSELoss, Sequential, Softmax
from torch.optim import Adam

import logging
logging.basicConfig(level=logging.INFO)

In [None]:
device = torch.device("mps")
env = gym.make('CartPole-v1')

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

wandb.init(
  project="ppo-base",
  
  config={
    "learning_rate": LEARNING_RATE,
    "gamma": GAMMA,
    "epochs": EPOCHS,
    "clip_epsilon": CLIP_EPSILON,
    "batch_size": BATCH_SIZE,
    "seed": SEED
  },
)

## Network Architecture

**PolicyNetwork**:
- Input: State
- Output: Action distribution (0-1)
- 2 Hidden layers with LeakyReLU activation

**ValueNetwork**:
- Input: State
- Output: Value
- 2 Hidden layers with LeakyReLU activation

In [None]:
class PolicyNetwork(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim):
    super().__init__()
    self.model = Sequential(
      Linear(input_dim, hidden_dim),
      LeakyReLU(),
      Linear(hidden_dim, hidden_dim),
      LeakyReLU(),
      Linear(hidden_dim, 2),
      Softmax(dim=1)
    )

  def forward(self, state):
    return self.model(state)
 
  def stochastic_action(self, state):
    r"""Returns an action sampled from the policy network."""
    
    state = torch.from_numpy(state).unsqueeze(0).to(device)
    probs = self.forward(state)
    m = torch.distributions.Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)
  
  def deterministic_action(self, state):
    r"""Returns an action with the highest probability."""
    
    state = torch.from_numpy(state).unsqueeze(0).to(device)
    probs = self.forward(state)
    action = torch.argmax(probs)
    return action.item(), probs[0][action].item()

  
class ValueNetwork(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim) -> None:
    super().__init__()
    self.model = Sequential(
      Linear(input_dim, hidden_dim),
      LeakyReLU(),
      Linear(hidden_dim, hidden_dim),
      LeakyReLU(),
      Linear(hidden_dim, 1)
    )
  
  def forward(self, state):
    return self.model(state)
  

# Training
- 64 hidden nodes
- Adam optimizer
- MSE loss for value network

In [None]:
_observation_size = env.observation_space.shape[0]

policy_net = PolicyNetwork(_observation_size, 64).to(device)
value_net  = ValueNetwork(_observation_size, 64).to(device)

policy_optimizer = Adam(policy_net.parameters(), lr=LEARNING_RATE)
value_optimizer  = Adam(value_net.parameters(), lr=LEARNING_RATE)

criterion = MSELoss()

In [None]:
def compute_returns(rewards):
  returns = torch.zeros(len(rewards))
  R = 0
  for i in reversed(range(len(rewards))):
    R = rewards[i] + GAMMA * R
    returns[i] = R
  return returns

In [None]:

def evaluate():
    env = gym.make('CartPole-v1')
    
    state, _ = env.reset()
    done, steps = False, 0
    
    while not done and steps < 10000:
        action = policy_net.deterministic_action(state)[0]
        next_state, _, done, *_ = env.step(action)
        state = next_state
        steps += 1
    
    return steps


def ppo_step():
    state, _ = env.reset()
    
    # capture entire episode
    done, steps = False, 0
    states, actions, log_probs_old, rewards = [], [], [], []
    
    while not done:
        action, log_prob = policy_net.stochastic_action(state)
        next_state, reward, done, _, _ = env.step(action)

        log_probs_old.append(log_prob)
        states.append(state)
        actions.append(action)
        rewards.append(reward)

        state = next_state
        steps += 1
    
    # Convert to tensors
    # Be sure to detach() the tensors from the graph as these are "constants"
    states = torch.from_numpy(np.array(states)).detach().to(device)
    actions = torch.tensor(actions).detach().to(device)
    log_probs_old = torch.stack(log_probs_old).detach().to(device)
    
    returns = compute_returns(rewards).detach().to(device)
    
    values = value_net(states)
    advantages = (returns - values.squeeze()).detach().to(device)

    for _ in range(EPOCHS):
        for i in range(0, len(states), BATCH_SIZE):
            # Grab a batch of data
            batch_states = states[i:i+BATCH_SIZE]
            batch_actions = actions[i:i+BATCH_SIZE]
            batch_log_probs_old = log_probs_old[i:i+BATCH_SIZE]
            batch_advantages = advantages[i:i+BATCH_SIZE]
            batch_returns = returns[i:i+BATCH_SIZE]

            # Calculate new log probabilities
            new_action_probs = policy_net(batch_states)
            new_log_probs = torch.log(new_action_probs.gather(1, batch_actions.unsqueeze(-1)))

            # rho is the ratio between new and old log probabilities
            ratio = (new_log_probs - batch_log_probs_old).exp()

            # Calculate surrogate loss
            surrogate_loss = ratio * batch_advantages
            clipped_surrogate_loss = torch.clamp(ratio, 1-CLIP_EPSILON, 1+CLIP_EPSILON) * batch_advantages
            policy_loss = -torch.min(surrogate_loss, clipped_surrogate_loss).mean()

            policy_optimizer.zero_grad()
            policy_loss.backward()
            policy_optimizer.step()

            value_loss = criterion(value_net(batch_states),
                                   batch_returns.unsqueeze(-1))

            value_optimizer.zero_grad()
            value_loss.backward()
            value_optimizer.step()
            
            wandb.log({
                "policy_loss": policy_loss.item(),
                "value_loss": value_loss.item(),
                "steps": steps,
            })
            
    return (returns.mean(), returns.std(), steps)

In [None]:

env = gym.make('CartPole-v1', render_mode='rgb_array')
env = gym.wrappers.RecordVideo(env, "video", episode_trigger=lambda x: x % 30 == 0 and x >= 30)

env.reset()
env.start_video_recorder()

for i in range(300):
  _, _, steps = ppo_step()
  if i % 5 == 0:
    print(f"Episode {i}\tSteps: {steps}\tReturn: {steps}")
  
env.close()
wandb.finish()

In [None]:
def record_best_effort():
  env = gym.make('CartPole-v1', render_mode='rgb_array', max_episode_steps=10000)
  env = gym.wrappers.RecordVideo(env, "tests")

  state, _ = env.reset()
  env.start_video_recorder()

  total_reward = 0
  done, i = False, 0
  
  while not done and i < 10000:
    action, _ = policy_net.deterministic_action(state)
    state, reward, done, *_ = env.step(action)
    total_reward += reward
    i += 1

  env.close()

In [None]:
record_best_effort()