# PPO Base Implementation
This will be the baseline implementation for comparing with the other methods.

In [10]:
import random
import wandb

import gym
import numpy as np

import torch
from torch.nn import LeakyReLU, Linear, MSELoss, Sequential, Softmax
from torch.optim import Adam

import logging
logging.basicConfig(level=logging.INFO)

In [11]:
SEED = 1234
LEARNING_RATE = 1e-4
GAMMA = 0.99
EPOCHS = 200
CLIP_EPSILON = 0.2
BATCH_SIZE = 15

DEVICE = torch.device("mps")
TYPE = torch.float32

In [12]:

torch.set_default_dtype(TYPE)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cpu.deterministic = True
torch.backends.cudnn.deterministic = True

class WandbLogger:
  def __init__(self, enable=False):
    self.enable = enable
    
    if enable:
      wandb.init(
          project="ppo-base",

          config={
              "learning_rate": LEARNING_RATE,
              "gamma": GAMMA,
              "epochs": EPOCHS,
              "clip_epsilon": CLIP_EPSILON,
              "batch_size": BATCH_SIZE,
              "seed": SEED
          },
      )

  def log(self, **kwargs):
    if self.enable:
      wandb.log(kwargs)
      
  def finish(self):
    if self.enable:
      wandb.finish()
      
wandb = WandbLogger(enable=False)

In [13]:
class CartPoleQuantized:
  def __init__(self, dtype: torch.dtype, device: torch.device):
    self.dtype = dtype
    self.device = device
    
    env = gym.make('CartPole-v1', render_mode='rgb_array')
    env = gym.wrappers.RecordVideo(env, f"training/{dtype}/", episode_trigger=lambda x: x % 100 == 0 and x >= 30)
    env.reset()
    env.start_video_recorder()
    
    self.env = env
  
  def reset(self):
    return torch.tensor(self.env.reset()[0], dtype=self.dtype, device=self.device)
  
  def step(self, action):
    state, reward, done, _, _ = self.env.step(action)
    return torch.tensor(state, dtype=self.dtype, device=self.device), reward, done
  
  def close(self):
    self.env.close()
  
  @staticmethod
  def env():
    return gym.make('CartPole-v1', render_mode='rgb_array')

## Network Architecture

**PolicyNetwork**:
- Input: State
- Output: Action distribution (0-1)
- 2 Hidden layers with LeakyReLU activation

**ValueNetwork**:
- Input: State
- Output: Value
- 2 Hidden layers with LeakyReLU activation

In [14]:

class PolicyNetwork(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim):
    super().__init__()
    self.model = Sequential(
      Linear(input_dim, hidden_dim),
      LeakyReLU(),
      Linear(hidden_dim, hidden_dim),
      LeakyReLU(),
      Linear(hidden_dim, 2),
      Softmax(dim=-1)
    )

  def forward(self, state):
    return self.model(state)
 
  def stochastic_action(self, state):
    r"""Returns an action sampled from the policy network."""
    
    probs = self.forward(state)
    
    if torch.isnan(probs).any():
      print(state)
      print(probs)
    
    m = torch.distributions.Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)
  
  def deterministic_action(self, state):
    r"""Returns an action with the highest probability."""
    
    probs = self.forward(state)
    action = torch.argmax(probs)
    return action.item(), probs[0][action].item()

  
class ValueNetwork(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim) -> None:
    super().__init__()
    self.model = Sequential(
      Linear(input_dim, hidden_dim),
      LeakyReLU(),
      Linear(hidden_dim, hidden_dim),
      LeakyReLU(),
      Linear(hidden_dim, 1)
    )
  
  def forward(self, state):
    return self.model(state)
  

# Training
- 64 hidden nodes
- Adam optimizer
- MSE loss for value network

In [15]:
_observation_size = CartPoleQuantized.env().observation_space.shape[0]

policy_net = PolicyNetwork(_observation_size, 64).to(DEVICE)
value_net  = ValueNetwork(_observation_size, 64).to(DEVICE)

policy_optimizer = Adam(policy_net.parameters(), lr=LEARNING_RATE)
value_optimizer  = Adam(value_net.parameters(), lr=LEARNING_RATE)

criterion = MSELoss()

In [16]:
def compute_returns(rewards):
  returns = torch.zeros(len(rewards))
  R = 0
  for i in reversed(range(len(rewards))):
    R = rewards[i] + GAMMA * R
    returns[i] = R
  return returns

