# PPO Base Implementation
This will be the baseline implementation for comparing with the other methods.

In [170]:
import random
import pandas as pd
import time
from datetime import datetime
import os

import gym
import numpy as np

import torch
from torch.nn import LeakyReLU, ReLU, Linear, MSELoss, Sequential, Softmax, Dropout
from torch.optim import Adam
from torch.quantization import quantize_dynamic
from torch.nn.utils import clip_grad_norm_

import logging
logging.basicConfig(level=logging.INFO)

In [171]:
SEED = 1234
LEARNING_RATE = 1e-4
GAMMA = 0.99
EPOCHS = 1
CLIP_EPSILON = 0.2
BATCH_SIZE = 1

DEVICE = torch.device("cuda")
TYPE = torch.float32

In [172]:
torch.set_default_dtype(TYPE)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
# torch.backends.cpu.deterministic = True
torch.backends.cudnn.deterministic = True

class Logger:
    def __init__(self, extras: dict):
        self.data = []
        self.base_time = time.time()
        self.extras = extras
        
    def add(self, p_loss, v_loss, steps):
        if len(self.data) == 0:
            self.base_time = time.time()
        
        time_elapsed = round(time.time() - self.base_time, 5)
        self.data.append((time_elapsed, p_loss, v_loss, steps))
    
    def df(self):
        return pd.DataFrame(self.data, columns=["time", "p_loss", "v_loss", "steps"])
    
    def save(self):
        path = "logs/" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        os.mkdir(path)
        pd.DataFrame([self.extras]).to_csv(path + "/info.csv")
        pd.DataFrame(
            self.data, 
            columns=["time", "p_loss", "v_loss", "steps"]
        ).to_csv(path + "/data.csv")

logger = Logger({
    "seed": SEED,
    "learning_rate": LEARNING_RATE,
    "gamma": GAMMA,
    "epochs": EPOCHS,
    "clip_epsilon": CLIP_EPSILON,
    "batch_size": BATCH_SIZE,
    "device": DEVICE,
    "type": TYPE
})

In [173]:
class CartPoleQuantized:
  def __init__(self, dtype: torch.dtype, device: torch.device):
    self.dtype = dtype
    self.device = device
    
    env = gym.make('CartPole-v1', render_mode='rgb_array')
    env = gym.wrappers.RecordVideo(env, f"training/{dtype}/", episode_trigger=lambda x: x % 100 == 0 and x >= 30)
    env.reset()
    env.start_video_recorder()
    
    self.env = env
  
  def reset(self):
    return torch.tensor(self.env.reset()[0], dtype=self.dtype, device=self.device)
  
  def step(self, action):
    state, reward, done, truncated, _ = self.env.step(action)
    return torch.tensor(state, dtype=self.dtype, device=self.device), reward, done or truncated
  
  def close(self):
    self.env.close()
  
  @staticmethod
  def env():
    return gym.make('CartPole-v1', render_mode='rgb_array')

## Network Architecture

**PolicyNetwork**:
- Input: State
- Output: Action distribution (0-1)
- 2 Hidden layers with LeakyReLU activation

**ValueNetwork**:
- Input: State
- Output: Value
- 2 Hidden layers with LeakyReLU activation

In [174]:

class PolicyNetwork(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim):
    super().__init__()
    self.model = Sequential(
      Linear(input_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, 2),
      Softmax()
    )

  def forward(self, state): 
    return self.model(state)
 
  def stochastic_action(self, state):
    r"""Returns an action sampled from the policy network."""
    
    probs = self.forward(state).detach()
    # adding floating point error to the maximum probability
    probs[torch.argmax(probs)] += 1 - probs.sum()
    
    probs.squeeze() # quantized tensors have an extra dimension
    
    m = torch.distributions.Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)
  
  def deterministic_action(self, state):
    r"""Returns an action with the highest probability."""
    
    probs = self.forward(state).detach()
    action = torch.argmax(probs)
    return action.item()

  
class ValueNetwork(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim) -> None:
    super().__init__()
    self.model = Sequential(
      Linear(input_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, 1)
    )
  
  def forward(self, state):
    return self.model(state)
  

# Training
- 64 hidden nodes
- Adam optimizer
- MSE loss for value network

