# PPO Base Implementation
This will be the baseline implementation for comparing with the other methods.

In [9]:
import random
import pandas as pd
import time
from datetime import datetime
import os

import gym
import numpy as np

import torch
from torch.nn import LeakyReLU, ReLU, Linear, MSELoss, Sequential, Softmax, Dropout
from torch.optim import Adam
from torch.quantization import quantize_dynamic
from torch.nn.utils import clip_grad_norm_

import logging
logging.basicConfig(level=logging.INFO)

In [10]:
SEED = 1234
LEARNING_RATE = 1e-4
GAMMA = 0.99
EPOCHS = 1
CLIP_EPSILON = 0.2
BATCH_SIZE = 1

DEVICE = torch.device("cpu")
TYPE = torch.float32

In [11]:
torch.set_default_dtype(TYPE)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
# torch.backends.cpu.deterministic = True
torch.backends.cudnn.deterministic = True

class Logger:
    def __init__(self, extras: dict):
        self.data = []
        self.base_time = time.time()
        self.extras = extras
        
    def add(self, p_loss, v_loss, steps):
        if len(self.data) == 0:
            self.base_time = time.time()
        
        time_elapsed = round(time.time() - self.base_time, 5)
        self.data.append((time_elapsed, p_loss, v_loss, steps))
    
    def df(self):
        return pd.DataFrame(self.data, columns=["time", "p_loss", "v_loss", "steps"])
    
    def save(self):
        path = "logs/" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        os.mkdir(path)
        pd.DataFrame([self.extras]).to_csv(path + "/info.csv")
        pd.DataFrame(
            self.data, 
            columns=["time", "p_loss", "v_loss", "steps"]
        ).to_csv(path + "/data.csv")

logger = Logger({
    "seed": SEED,
    "learning_rate": LEARNING_RATE,
    "gamma": GAMMA,
    "epochs": EPOCHS,
    "clip_epsilon": CLIP_EPSILON,
    "batch_size": BATCH_SIZE,
    "device": DEVICE,
    "type": TYPE
})

In [12]:
class CartPoleQuantized:
  def __init__(self, dtype: torch.dtype, device: torch.device):
    self.dtype = dtype
    self.device = device
    
    env = gym.make('CartPole-v1', render_mode='rgb_array')
    env = gym.wrappers.RecordVideo(env, f"training/{dtype}/", episode_trigger=lambda x: x % 100 == 0 and x >= 30)
    env.reset()
    env.start_video_recorder()
    
    self.env = env
  
  def reset(self):
    return torch.tensor(self.env.reset()[0], dtype=self.dtype, device=self.device)
  
  def step(self, action):
    state, reward, done, truncated, _ = self.env.step(action)
    return torch.tensor(state, dtype=self.dtype, device=self.device), reward, done or truncated
  
  def close(self):
    self.env.close()
  
  @staticmethod
  def env():
    return gym.make('CartPole-v1', render_mode='rgb_array')

## Network Architecture

**PolicyNetwork**:
- Input: State
- Output: Action distribution (0-1)
- 2 Hidden layers with LeakyReLU activation

**ValueNetwork**:
- Input: State
- Output: Value
- 2 Hidden layers with LeakyReLU activation

In [13]:

class PolicyNetwork(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim):
    super().__init__()
    self.model = Sequential(
      Linear(input_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, 2),
      Softmax()
    )

  def forward(self, state): 
    return self.model(state)
 
  def stochastic_action(self, state):
    r"""Returns an action sampled from the policy network."""
    
    probs = self.forward(state).detach()
    # adding floating point error to the maximum probability
    probs[torch.argmax(probs)] += 1 - probs.sum()
    
    probs.squeeze() # quantized tensors have an extra dimension
    
    m = torch.distributions.Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)
  
  def deterministic_action(self, state):
    r"""Returns an action with the highest probability."""
    
    probs = self.forward(state).detach()
    action = torch.argmax(probs)
    return action.item()

  
class ValueNetwork(torch.nn.Module):
  def __init__(self, input_dim, hidden_dim) -> None:
    super().__init__()
    self.model = Sequential(
      Linear(input_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, hidden_dim),
      ReLU(),
      Linear(hidden_dim, 1)
    )
  
  def forward(self, state):
    return self.model(state)
  

# Training
- 64 hidden nodes
- Adam optimizer
- MSE loss for value network

