# PPO Base Implementation
This will be the baseline implementation for comparing with the other methods.

In [10]:
import random
import pandas as pd
import time
from datetime import datetime
import os

import gym
import numpy as np

import torch
from torch.nn import LeakyReLU, ReLU, Linear, MSELoss, Sequential, Softmax, Dropout
from torch.optim import Adam
from torch.quantization import quantize_dynamic
from torch.nn.utils import clip_grad_norm_

import logging
logging.basicConfig(level=logging.INFO)

In [11]:
SEED = 1234
LEARNING_RATE = 1e-4
GAMMA = 0.99
EPOCHS = 1
CLIP_EPSILON = 0.2
BATCH_SIZE = 1

DEVICE = torch.device("cuda")
TYPE = torch.float32

In [12]:
torch.set_default_dtype(TYPE)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
# torch.backends.cpu.deterministic = True
torch.backends.cudnn.deterministic = True

class Logger:
    def __init__(self, extras: dict):
        self.data = []
        self.base_time = time.time()
        self.extras = extras
        
    def add(self, p_loss, v_loss, steps):
        if len(self.data) == 0:
            self.base_time = time.time()
        
        time_elapsed = round(time.time() - self.base_time, 5)
        self.data.append((time_elapsed, p_loss, v_loss, steps))
    
    def df(self):
        return pd.DataFrame(self.data, columns=["time", "p_loss", "v_loss", "steps"])
    
    def save(self):
        path = "logs/" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        os.mkdir(path)
        pd.DataFrame([self.extras]).to_csv(path + "/info.csv")
        pd.DataFrame(
            self.data, 
            columns=["time", "p_loss", "v_loss", "steps"]
        ).to_csv(path + "/data.csv")

logger = Logger({
    "seed": SEED,
    "learning_rate": LEARNING_RATE,
    "gamma": GAMMA,
    "epochs": EPOCHS,
    "clip_epsilon": CLIP_EPSILON,
    "batch_size": BATCH_SIZE,
    "device": DEVICE,
    "type": TYPE
})

In [13]:
class CartPoleQuantized:
  def __init__(self, dtype: torch.dtype, device: torch.device):
    self.dtype = dtype
    self.device = device
    
    env = gym.make('CartPole-v1', render_mode='rgb_array')
    env = gym.wrappers.RecordVideo(env, f"training/{dtype}/", episode_trigger=lambda x: x % 100 == 0 and x >= 30)
    env.reset()
    env.start_video_recorder()
    
    self.env = env
  
  def reset(self):
    return torch.tensor(self.env.reset()[0], dtype=self.dtype, device=self.device)
  
  def step(self, action):
    state, reward, done, truncated, _ = self.env.step(action)
    return torch.tensor(state, dtype=self.dtype, device=self.device), reward, done or truncated
  
  def close(self):
    self.env.close()
  
  @staticmethod
  def env():
    return gym.make('CartPole-v1', render_mode='rgb_array')

## Network Architecture

**PolicyNetwork**:
- Input: State
- Output: Action distribution (0-1)
- 2 Hidden layers with LeakyReLU activation

**ValueNetwork**:
- Input: State
- Output: Value
- 2 Hidden layers with LeakyReLU activation

In [14]:

class PolicyNetwork(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim):
    super().__init__()
    self.model = Sequential(
      Linear(input_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, 2),
      Softmax()
    )

  def forward(self, state): 
    return self.model(state)
 
  def stochastic_action(self, state):
    r"""Returns an action sampled from the policy network."""
    
    probs = self.forward(state).detach()
    # adding floating point error to the maximum probability
    probs[torch.argmax(probs)] += 1 - probs.sum()
    
    probs.squeeze() # quantized tensors have an extra dimension
    
    m = torch.distributions.Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)
  
  def deterministic_action(self, state):
    r"""Returns an action with the highest probability."""
    
    probs = self.forward(state).detach()
    action = torch.argmax(probs)
    return action.item()

  
class ValueNetwork(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim) -> None:
    super().__init__()
    self.model = Sequential(
      Linear(input_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, 1)
    )
  
  def forward(self, state):
    return self.model(state)
  