In [179]:
class PPOSession:
    def __init__(self, env: CartPoleQuantized):
        self.env = env
        self.episode = 0
        
        _observation_size = CartPoleQuantized.env().observation_space.shape[0]

        self.policy_net = PolicyNetwork(_observation_size, 64).to(DEVICE)
        self.value_net  = ValueNetwork(_observation_size, 64).to(DEVICE)

        self.policy_optimizer = Adam(self.policy_net.parameters(), lr=LEARNING_RATE, eps=1e-7)
        self.value_optimizer  = Adam(self.value_net.parameters(), lr=LEARNING_RATE, eps=1e-7)
        self.quantized = False
    
    @staticmethod
    def compute_returns(rewards):
        returns = torch.zeros(len(rewards))
        R = 0
        for i in reversed(range(len(rewards))):
            R = rewards[i] + GAMMA * R
            returns[i] = R
        return returns
    
    def run(self, episodes):
        self.policy_net.train()
        self.value_net.train()
        
        for i in range(episodes):
            self.episode += 1
            returns, std, steps = self.ppo_step()
            
            if self.episode % 5 == 0:
                print(f"Episode {self.episode} - Returns: {returns} - Std: {std} - Steps: {steps}")

        self.policy_net.eval()
        self.value_net.eval()
    
    def ppo_step(self):
        state = self.env.reset()
        
        # capture entire episode
        done, steps = False, 0
        states, actions, log_probs_old, rewards = [], [], [], []
        
        while not done:
            if self.quantized:
                state = state.unsqueeze(0)
            action, log_prob = self.policy_net.stochastic_action(state)
            next_state, reward, done = self.env.step(action)

            log_probs_old.append(log_prob)
            states.append(state.squeeze())
            actions.append(action)
            rewards.append(reward)

            state = next_state
            steps += 1
        
        # Convert to tensors
        # Be sure to detach() the tensors from the graph as these are "constants"
        states = torch.stack(states).detach().to(DEVICE)
        actions = torch.tensor(actions).detach().to(DEVICE)
        log_probs_old = torch.stack(log_probs_old).detach().to(DEVICE)
        
        returns = self.compute_returns(rewards).detach().to(DEVICE)
        values = self.value_net(states).detach().to(DEVICE)
        advantages = returns - values.squeeze()

        for _ in range(EPOCHS):
            for i in range(0, len(states), BATCH_SIZE):
                # Grab a batch of data
                batch_states = states[i:i+BATCH_SIZE]
                batch_actions = actions[i:i+BATCH_SIZE]
                batch_log_probs_old = log_probs_old[i:i+BATCH_SIZE]
                batch_advantages = advantages[i:i+BATCH_SIZE]
                batch_returns = returns[i:i+BATCH_SIZE]

                # Calculate new log probabilities
                new_action_probs = self.policy_net(batch_states)
                new_log_probs = torch.log(new_action_probs.gather(1, batch_actions.unsqueeze(-1)))

                # rho is the ratio between new and old log probabilities
                ratio = (new_log_probs - batch_log_probs_old).exp()

                # Calculate surrogate loss
                surrogate_loss = ratio * batch_advantages
                clipped_surrogate_loss = torch.clamp(ratio, 1-CLIP_EPSILON, 1+CLIP_EPSILON) * batch_advantages
                policy_loss = -torch.min(surrogate_loss, clipped_surrogate_loss).mean()
                
                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()

                # check for nan
                if torch.isnan(policy_loss):
                    print("NaN detected in policy loss")
                    return

                value_loss = torch.pow(self.value_net(
                    batch_states) - batch_returns.unsqueeze(-1), 2).mean()

                self.value_optimizer.zero_grad()
                value_loss.backward()
                self.value_optimizer.step()

                # check for nan
                if torch.isnan(value_loss):
                    print("NaN detected in value loss")
                    return
        
        logger.add(policy_loss.item(), value_loss.item(), steps)
        return (returns.mean(), returns.std(), steps)
    
    def record_best_effort(self):
        env = gym.make('CartPole-v1', render_mode='rgb_array', max_episode_steps=10000)
        env = gym.wrappers.RecordVideo(env, "tests")

        state, _ = env.reset()
        state = torch.tensor(state, dtype=TYPE, device=DEVICE, requires_grad=False)
        env.start_video_recorder()

        total_reward = 0
        done, i = False, 0
        
        while not done and not truncated:
            if self.quantized:
                state = state.unsqueeze(0)
            
            action = self.policy_net.deterministic_action(state)
            state, reward, done, truncated, _ = env.step(action)
            state = torch.tensor(state, dtype=TYPE, device=DEVICE, requires_grad=False).unsqueeze(0)
            total_reward += reward
            i += 1

        env.close()
        return total_reward, i

In [176]:
env = CartPoleQuantized(TYPE, DEVICE)
session = PPOSession(env)

