<a href="https://colab.research.google.com/github/jbpacker/deep-rl-class/blob/main/unit8/ppo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PPO

resources
* [huggingface deep rl class readme](https://github.com/huggingface/deep-rl-class/tree/main/unit8)
* [course example code](https://github.com/huggingface/deep-rl-class/blob/main/unit8/unit8.ipynb)
* [course ppo chapter](https://huggingface.co/blog/deep-rl-ppo)
* [cleanrl ppo](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py)

TODO:
* extra step() error
* revisit advantage calculation since that's where the error was
* check out https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for small improvements

## Setup

### Installs

In [None]:
!apt install python-opengl
!apt install ffmpeg
!apt install xvfb
!pip3 install pyvirtualdisplay

# Virtual display
from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(500, 500))
virtual_display.start()

!pip install pybullet
!pip install gym
!pip install stable-baselines3[extra]
!pip install git+https://github.com/ntasfi/PyGame-Learning-Environment.git
!pip install git+https://github.com/qlan3/gym-games.git
!pip install huggingface_hub
!pip install wandb
!pip install imageio-ffmpeg

!pip install pyyaml==6.0 # avoid key error metadata

!pip install pyglet # Virtual Screen

### Imports

In [28]:
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
%matplotlib inline

import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

import wandb

import pybullet_envs
import gym
import gym_pygame

from huggingface_hub import notebook_login # To log to our Hugging Face account to be able to upload models to the Hub.

import imageio

### device allocation

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Helper functions



In [47]:
def record_video(env, policy, out_directory="/content/out.mp4", fps=30):
    images = []  
    done = False
    state = env.reset()
    img = env.render(mode='rgb_array')
    images.append(img)
    while not done:
        # Take the action (index) that have the maximum expected future reward given that state
        with torch.no_grad():
            action, _, _, _ = policy.get_action_and_value(torch.Tensor(state).float().unsqueeze(0))
        state, reward, done, info = env.step(action.item()) # We directly put next_state = state for recording logic
        img = env.render(mode='rgb_array')
        images.append(img)
        action.detach()
    imageio.mimsave(out_directory, [np.array(img) for i, img in enumerate(images)], fps=fps)
    wandb.log({"videos": wandb.Video(out_directory, fps=fps)})

# env_id = "CartPole-v1"
# env = gym.make(env_id)
# policy = PolicyNetwork(num_obs, num_act)
# record_video(env, policy, "/home/out.gif", fps=30)

## Network

In [40]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class ActorCriticPolicy(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(env.observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor = nn.Sequential(
            layer_init(nn.Linear(np.array(env.observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, env.action_space.n), std=0.01),
        )

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(x)




# class ActorCriticPolicy(nn.Module):
#     def __init__(self, num_obs, num_acts):
#         super(ActorCriticPolicy, self).__init__()

#         self.l1_actor = nn.Linear(num_obs, 64)
#         self.l2_actor = nn.Linear(64, 64)
#         self.l3_actor = nn.Linear(64, num_acts)
#         torch.nn.init.orthogonal_(self.l3_actor.weight, 1.0)

#         self.l1_critic = nn.Linear(num_obs, 64)
#         self.l2_critic = nn.Linear(64, 64)
#         self.l3_critic = nn.Linear(64, 1)
#         torch.nn.init.orthogonal_(self.l3_critic.weight, 0.1)

#     def forward(self, x):
#         x_actor = self.l1_actor(x)
#         x_actor = F.relu(x_actor)
#         x_actor = self.l2_actor(x_actor)
#         x_actor = F.relu(x_actor)
#         action_scores = self.l3_actor(x_actor)
#         # action_probs = F.softmax(action_scores, dim=1)

#         x_critic = self.l1_critic(x)
#         x_critic = F.relu(x_critic)
#         x_critic = self.l2_critic(x_critic)
#         x_critic = F.relu(x_critic)
#         value = self.l3_critic(x_critic)

#         return action_scores, value

#     def get_action_and_value(self, x, action=None):
#         logits, value = self(x)
#         probs = Categorical(logits=logits)
#         if action is None:
#             action = probs.sample()
#         return action, probs.log_prob(action), probs.entropy(), value

## Training

### util classes

In [44]:
class Buffer():
    def __init__(self, env, batch_size, minibatch_size = None, gamma = 0.99, gae_lambda = 0.95):
        self.batch_size = batch_size
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        if minibatch_size is None:
            self.minibatch_size = batch_size
        else:
            self.minibatch_size = minibatch_size

        assert self.batch_size % self.minibatch_size == 0, "batch size must be evenly divisible by minibatch size"

        self.num_states = env.observation_space.shape[0]
        self.num_actions = env.action_space.n
        
        self.reset()
    
    def add(self, state, action, log_prob, reward, done, value):
        self.states[self.add_idx] = state
        self.actions[self.add_idx] = action
        self.log_probs[self.add_idx] = log_prob
        self.values[self.add_idx] = value
        self.rewards[self.add_idx] = reward
        self.dones[self.add_idx] = done

        self.add_idx += 1
        assert self.add_idx <= self.batch_size, "adding too many samples to buffer!"
        assert len(self) <= self.batch_size, "adding too many samples to buffer!"

    def reset(self):
        self.states = torch.zeros((self.batch_size, self.num_states))
        self.actions = torch.zeros(self.batch_size, dtype=int)
        self.log_probs = torch.zeros(self.batch_size)
        self.values = torch.zeros(self.batch_size)
        self.rewards = torch.zeros(self.batch_size)
        self.dones = torch.zeros(self.batch_size)

        self.advantages = torch.zeros(self.batch_size)
        self.returns = torch.zeros(self.batch_size)

        self.add_idx = 0

        self.shuffled_idxs = torch.zeros(self.batch_size, dtype=int)
        self.minibatch_idxs = torch.zeros(self.minibatch_size, dtype=int)

    def __len__(self):
        return len(self.states)

    ## Calculate with advantages
    # Advantage = gamma * V(s+1) + r - V(s)
    # return = gamma * v(s+1) + r
    # bellman V(s) = r + gamma*V(s+1)
    # def calculate_advantages(self, policy, next_state, next_done):
    #     with torch.no_grad():
    #         next_value = policy.get_value(next_state.float().unsqueeze(0))
    #         mask = 1 - next_done
    #         for i in reversed(range(len(self))):
    #             if i < len(self) - 1:
    #                 next_value = self.returns[i+1]
    #                 mask = 1 - self.dones[i+1]
    #             self.returns[i] = self.rewards[i] + mask * self.gamma * next_value

    #         self.advantages = self.returns - self.values

    ## Calculate with GAE
    def calculate_advantages(self, policy, next_state, next_done):
        with torch.no_grad():
            next_value = policy.get_value(next_state.float().unsqueeze(0))
            # for i in reversed(range(len(self))):
            #     if i < len(self) - 1:
            #         next_value = self.values[i + 1]
            #     self.returns[i] = self.rewards[i] + self.gamma * ~self.dones[i] * next_value

            # self.advantages = self.returns - self.values

            lastgaelam = 0
            for t in reversed(range(len(self))):
                if t == len(self) - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - self.dones[t + 1]
                    nextvalues = self.values[t + 1]
                delta = self.rewards[t] + self.gamma * nextvalues * nextnonterminal - self.values[t]
                self.advantages[t] = lastgaelam = delta + self.gamma * self.gae_lambda * nextnonterminal * lastgaelam
            self.returns = self.advantages + self.values

            assert(self.advantages.shape[0] == len(self)), "final adv sizes don't match (batch size: {} adv size {})".format(len(self), self.advantages.shape[0])

    def num_minibatches(self):
        return (int)(len(self) / self.minibatch_size)

    def shuffle_minibatches(self):
        self.minibatch_idx = 0

        
        self.shuffled_idxs = np.arange(len(self))
        np.random.shuffle(self.shuffled_idxs)
        
    def get_minibatch_idxs(self):
        start_idx = self.minibatch_idx * self.minibatch_size
        end_idx = (self.minibatch_idx+1) * self.minibatch_size
        self.minibatch_idxs = self.shuffled_idxs[start_idx:end_idx]
        self.minibatch_idx += 1
        return self.minibatch_idxs

    def print(self):
        for i in range(len(self)):
            print("[{}] s: {} a: {} r: {} d: {}".format(i, self.states[i], self.actions[i], self.rewards[i], self.dones[i]))

    def print_adv(self):
        for i in range(len(self)):
            print("[{}] r: {} d: {} value: {} returns: {} adv: {}".format(
                i, 
                self.rewards[i], 
                self.dones[i], 
                self.values[i], 
                self.returns[i],
                self.advantages[i]))

### RolloutGenerator

Fills the buffer with an exact number of steps with multiple rollouts.

Important:
the buffer will be filled with the following information

```
[state, action, reward, done, value]
```
and the transition will look like the following for specific entries

Normal
```
O (state, value, done)
|
| a, r
v 
O (state+1, value+1, done+1)
|
| a+1,r+1
v 
O
```

Terminal
```
 O (s_f, v_f, d_f=False)
(|)
(|) a_f,r_f
(v)
 O (s_1, v_1, d_1=True)
 |
 | a_1,r_1
 v
 O
 ```
Note that `done=True` is on the first state of the new sequence. This is because any actions and rewards found past the final state won't be correct or defined with respect to the environment.

In [None]:
class RolloutGenerator():
    def __init__(self, env, batch_size, minibatch_size, max_episode_steps, log):
        self.log = log
        self.max_episode_steps = max_episode_steps
        self.buffer = Buffer(env, batch_size, minibatch_size)
        
        self.episode_reward = 0
        self.episode_steps = 1
        self.num_episodes = 1

        self.next_state = torch.Tensor(env.reset())
        self.next_done = torch.Tensor([False])

    def fill_buffer(self, env, policy):
        self.buffer.reset()
        for step in range(0, self.buffer.batch_size):
            #      (state)
            #  (-->)  o
            state = self.next_state
            done = self.next_done

            #      (state, done)  r,a  (next_state, next_done)
            #  (-->)     o ---------------------> o
            with torch.no_grad():
                action, log_prob, _, value = policy.get_action_and_value(state.float().unsqueeze(0))
            next_state, reward, next_done, info = env.step(action.item())
            self.next_state, self.next_done = torch.Tensor(next_state), torch.Tensor([next_done])

            self.buffer.add(state, action, log_prob, torch.Tensor([reward]), done, value)

            self.episode_reward += reward
            self.episode_steps += 1

            # If episode is next_done or past max steps reset the env
            if self.next_done or self.episode_steps > self.max_episode_steps:
                if self.log:
                    wandb.log({
                        "episode_steps": self.episode_steps,
                        "episode_reward": self.episode_reward,
                        "num_episodes": self.num_episodes,
                    })

                self.num_episodes += 1
                self.episode_reward = 0
                self.episode_steps = 1
                
                # (next_state)
                #      o
                self.next_state = torch.Tensor(env.reset())
                
            if done:
                self.next_done = torch.Tensor([False])
        
        self.buffer.calculate_advantages(policy, self.next_state, self.next_done)

    def get_buffer(self):
        return self.buffer

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
torch.backends.cudnn.deterministic = True


env = gym.make("CartPole-v1")
policy = ActorCriticPolicy(env)
r = RolloutGenerator(env, 40, 10, 100, False)
r.fill_buffer(env, policy)
# r.buffer.print()
    
## data is added correctly
b = r.buffer
for i in range(len(b)):
    print("[{}]: s: {} a: {} r: {} d: {} v: {}".format(
        i, 
        b.states[i], 
        b.actions[i], 
        b.rewards[i], 
        b.dones[i],
        b.values[i]))
    
for i in range(len(b)):
    print("[{}] r: {} d: {} value: {} returns: {} adv: {}".format(
        i, 
        b.rewards[i], 
        b.dones[i], 
        b.values[i], 
        b.returns[i],
        b.advantages[i]))

### Training loop

In [38]:
def train(env_id, log, lr, batch_size, minibatch_size, max_episode_steps, n_epochs, eps = 0.2):
    if log: 
        name = "ppo_" + env_id
        wandb.init(project=name)

    env = gym.make(env_id)
    policy = ActorCriticPolicy(env)

    if log:
        wandb.watch(policy, log_freq=1)

    optimizer = optim.Adam(policy.parameters(), lr=lr, eps=1e-5)

    rollout = RolloutGenerator(env, batch_size, minibatch_size, max_episode_steps, log)

    # each epoch collects N steps regardless of episode length and trains
    for epoch in range(n_epochs):
        # this also calculates advantages
        rollout.fill_buffer(env, policy)

        for updates in range(4):
            rollout.buffer.shuffle_minibatches()

            if log:
                wandb.log({
                    "epoch": epoch,
                })

            # Go thru all minibatches
            for i in range(rollout.buffer.num_minibatches()):
                # sample minibatch idxs from buffer
                idxs = rollout.buffer.get_minibatch_idxs()

                #
                # Step 1: Sample current policy for new_probs and new_value
                #
                state = rollout.buffer.states[idxs]
                input = state.float()
                # If a single row, then unsqueeze to make a batch of 1
                if len(input.shape) == 1:
                    input = input.unsqueeze(0)

                _, new_log_prob, new_entropy, new_value = policy.get_action_and_value(input, rollout.buffer.actions[idxs])

                #
                # Step 2: Calculate L_clip
                #

                # Calculate r(t)
                # r(t) = pi(a, s) / pi_old(a, s) 
                #      = exp(logprob(pi(a,s)) - logprob(pi_old(a, s)))
                logratio = new_log_prob - rollout.buffer.log_probs[idxs]
                r = logratio.exp()
                
                # Find policy loss
                advantage = rollout.buffer.advantages[idxs]
                policy_loss1 = -advantage * r
                policy_loss2 = -advantage * torch.clamp(r, 1 - eps, 1 + eps)
                policy_loss = torch.max(policy_loss1, policy_loss2).mean()
                
                #
                # Step 3: Calculate L_vf
                #
                returns = rollout.buffer.returns[idxs]
                value_loss = 0.5 * ((new_value - returns)**2).mean()
                
                #
                # Step 4: Calculate L_entropy
                #

                # find L_entropy
                c_entropy = 0.01
                entropy_loss = new_entropy.mean()
                c_vf = 0.5
                
                #
                # Step 4: Train
                #
                loss = policy_loss - c_entropy * entropy_loss + c_vf * value_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # If done next step them reset env
                if log:
                    wandb.log({
                        "policy_loss": policy_loss,
                        "value_loss": value_loss,
                        "entropy_loss": entropy_loss,
                        "loss": loss,
                    })

        if record_vids and epoch % num_epochs_to_vid == 0:
            record_video(env, policy, "/content/out.mp4")


record_vids = False
log = False
batch_size = 8
minibatch_size = 4
max_episode_steps = 100
n_epochs = 1
lr = 1e-3
eps = 0.2
train("CartPole-v1", log, lr, batch_size, minibatch_size, max_episode_steps, n_epochs, eps)

In [None]:
log = True
record_vids = True
num_epochs_to_vid = 100
# env_id = "CartPole-v1"
# env_id = "LunarLander-v2"
env_id = "Pixelcopter-PLE-v0"
batch_size = 512
minibatch_size = 128
max_episode_steps = 1000
n_epochs = 500
lr = 2.5e-4
eps = 0.2
train(env_id, log, lr, batch_size, minibatch_size, max_episode_steps, n_epochs, eps)

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
entropy_loss,██▇▇▇▇▇▆▆▅▅▅▄▃▂▁
episode_reward,▅▁▅▁▁▁▅▁▁▁▁▅▁▁▁▁▁▅▁▁▁▁▁▁▁▅▁▁▁▁█▅▅▁▁▁▁▁▁█
episode_steps,▅▂▅▃▂▁▅▂▂▄▁▅▂▂▃▃▂▅▂▃▃▃▂▃▂▆▂▃▄▂█▅▅▂▂▂▄▃▂█
epoch,▁▁▁▁
loss,▇█▆▆▇▄▄▃▃▅▂▂▁▂▁▂
num_episodes,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
policy_loss,▄▆▃▃█▃▃▁▃█▁▃▁▅▃▆
value_loss,██▆▆▇▅▄▄▃▄▂▂▁▂▁▁

0,1
entropy_loss,0.69133
episode_reward,-3.0
episode_steps,17.0
epoch,0.0
loss,5.18056
num_episodes,66.0
policy_loss,3.48082
value_loss,3.4133