# Training
- 64 hidden nodes
- Adam optimizer
- MSE loss for value network

In [15]:
class PPOSession:
    def __init__(self, env: CartPoleQuantized):
        self.env = env
        self.episode = 0
        
        _observation_size = CartPoleQuantized.env().observation_space.shape[0]

        self.policy_net = PolicyNetwork(_observation_size, 64).to(DEVICE)
        self.value_net  = ValueNetwork(_observation_size, 64).to(DEVICE)

        self.policy_optimizer = Adam(self.policy_net.parameters(), lr=LEARNING_RATE, eps=1e-7)
        self.value_optimizer  = Adam(self.value_net.parameters(), lr=LEARNING_RATE, eps=1e-7)
        self.quantized = False
    
    @staticmethod
    def compute_returns(rewards):
        returns = torch.zeros(len(rewards))
        R = 0
        for i in reversed(range(len(rewards))):
            R = rewards[i] + GAMMA * R
            returns[i] = R
        return returns
    
    def run(self, episodes):
        self.policy_net.train()
        self.value_net.train()
        
        for i in range(episodes):
            self.episode += 1
            returns, std, steps = self.ppo_step()
            
            if self.episode % 5 == 0:
                print(f"Episode {self.episode} - Returns: {returns} - Std: {std} - Steps: {steps}")

        self.policy_net.eval()
        self.value_net.eval()
    
    def ppo_step(self):
        state = self.env.reset()
        
        # capture entire episode
        done, steps = False, 0
        states, actions, log_probs_old, rewards = [], [], [], []
        
        while not done:
            if self.quantized:
                state = state.unsqueeze(0)
            action, log_prob = self.policy_net.stochastic_action(state)
            next_state, reward, done = self.env.step(action)

            log_probs_old.append(log_prob)
            states.append(state.squeeze())
            actions.append(action)
            rewards.append(reward)

            state = next_state
            steps += 1
        
        # Convert to tensors
        # Be sure to detach() the tensors from the graph as these are "constants"
        states = torch.stack(states).detach().to(DEVICE)
        actions = torch.tensor(actions).detach().to(DEVICE)
        log_probs_old = torch.stack(log_probs_old).detach().to(DEVICE)
        
        returns = self.compute_returns(rewards).detach().to(DEVICE)
        values = self.value_net(states).detach().to(DEVICE)
        advantages = returns - values.squeeze()

        for _ in range(EPOCHS):
            for i in range(0, len(states), BATCH_SIZE):
                # Grab a batch of data
                batch_states = states[i:i+BATCH_SIZE]
                batch_actions = actions[i:i+BATCH_SIZE]
                batch_log_probs_old = log_probs_old[i:i+BATCH_SIZE]
                batch_advantages = advantages[i:i+BATCH_SIZE]
                batch_returns = returns[i:i+BATCH_SIZE]

                # Calculate new log probabilities
                new_action_probs = self.policy_net(batch_states)
                new_log_probs = torch.log(new_action_probs.gather(1, batch_actions.unsqueeze(-1)))

                # rho is the ratio between new and old log probabilities
                ratio = (new_log_probs - batch_log_probs_old).exp()

                # Calculate surrogate loss
                surrogate_loss = ratio * batch_advantages
                clipped_surrogate_loss = torch.clamp(ratio, 1-CLIP_EPSILON, 1+CLIP_EPSILON) * batch_advantages
                policy_loss = -torch.min(surrogate_loss, clipped_surrogate_loss).mean()
                
                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()

                # check for nan
                if torch.isnan(policy_loss):
                    print("NaN detected in policy loss")
                    return

                value_loss = torch.pow(self.value_net(
                    batch_states) - batch_returns.unsqueeze(-1), 2).mean()

                self.value_optimizer.zero_grad()
                value_loss.backward()
                self.value_optimizer.step()

                # check for nan
                if torch.isnan(value_loss):
                    print("NaN detected in value loss")
                    return
        
        logger.add(policy_loss.item(), value_loss.item(), steps)
        return (returns.mean(), returns.std(), steps)
    
    def record_best_effort(self):
        env = gym.make('CartPole-v1', render_mode='rgb_array', max_episode_steps=10000)
        env = gym.wrappers.RecordVideo(env, "tests")

        state, _ = env.reset()
        state = torch.tensor(state, dtype=TYPE, device=DEVICE, requires_grad=False)
        env.start_video_recorder()

        total_reward = 0
        done, i = False, 0
        
        while not done and not truncated:
            if self.quantized:
                state = state.unsqueeze(0)
            
            action = self.policy_net.deterministic_action(state)
            state, reward, done, truncated, _ = env.step(action)
            state = torch.tensor(state, dtype=TYPE, device=DEVICE, requires_grad=False).unsqueeze(0)
            total_reward += reward
            i += 1

        env.close()
        return total_reward, i