In [14]:
class PPOSession:
    def __init__(self, env: CartPoleQuantized):
        self.env = env
        self.episode = 0
        
        _observation_size = CartPoleQuantized.env().observation_space.shape[0]

        self.policy_net = PolicyNetwork(_observation_size, 64).to(DEVICE)
        self.value_net  = ValueNetwork(_observation_size, 64).to(DEVICE)

        self.policy_optimizer = Adam(self.policy_net.parameters(), lr=LEARNING_RATE, eps=1e-7)
        self.value_optimizer  = Adam(self.value_net.parameters(), lr=LEARNING_RATE, eps=1e-7)
        self.quantized = False
    
    @staticmethod
    def compute_returns(rewards):
        returns = torch.zeros(len(rewards))
        R = 0
        for i in reversed(range(len(rewards))):
            R = rewards[i] + GAMMA * R
            returns[i] = R
        return returns
    
    def run(self, episodes):
        self.policy_net.train()
        self.value_net.train()
        
        for i in range(episodes):
            self.episode += 1
            returns, std, steps = self.ppo_step()
            
            if self.episode % 5 == 0:
                print(f"Episode {self.episode} - Returns: {returns} - Std: {std} - Steps: {steps}")

        self.policy_net.eval()
        self.value_net.eval()
    
    def ppo_step(self):
        state = self.env.reset()
        
        # capture entire episode
        done, steps = False, 0
        states, actions, log_probs_old, rewards = [], [], [], []
        
        while not done:
            if self.quantized:
                state = state.unsqueeze(0)
            action, log_prob = self.policy_net.stochastic_action(state)
            next_state, reward, done = self.env.step(action)

            log_probs_old.append(log_prob)
            states.append(state.squeeze())
            actions.append(action)
            rewards.append(reward)

            state = next_state
            steps += 1
        
        # Convert to tensors
        # Be sure to detach() the tensors from the graph as these are "constants"
        states = torch.stack(states).detach().to(DEVICE)
        actions = torch.tensor(actions).detach().to(DEVICE)
        log_probs_old = torch.stack(log_probs_old).detach().to(DEVICE)
        
        returns = self.compute_returns(rewards).detach().to(DEVICE)
        values = self.value_net(states).detach().to(DEVICE)
        advantages = returns - values.squeeze()

        for _ in range(EPOCHS):
            for i in range(0, len(states), BATCH_SIZE):
                # Grab a batch of data
                batch_states = states[i:i+BATCH_SIZE]
                batch_actions = actions[i:i+BATCH_SIZE]
                batch_log_probs_old = log_probs_old[i:i+BATCH_SIZE]
                batch_advantages = advantages[i:i+BATCH_SIZE]
                batch_returns = returns[i:i+BATCH_SIZE]

                # Calculate new log probabilities
                new_action_probs = self.policy_net(batch_states)
                new_log_probs = torch.log(new_action_probs.gather(1, batch_actions.unsqueeze(-1)))

                # rho is the ratio between new and old log probabilities
                ratio = (new_log_probs - batch_log_probs_old).exp()

                # Calculate surrogate loss
                surrogate_loss = ratio * batch_advantages
                clipped_surrogate_loss = torch.clamp(ratio, 1-CLIP_EPSILON, 1+CLIP_EPSILON) * batch_advantages
                policy_loss = -torch.min(surrogate_loss, clipped_surrogate_loss).mean()
                
                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()

                # check for nan
                if torch.isnan(policy_loss):
                    print("NaN detected in policy loss")
                    return

                value_loss = torch.pow(self.value_net(
                    batch_states) - batch_returns.unsqueeze(-1), 2).mean()

                self.value_optimizer.zero_grad()
                value_loss.backward()
                self.value_optimizer.step()

                # check for nan
                if torch.isnan(value_loss):
                    print("NaN detected in value loss")
                    return
        
        logger.add(policy_loss.item(), value_loss.item(), steps)
        return (returns.mean(), returns.std(), steps)
    
    def record_best_effort(self):
        env = gym.make('CartPole-v1', render_mode='rgb_array', max_episode_steps=10000)
        env = gym.wrappers.RecordVideo(env, "tests")

        state, _ = env.reset()
        state = torch.tensor(state, dtype=TYPE, device=DEVICE, requires_grad=False)
        env.start_video_recorder()

        total_reward = 0
        done, i = False, 0
        
        while not done and not truncated:
            if self.quantized:
                state = state.unsqueeze(0)
            
            action = self.policy_net.deterministic_action(state)
            state, reward, done, truncated, _ = env.step(action)
            state = torch.tensor(state, dtype=TYPE, device=DEVICE, requires_grad=False).unsqueeze(0)
            total_reward += reward
            i += 1

        env.close()
        return total_reward, i