In [177]:
session.run(2500)

Moviepy - Building video /home/ubuntu/QuantizeRL/cart_pole/training/torch.float16/rl-video-episode-0.mp4.
Moviepy - Writing video /home/ubuntu/QuantizeRL/cart_pole/training/torch.float16/rl-video-episode-0.mp4



                                                   

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float16/rl-video-episode-0.mp4
Episode 5 - Returns: 8.984375 - Std: 4.8828125 - Steps: 18
Episode 10 - Returns: 4.8671875 - Std: 2.619140625 - Steps: 9
Episode 15 - Returns: 8.0859375 - Std: 4.39453125 - Steps: 16
Episode 20 - Returns: 13.6875 - Std: 7.3828125 - Steps: 29
Episode 25 - Returns: 4.8671875 - Std: 2.619140625 - Steps: 9
Episode 30 - Returns: 22.859375 - Std: 11.9453125 - Steps: 53
Episode 35 - Returns: 8.0859375 - Std: 4.39453125 - Steps: 16
Episode 40 - Returns: 6.26953125 - Std: 3.396484375 - Steps: 12
Episode 45 - Returns: 7.640625 - Std: 4.1484375 - Steps: 15
Episode 50 - Returns: 21.796875 - Std: 11.4375 - Steps: 50
Episode 55 - Returns: 8.984375 - Std: 4.8828125 - Steps: 18
Episode 60 - Returns: 13.6875 - Std: 7.3828125 - Steps: 29
Episode 65 - Returns: 19.96875 - Std: 10.546875 - Steps: 45
Episode 70 - Returns: 19.96875 - Std: 10.546875 - Steps: 45
Episode 75 - Returns: 7.640625 

                                                   

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float16/rl-video-episode-100.mp4
Episode 105 - Returns: 36.96875 - Std: 18.046875 - Steps: 99
Episode 110 - Returns: 9.421875 - Std: 5.12109375 - Steps: 19
Episode 115 - Returns: 15.71875 - Std: 8.4296875 - Steps: 34
Episode 120 - Returns: 21.796875 - Std: 11.4375 - Steps: 50
Episode 125 - Returns: 16.125 - Std: 8.6328125 - Steps: 35
Episode 130 - Returns: 35.34375 - Std: 17.40625 - Steps: 93
Episode 135 - Returns: 4.3984375 - Std: 2.353515625 - Steps: 8
Episode 140 - Returns: 10.734375 - Std: 5.82421875 - Steps: 22
Episode 145 - Returns: 17.6875 - Std: 9.421875 - Steps: 39
Episode 150 - Returns: 18.453125 - Std: 9.8046875 - Steps: 41
Episode 155 - Returns: 5.8046875 - Std: 3.140625 - Steps: 11
Episode 160 - Returns: 12.4375 - Std: 6.7265625 - Steps: 26
Episode 165 - Returns: 38.28125 - Std: 18.546875 - Steps: 104
Episode 170 - Returns: 32.8125 - Std: 16.375 - Steps: 84
Episode 175 - Returns: 34.25 

                                                             

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float16/rl-video-episode-200.mp4
Episode 205 - Returns: 19.59375 - Std: 10.3671875 - Steps: 44
Episode 210 - Returns: 31.03125 - Std: 15.625 - Steps: 78
Episode 215 - Returns: 27.265625 - Std: 13.984375 - Steps: 66
Episode 220 - Returns: 45.125 - Std: 20.953125 - Steps: 133
Episode 225 - Returns: 32.21875 - Std: 16.125 - Steps: 82
Episode 230 - Returns: 24.609375 - Std: 12.7578125 - Steps: 58
Episode 235 - Returns: 11.59375 - Std: 6.28125 - Steps: 24
Episode 240 - Returns: 23.90625 - Std: 12.4375 - Steps: 56
Episode 245 - Returns: 25.953125 - Std: 13.3828125 - Steps: 62
Episode 250 - Returns: 26.28125 - Std: 13.5390625 - Steps: 63
Episode 255 - Returns: 41.0 - Std: 19.546875 - Steps: 115
Episode 260 - Returns: 39.78125 - Std: 19.109375 - Steps: 110
Episode 265 - Returns: 33.6875 - Std: 16.734375 - Steps: 87
Episode 270 - Returns: 21.4375 - Std: 11.2578125 - Steps: 49
Episode 275 - Returns: 30.125 - 

                                                   

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float16/rl-video-episode-300.mp4
Episode 305 - Returns: 25.625 - Std: 13.234375 - Steps: 61
Episode 310 - Returns: 47.625 - Std: 21.734375 - Steps: 145
Episode 315 - Returns: 29.8125 - Std: 15.1015625 - Steps: 74
Episode 320 - Returns: 31.921875 - Std: 16.0 - Steps: 81
Episode 325 - Returns: 56.375 - Std: 24.0 - Steps: 195
Episode 330 - Returns: 27.921875 - Std: 14.2734375 - Steps: 68
Episode 335 - Returns: 47.40625 - Std: 21.671875 - Steps: 144
Episode 340 - Returns: 35.90625 - Std: 17.625 - Steps: 95
Episode 345 - Returns: 65.25 - Std: 25.375 - Steps: 265
Episode 350 - Returns: 31.03125 - Std: 15.625 - Steps: 78
Episode 355 - Returns: 50.53125 - Std: 22.5625 - Steps: 160
Episode 360 - Returns: 33.96875 - Std: 16.84375 - Steps: 88
Episode 365 - Returns: 40.53125 - Std: 19.375 - Steps: 113
Episode 370 - Returns: 41.25 - Std: 19.640625 - Steps: 116
Episode 375 - Returns: 55.9375 - Std: 23.90625 - Ste

                                                              

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float16/rl-video-episode-400.mp4
Episode 405 - Returns: 60.46875 - Std: 24.765625 - Steps: 224
Episode 410 - Returns: 23.5625 - Std: 12.2734375 - Steps: 55
Episode 415 - Returns: 47.625 - Std: 21.734375 - Steps: 145
Episode 420 - Returns: 63.71875 - Std: 25.21875 - Steps: 251
Episode 425 - Returns: 67.0625 - Std: 25.515625 - Steps: 283
Episode 430 - Returns: 61.84375 - Std: 24.96875 - Steps: 235
Episode 435 - Returns: 41.0 - Std: 19.546875 - Steps: 115
Episode 440 - Returns: 57.71875 - Std: 24.28125 - Steps: 204
Episode 445 - Returns: 45.75 - Std: 21.15625 - Steps: 136
Episode 450 - Returns: 62.8125 - Std: 25.109375 - Steps: 243
Episode 455 - Returns: 67.6875 - Std: 25.5625 - Steps: 290
Episode 460 - Returns: 67.875 - Std: 25.5625 - Steps: 292
Episode 465 - Returns: 80.3125 - Std: 24.484375 - Steps: 500
Episode 470 - Returns: 51.78125 - Std: 22.90625 - Steps: 167
Episode 475 - Returns: 80.3125 - Std

                                                              

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float16/rl-video-episode-500.mp4
Episode 505 - Returns: 80.3125 - Std: 24.484375 - Steps: 500
Episode 510 - Returns: 80.3125 - Std: 24.484375 - Steps: 500
Episode 515 - Returns: 80.3125 - Std: 24.484375 - Steps: 500
Episode 520 - Returns: 80.3125 - Std: 24.484375 - Steps: 500
Episode 525 - Returns: 80.3125 - Std: 24.484375 - Steps: 500
Episode 530 - Returns: 80.3125 - Std: 24.484375 - Steps: 500
Episode 535 - Returns: 80.3125 - Std: 24.484375 - Steps: 500
Episode 540 - Returns: 11.59375 - Std: 6.28125 - Steps: 24
Episode 545 - Returns: 80.3125 - Std: 24.484375 - Steps: 500


KeyboardInterrupt: 

In [180]:
session.record_best_effort()

Moviepy - Building video /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4
Moviepy - Building video /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                                  

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4


KeyboardInterrupt: 

In [181]:
logger.save()

# Loading and saving the new checkpoints!

In [250]:
# save weights
torch.save(session.policy_net.state_dict(), "float-16-good-policy.pt")
torch.save(session.value_net.state_dict(), "float-16-good-value.pt")

In [113]:
# Load weights
env = CartPoleQuantized(TYPE, DEVICE)
session = PPOSession(env)
session.policy_net.load_state_dict(torch.load("checkpoints/float-32-good-policy.pt"))
session.value_net.load_state_dict(torch.load("checkpoints/float-32-good-value.pt"))

session.policy_net.eval()
session.value_net.eval()

torch.backends.quantized.engine = "qnnpack"

quantize_dynamic(
    session.policy_net,
    {torch.nn.Linear},
    dtype=torch.qint8,
    inplace=True
)

session.quantized = True
session.run(500)

Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4


  return self._call_impl(*args, **kwargs)


Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4



                                                                

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
#session.record_best_effort()

Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4




Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                                   

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4


(15321.0, 15321)