In [17]:
def ppo_step(env: CartPoleQuantized):
    state = env.reset()
    
    # capture entire episode
    done, steps = False, 0
    states, actions, log_probs_old, rewards = [], [], [], []
    
    while not done:
        action, log_prob = policy_net.stochastic_action(state)
        next_state, reward, done = env.step(action)

        log_probs_old.append(log_prob)
        states.append(state)
        actions.append(action)
        rewards.append(reward)

        state = next_state
        steps += 1
    
    # Convert to tensors
    # Be sure to detach() the tensors from the graph as these are "constants"
    states = torch.stack(states).detach().to(DEVICE)
    actions = torch.tensor(actions).detach().to(DEVICE)
    log_probs_old = torch.stack(log_probs_old).detach().to(DEVICE)
    
    returns = compute_returns(rewards).detach().to(DEVICE)
    
    values = value_net(states)
    advantages = (returns - values.squeeze()).detach().to(DEVICE)

    for _ in range(EPOCHS):
        for i in range(0, len(states), BATCH_SIZE):
            # Grab a batch of data
            batch_states = states[i:i+BATCH_SIZE]
            batch_actions = actions[i:i+BATCH_SIZE]
            batch_log_probs_old = log_probs_old[i:i+BATCH_SIZE]
            batch_advantages = advantages[i:i+BATCH_SIZE]
            batch_returns = returns[i:i+BATCH_SIZE]

            # Calculate new log probabilities
            new_action_probs = policy_net(batch_states)
            new_log_probs = torch.log(new_action_probs.gather(1, batch_actions.unsqueeze(-1)))

            # rho is the ratio between new and old log probabilities
            ratio = (new_log_probs - batch_log_probs_old).exp()

            # Calculate surrogate loss
            surrogate_loss = ratio * batch_advantages
            clipped_surrogate_loss = torch.clamp(ratio, 1-CLIP_EPSILON, 1+CLIP_EPSILON) * batch_advantages
            policy_loss = -torch.min(surrogate_loss, clipped_surrogate_loss).mean()

            policy_optimizer.zero_grad()
            policy_loss.backward()
            policy_optimizer.step()
            
            value_loss = torch.pow(value_net(batch_states) - batch_returns.unsqueeze(-1), 2).mean()

            value_optimizer.zero_grad()
            value_loss.backward()
            value_optimizer.step()
            
    return (returns.mean(), returns.std(), steps)

In [18]:
env = CartPoleQuantized(TYPE, DEVICE)

for i in range(300):
  _, _, steps = ppo_step(env)
  if i % 5 == 0:
    print(f"Episode {i}\tSteps: {steps}\tReturn: {steps}")
  
env.close()
wandb.finish()

  logger.warn(


Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4



                                                   

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4




Episode 0	Steps: 11	Return: 11
Episode 5	Steps: 100	Return: 100
Episode 10	Steps: 33	Return: 33
Episode 15	Steps: 23	Return: 23
Episode 20	Steps: 31	Return: 31
Episode 25	Steps: 11	Return: 11
Episode 30	Steps: 8	Return: 8
Episode 35	Steps: 312	Return: 312
Episode 40	Steps: 36	Return: 36
Episode 45	Steps: 15	Return: 15
Episode 50	Steps: 19	Return: 19
Episode 55	Steps: 71	Return: 71
Episode 60	Steps: 15	Return: 15
Episode 65	Steps: 28	Return: 28
Episode 70	Steps: 22	Return: 22
Episode 75	Steps: 37	Return: 37
Episode 80	Steps: 35	Return: 35
Episode 85	Steps: 58	Return: 58
Episode 90	Steps: 39	Return: 39
Episode 95	Steps: 15	Return: 15
Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-100.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-100.mp4



                                                             

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-100.mp4




Episode 100	Steps: 64	Return: 64
Episode 105	Steps: 10	Return: 10
Episode 110	Steps: 10	Return: 10
Episode 115	Steps: 9	Return: 9
Episode 120	Steps: 10	Return: 10
Episode 125	Steps: 9	Return: 9
Episode 130	Steps: 8	Return: 8
Episode 135	Steps: 29	Return: 29
Episode 140	Steps: 32	Return: 32
Episode 145	Steps: 37	Return: 37
Episode 150	Steps: 15	Return: 15
Episode 155	Steps: 10	Return: 10
Episode 160	Steps: 8	Return: 8
Episode 165	Steps: 9	Return: 9
Episode 170	Steps: 9	Return: 9
Episode 175	Steps: 10	Return: 10
Episode 180	Steps: 9	Return: 9
Episode 185	Steps: 9	Return: 9
Episode 190	Steps: 8	Return: 8
Episode 195	Steps: 10	Return: 10
Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-200.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-200.mp4



                                                   

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-200.mp4




Episode 200	Steps: 21	Return: 21
Episode 205	Steps: 12	Return: 12
Episode 210	Steps: 14	Return: 14
Episode 215	Steps: 10	Return: 10
Episode 220	Steps: 10	Return: 10
Episode 225	Steps: 9	Return: 9
Episode 230	Steps: 75	Return: 75
Episode 235	Steps: 13	Return: 13
Episode 240	Steps: 130	Return: 130
Episode 245	Steps: 42	Return: 42
Episode 250	Steps: 24	Return: 24
Episode 255	Steps: 20	Return: 20
Episode 260	Steps: 97	Return: 97
Episode 265	Steps: 16	Return: 16
Episode 270	Steps: 44	Return: 44
Episode 275	Steps: 9	Return: 9
Episode 280	Steps: 79	Return: 79
Episode 285	Steps: 29	Return: 29
Episode 290	Steps: 97	Return: 97
Episode 295	Steps: 10	Return: 10
Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-300.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-300.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-300.mp4




In [19]:
def record_best_effort():
  env = gym.make('CartPole-v1', render_mode='rgb_array', max_episode_steps=10000)
  env = gym.wrappers.RecordVideo(env, "tests")

  state, _ = env.reset()
  env.start_video_recorder()

  total_reward = 0
  done, i = False, 0
  
  while not done and i < 10000:
    action, _ = policy_net.deterministic_action(state)
    state, reward, done, *_ = env.step(action)
    total_reward += reward
    i += 1

  env.close()

In [20]:
record_best_effort()

  logger.warn(


Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4




TypeError: linear(): argument 'input' (position 1) must be Tensor, not numpy.ndarray

: 