In [15]:
env = CartPoleQuantized(TYPE, DEVICE)
session = PPOSession(env)

  logger.warn(


In [16]:
session.run(600)

Moviepy - Building video /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4.
Moviepy - Writing video /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4



                                                   

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4
Episode 5 - Returns: 9.863934516906738 - Std: 5.357532978057861 - Steps: 20




Episode 10 - Returns: 17.297117233276367 - Std: 9.225837707519531 - Steps: 38
Episode 15 - Returns: 7.638513565063477 - Std: 4.150516033172607 - Steps: 15
Episode 20 - Returns: 7.638513565063477 - Std: 4.150516033172607 - Steps: 15
Episode 25 - Returns: 9.863934516906738 - Std: 5.357532978057861 - Steps: 20
Episode 30 - Returns: 5.804428577423096 - Std: 3.1395034790039062 - Steps: 11
Episode 35 - Returns: 17.297117233276367 - Std: 9.225837707519531 - Steps: 38
Episode 40 - Returns: 11.592232704162598 - Std: 6.279765605926514 - Steps: 24
Episode 45 - Returns: 14.509786605834961 - Std: 7.805652618408203 - Steps: 31
Episode 50 - Returns: 6.727548599243164 - Std: 3.650186777114868 - Steps: 13
Episode 55 - Returns: 13.275749206542969 - Std: 7.165064334869385 - Steps: 28
Episode 60 - Returns: 10.7337646484375 - Std: 5.823357105255127 - Steps: 22
Episode 65 - Returns: 24.941802978515625 - Std: 12.918790817260742 - Steps: 59
Episode 70 - Returns: 6.727548599243164 - Std: 3.650186777114868 - St

                                                   

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-100.mp4




Episode 105 - Returns: 18.83395767211914 - Std: 9.992586135864258 - Steps: 42
Episode 110 - Returns: 25.61774253845215 - Std: 13.230171203613281 - Steps: 61
Episode 115 - Returns: 21.066091537475586 - Std: 11.084874153137207 - Steps: 48
Episode 120 - Returns: 12.858996391296387 - Std: 6.9471235275268555 - Steps: 27
Episode 125 - Returns: 7.184538841247559 - Std: 3.901630163192749 - Steps: 14
Episode 130 - Returns: 8.982568740844727 - Std: 4.88210391998291 - Steps: 18
Episode 135 - Returns: 25.280845642089844 - Std: 13.075297355651855 - Steps: 60
Episode 140 - Returns: 29.49863624572754 - Std: 14.966854095458984 - Steps: 73
Episode 145 - Returns: 24.60059928894043 - Std: 12.760636329650879 - Steps: 58
Episode 150 - Returns: 19.211999893188477 - Std: 10.179376602172852 - Steps: 43
Episode 155 - Returns: 29.18631362915039 - Std: 14.83039665222168 - Steps: 72
Episode 160 - Returns: 21.4298095703125 - Std: 11.260408401489258 - Steps: 49
Episode 165 - Returns: 26.28515625 - Std: 13.535084724

                                                             

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-200.mp4




Episode 205 - Returns: 41.492225646972656 - Std: 19.721296310424805 - Steps: 117
Episode 210 - Returns: 39.04246520996094 - Std: 18.83245277404785 - Steps: 107
Episode 215 - Returns: 51.60670471191406 - Std: 22.862524032592773 - Steps: 166
Episode 220 - Returns: 48.023681640625 - Std: 21.85443878173828 - Steps: 147
Episode 225 - Returns: 41.729190826416016 - Std: 19.80483055114746 - Steps: 118
Episode 230 - Returns: 56.38336181640625 - Std: 24.00244903564453 - Steps: 195
Episode 235 - Returns: 40.530006408691406 - Std: 19.37761688232422 - Steps: 113
Episode 240 - Returns: 62.802345275878906 - Std: 25.107769012451172 - Steps: 243
Episode 245 - Returns: 38.53449249267578 - Std: 18.642515182495117 - Steps: 105
Episode 250 - Returns: 48.023681640625 - Std: 21.85443878173828 - Steps: 147
Episode 255 - Returns: 54.822811126708984 - Std: 23.657377243041992 - Steps: 185
Episode 260 - Returns: 55.61387252807617 - Std: 23.83574867248535 - Steps: 190
Episode 265 - Returns: 63.15202331542969 - Std

                                                               

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-300.mp4
Episode 305 - Returns: 53.51039505004883 - Std: 23.34624481201172 - Steps: 177
Episode 310 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 315 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 320 - Returns: 74.33402252197266 - Std: 25.469770431518555 - Steps: 377
Episode 325 - Returns: 74.94189453125 - Std: 25.41397476196289 - Steps: 387
Episode 330 - Returns: 75.63785552978516 - Std: 25.338838577270508 - Steps: 399
Episode 335 - Returns: 79.45845031738281 - Std: 24.69032859802246 - Steps: 478
Episode 340 - Returns: 73.6341552734375 - Std: 25.52305030822754 - Steps: 366
Episode 345 - Returns: 69.89702606201172 - Std: 25.625102996826172 - Steps: 315
Episode 350 - Returns: 63.38215255737305 - Std: 25.180221557617188 - Steps: 248
Episode 355 - Returns: 45.33467102050781 - Std: 21.02001953125 - Steps: 134
Episode 360

                                                               

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-400.mp4
Episode 405 - Returns: 76.61822509765625 - Std: 25.21173858642578 - Steps: 417
Episode 410 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 415 - Returns: 60.32577896118164 - Std: 24.74488067626953 - Steps: 223
Episode 420 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 425 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 430 - Returns: 73.95645904541016 - Std: 25.49994659423828 - Steps: 371
Episode 435 - Returns: 65.24596405029297 - Std: 25.37897300720215 - Steps: 265
Episode 440 - Returns: 63.03606414794922 - Std: 25.137563705444336 - Steps: 245
Episode 445 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 450 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 455 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Epi

                                                               

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-500.mp4
Episode 505 - Returns: 58.146331787109375 - Std: 24.358173370361328 - Steps: 207
Episode 510 - Returns: 67.80571746826172 - Std: 25.561294555664062 - Steps: 291
Episode 515 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 520 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 525 - Returns: 78.86508178710938 - Std: 24.81931495666504 - Steps: 464
Episode 530 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 535 - Returns: 68.17127227783203 - Std: 25.57819938659668 - Steps: 295
Episode 540 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 545 - Returns: 56.07810974121094 - Std: 23.9371337890625 - Steps: 193
Episode 550 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Episode 555 - Returns: 80.33009338378906 - Std: 24.480113983154297 - Steps: 500
Ep

In [18]:
logger.save()

In [19]:
session.record_best_effort()

Moviepy - Building video /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /home/ubuntu/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4




UnboundLocalError: local variable 'truncated' referenced before assignment

# Loading and saving the new checkpoints!

In [250]:
# save weights
torch.save(session.policy_net.state_dict(), "float-16-good-policy.pt")
torch.save(session.value_net.state_dict(), "float-16-good-value.pt")

In [113]:
# Load weights
env = CartPoleQuantized(TYPE, DEVICE)
session = PPOSession(env)
session.policy_net.load_state_dict(torch.load("checkpoints/float-32-good-policy.pt"))
session.value_net.load_state_dict(torch.load("checkpoints/float-32-good-value.pt"))

session.policy_net.eval()
session.value_net.eval()

torch.backends.quantized.engine = "qnnpack"

quantize_dynamic(
    session.policy_net,
    {torch.nn.Linear},
    dtype=torch.qint8,
    inplace=True
)

session.quantized = True
session.run(500)

Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4


  return self._call_impl(*args, **kwargs)


Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4



                                                                

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/training/torch.float32/rl-video-episode-0.mp4


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
#session.record_best_effort()

Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                  

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4




Moviepy - Building video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4.
Moviepy - Writing video /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4



                                                                   

Moviepy - Done !
Moviepy - video ready /Users/b0kch01/Documents/Code/QuantizeRL/cart_pole/tests/rl-video-episode-0.mp4


(15321.0, 15321)