In [19]:
env = CartPoleQuantized(TYPE, DEVICE)
session = PPOSession(env)

Moviepy - Building video /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4.
Moviepy - Writing video /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4




In [20]:
session.run(600)

  input = module(input)


Moviepy - Building video /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4.
Moviepy - Writing video /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4



                                                   

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4




Episode 5 - Returns: 10.733765602111816 - Std: 5.823357582092285 - Steps: 22
Episode 10 - Returns: 14.101123809814453 - Std: 7.5943074226379395 - Steps: 30
Episode 15 - Returns: 10.300280570983887 - Std: 5.5916337966918945 - Steps: 21
Episode 20 - Returns: 12.439506530761719 - Std: 6.7269415855407715 - Steps: 26
Episode 25 - Returns: 8.089495658874512 - Std: 4.396872043609619 - Steps: 16
Episode 30 - Returns: 5.804428577423096 - Std: 3.1395034790039062 - Steps: 11
Episode 35 - Returns: 6.267518997192383 - Std: 3.39615535736084 - Steps: 12
Episode 40 - Returns: 6.267518997192383 - Std: 3.39615535736084 - Steps: 12
Episode 45 - Returns: 13.275748252868652 - Std: 7.165063858032227 - Steps: 28
Episode 50 - Returns: 6.267518997192383 - Std: 3.39615535736084 - Steps: 12
Episode 55 - Returns: 5.804428577423096 - Std: 3.1395034790039062 - Steps: 11
Episode 60 - Returns: 17.68506622314453 - Std: 9.420500755310059 - Steps: 39
Episode 65 - Returns: 7.184539794921875 - Std: 3.901630401611328 - Ste

                                                   

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-100.mp4
Episode 105 - Returns: 11.164410591125488 - Std: 6.052726745605469 - Steps: 23
Episode 110 - Returns: 30.117435455322266 - Std: 15.235461235046387 - Steps: 75
Episode 115 - Returns: 14.509784698486328 - Std: 7.805652618408203 - Steps: 31
Episode 120 - Returns: 9.863934516906738 - Std: 5.357532978057861 - Steps: 20
Episode 125 - Returns: 21.4298095703125 - Std: 11.260408401489258 - Steps: 49
Episode 130 - Returns: 13.689783096313477 - Std: 7.38078498840332 - Steps: 29
Episode 135 - Returns: 46.393192291259766 - Std: 21.356096267700195 - Steps: 139
Episode 140 - Returns: 14.101123809814453 - Std: 7.5943074226379395 - Steps: 30
Episode 145 - Returns: 8.537507057189941 - Std: 4.640725612640381 - Steps: 17
Episode 150 - Returns: 47.21794509887695 - Std: 21.61115264892578 - Steps: 143
Episode 155 - Returns: 36.17048263549805 - Std: 17.734111785888672 - Steps: 96
Episode 16

                                                             

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-200.mp4




Episode 205 - Returns: 50.145572662353516 - Std: 22.466360092163086 - Steps: 158
Episode 210 - Returns: 35.07728576660156 - Std: 17.300790786743164 - Steps: 92
Episode 215 - Returns: 23.213842391967773 - Std: 12.11121654510498 - Steps: 54
Episode 220 - Returns: 19.96080780029297 - Std: 10.547213554382324 - Steps: 45
Episode 225 - Returns: 33.95683670043945 - Std: 16.848249435424805 - Steps: 88
Episode 230 - Returns: 32.223331451416016 - Std: 16.131746292114258 - Steps: 82
Episode 235 - Returns: 52.659725189208984 - Std: 23.13473892211914 - Steps: 172
Episode 240 - Returns: 40.040138244628906 - Std: 19.19992446899414 - Steps: 111
Episode 245 - Returns: 39.79297637939453 - Std: 19.109582901000977 - Steps: 110
Episode 250 - Returns: 40.530006408691406 - Std: 19.37761688232422 - Steps: 113
Episode 255 - Returns: 41.25383377075195 - Std: 19.636817932128906 - Steps: 116
Episode 260 - Returns: 27.594907760620117 - Std: 14.125996589660645 - Steps: 67
Episode 265 - Returns: 68.26155853271484 - 

                                                               

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-300.mp4
Episode 305 - Returns: 71.98341369628906 - Std: 25.60456657409668 - Steps: 342
Episode 310 - Returns: 57.86044692993164 - Std: 24.303028106689453 - Steps: 205
Episode 315 - Returns: 60.96955490112305 - Std: 24.847301483154297 - Steps: 228
Episode 320 - Returns: 71.91067504882812 - Std: 25.60679054260254 - Steps: 341
Episode 325 - Returns: 47.62316131591797 - Std: 21.73423957824707 - Steps: 145
Episode 330 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 335 - Returns: 44.24451446533203 - Std: 20.663883209228516 - Steps: 129
Episode 340 - Returns: 66.269775390625 - Std: 25.4649715423584 - Steps: 275
Episode 345 - Returns: 59.25944137573242 - Std: 24.56324005126953 - Steps: 215
Episode 350 - Returns: 74.88227081298828 - Std: 25.419845581054688 - Steps: 386
Episode 355 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 3

                                                               

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-400.mp4
Episode 405 - Returns: 71.83761596679688 - Std: 25.60890769958496 - Steps: 340
Episode 410 - Returns: 73.8925552368164 - Std: 25.504718780517578 - Steps: 370
Episode 415 - Returns: 80.21537017822266 - Std: 24.509235382080078 - Steps: 497
Episode 420 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 425 - Returns: 71.91067504882812 - Std: 25.60679054260254 - Steps: 341
Episode 430 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 435 - Returns: 68.79418182373047 - Std: 25.601484298706055 - Steps: 302
Episode 440 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 445 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 450 - Returns: 65.03495788574219 - Std: 25.359161376953125 - Steps: 263
Episode 455 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode

                                                               

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-500.mp4
Episode 505 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 510 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 515 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 520 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 525 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 530 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 535 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 540 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 545 - Returns: 63.26738357543945 - Std: 25.166269302368164 - Steps: 247
Episode 550 - Returns: 71.54199981689453 - Std: 25.616334915161133 - Steps: 336
Episode 555 - Returns: 80.3301010131836 - Std: 24.480113983154297 - Steps: 500
Episode 

In [180]:
session.record_best_effort()

Moviepy - Building video /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4
Moviepy - Building video /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                                  

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4


KeyboardInterrupt: 

In [21]:
logger.save()

# Loading and saving the new checkpoints!

In [250]:
# save weights
torch.save(session.policy_net.state_dict(), "float-16-good-policy.pt")
torch.save(session.value_net.state_dict(), "float-16-good-value.pt")

In [113]:
# Load weights
env = CartPoleQuantized(TYPE, DEVICE)
session = PPOSession(env)
session.policy_net.load_state_dict(torch.load("checkpoints/float-32-good-policy.pt"))
session.value_net.load_state_dict(torch.load("checkpoints/float-32-good-value.pt"))

session.policy_net.eval()
session.value_net.eval()

torch.backends.quantized.engine = "qnnpack"

quantize_dynamic(
    session.policy_net,
    {torch.nn.Linear},
    dtype=torch.qint8,
    inplace=True
)

session.quantized = True
session.run(500)

Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4


  return self._call_impl(*args, **kwargs)


Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4



                                                                

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
#session.record_best_effort()

Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4




Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                                   

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4


(15321.0, 15321)