## Intro to Proximal Policy Optimization 
The main idea behind Proximal Policy optimization(PPO) is that we want to improve the stability of our training by limiting the amount of change we make to the policy at each epoch: essentially, we want to avoid making too large of a change at each epoch. 

There are two problems we if we make too large of a policy change: 
* We know that smaller policy changes when training have a higher likelihood of converging to an optimial solution.
* To larger of a policy change can have us go past the optimal solution to a less optimal one.

PPO alleviates the problem of making too large of a policy change at each epoch by updating the policy more conservatively. In order to do that, we need to measure how much we are changing the current policy compared to the old. We clip policy updates in this range where we change the policy enough for our agent to learr, but not too much so that we regress $[1-\epsilon,1+\epsilon]$. 


## Intro to Clipped Surrogate Objective function 

The objective function that we are trying to optimize in Reinforce is the following: 

$L^{PG}(\theta) = E_t[log \pi_{\theta}(a_t|s_t) * A_t]$

If we take a gradient ascent step on this function, we make our agent get better rewards and reach closer to the goal. 

Like with any ML problem, our problem will be the step size: 
* If the step size is too small, our agent takes too long to train
* If the step size if stoo large, there will be too much noise when training.

With PPO, we change our objective function to a function called *Clipped surrogate objective function* which constrainsts our policy update to a small range. 

Here is the objective function for PPO: 

$L^{CLIP}(\theta) = E_t[min(r_t(\theta)\hat{A}_t,clip(r_t(\theta),1-\epsilon,1+\epsilon)\hat{A}_t)]$

Let's dissect the above function: 

### The ratio fucntion 

$r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t) }$

The above is simply the probability of taking action $a_t$ at state $s_t$ in the current policy divided by the probability of taking action $a_t$ at state $s_t$ in the old policy. 

* if $r_t(\theta)$ > 1, then action $a_t$ and state $s_t$ is more likely in the new policy than the old policy.
* If $r_t(\theta)$ is between 0 and 1, then the action is less likely in the current policy than the old policy.

As we can see, the ratio is a good way of measuring how much our current policy diverges from the old policy. 

### The unclipped part of the Clipped Surrogate Objective function 

We replace the ratio with the log probability in the objective function, which we then multiply with the advantage. We must still constraint the ratio since without a constraint we may stray too far from the old policy. 

### The clipped part of the Clipped Surrogate Objective function 
We need to constrain the objective function to penalize changes that stray too much from 1(the ideal range is between 0.8 and 1.2). By clipping the ratio, we ensure that we don't make a policy update that is too large since the current policy can't be too different than the old policy. 

We can approach clipping through two methods: 
* *TRPO (Trusted Region Policy Optimization)* which uses KL divergence constraints that are outside of the objective function to constrain the policy updates. This method is quite complex and takes addtional compute time.
* *PPO* clips probability ratio directly in the objective function with the **Clipped surrogate objective function**

The clipped part of the $r_t(\theta)$ is where $r_t(\theta)$ is clipped between $[1-\epsilon,1+\epsilon]$. 

When using the Clipped Surrogate function, we have two probability ratios, one non-clipped and one clipped in a range between $[1-\epsilon,1+\epsilon]$, where epsilon is a hyperparameter that helps define the clip range(typically, $\epsilon$ = 0.2).

We then take the minimum across the clipped and non-clipped objective, resulting in the final objective being a lower bound of the clipped objective. We'll either select the clipped or non-clipped objective based on the ratios and advantage. 

## Example: PPO from scratch

In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim 
from torch.distributions import Categorical 
import numpy as np
import gym 
import time
import random
from tqdm.auto import tqdm

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
env = gym.make("LunarLander-v2",render_mode='rgb_array')
env = gym.wrappers.RecordEpisodeStatistics(env)
observation, info = env.reset(seed=42)
for _ in range(200):
    action = env.action_space.sample()  # this is where you would insert your policy
    observation, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        print(f"Episodic Return: {info['episode']['r']}")
        observation, info = env.reset()
env.close()

Episodic Return: -466.5960388183594


  if not isinstance(terminated, (bool, np.bool8)):


In [5]:
def make_env(gym_id):
    def thunk():
        env = gym.make(gym_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        return env
    return thunk
envs = gym.vector.SyncVectorEnv([make_env("LunarLander-v2")])
observation,info= envs.reset()
for _ in range(200):
    action = envs.action_space.sample()
    observation, reward,terminated,truncated,info = envs.step(action)
    if info:
        print(f"episodic return {info['final_info'][0]['episode']['r']}")

episodic return -123.05753326416016


In [11]:
lr = 0.003
#lr= 0.01
num_steps = 1024
num_envs=16
batch_size = num_envs* num_steps
total_timesteps = 2500000
anneal_lr = True
gae = True
gamma = 0.999
gae_lambda = 0.98
num_minibatches = 64
update_epochs = 4
norm_adv = True
clip_coef = 0.2
clip_vloss = True
ent_coef = 0.01
vf_coef = 0.5
max_grad_norm = 0.5
target_kl = 0.015
n_hidden = 128
#n_hidden =64
#target_kl = None
minibatch_size = batch_size // num_minibatches

In [7]:
# Best model has 128 hidden
def layer_init(layer,std=np.sqrt(2),bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight,std)
    torch.nn.init.constant_(layer.bias,bias_const)
    return layer
class Agent(nn.Module):
    def __init__(self,envs):
        super(Agent,self).__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), n_hidden)),
            nn.Tanh(),
            layer_init(nn.Linear(n_hidden,n_hidden)),
            nn.Tanh(),
            layer_init(nn.Linear(n_hidden,1),std=1.)
        )
        self.actor = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(),n_hidden)),
            nn.Tanh(),
            layer_init(nn.Linear(n_hidden,n_hidden)),
            nn.Tanh(),
            layer_init(nn.Linear(n_hidden,envs.single_action_space.n),std=0.01)
        )
    def get_value(self,x):
        return self.critic(x)
    def get_action_and_value(self,x,action=None):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        if action is None: 
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(x)

In [15]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
envs = gym.vector.SyncVectorEnv(
    [make_env("LunarLander-v2") for i in range(num_envs)]
)
assert isinstance(envs.single_action_space,gym.spaces.Discrete), "only discrete action space is supported"
#print("envs.single_observation_space.shape", envs.single_observation_space.shape)
#print("envs.single_action_space.shape.n", envs.single_action_space.n)
agent = Agent(envs).to(device)
#print(agent)
optimizer = optim.Adam(agent.parameters(),lr=lr,eps=1e-5)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer,gamma =0.9)

obs = torch.zeros((num_steps,num_envs) + envs.single_observation_space.shape).to(device)
actions  = torch.zeros((num_steps,num_envs)+envs.single_action_space.shape).to(device)
logprobs = torch.zeros((num_steps,num_envs)).to(device)
rewards = torch.zeros((num_steps,num_envs)).to(device)
dones = torch.zeros((num_steps,num_envs)).to(device)
values = torch.zeros((num_steps,num_envs)).to(device)

global_step = 0
start_time = time.time()
next_obs = torch.Tensor(envs.reset()[0]).to(device)
next_done  = torch.zeros(num_envs).to(device)
num_updates = total_timesteps // batch_size

for update in tqdm(range(1,num_updates+1)):
    if anneal_lr:
        frac = 1.0- (update-1.0) /num_updates
        lrnow = frac * lr
        optimizer.param_groups[0]['lr'] = lrnow
        
    for step in range(0, num_steps):
        global_step +=1* num_envs
        obs[step] = next_obs
        dones[step] = next_done

        with torch.no_grad():
            action, logprob, _, value = agent.get_action_and_value(next_obs)
            values[step] = value.flatten()
        actions[step] = action
        logprobs[step] = logprob

        next_obs,reward,done,truncated,info = envs.step(action.cpu().numpy())
        rewards[step] = torch.tensor(reward).to(device).view(-1)
        next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
        if info:
            for item in info['final_info']:
                if item is not None:
                    print(f"global_step={global_step}, episodic_return={item['episode']['r']}")
                    break
    with torch.no_grad():
        next_value = agent.get_value(next_obs).reshape(1,-1)
        if gae:
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(num_steps)):
                if t == num_steps -1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0- dones[t+1]
                    nextvalues = values[t+1]
                delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values
        else: 
            returns = torch.zeros_like(rewards).to(device)
            for t in reversed(range(num_steps)):
                if t == num_steps -1:
                    nextnonterminal = 1.0 - next_done
                    next_return = next_value
                else: 
                    nextnonterminal = 1.0 - dones[t+1]
                    next_return = returns[t+1]
                returns[t] = rewards[t] + gamma* nextnonterminal * next_return
            advantages = returns-values
    #Flatten batch
    b_obs = obs.reshape((-1,) +envs.single_observation_space.shape)
    b_logprobs = logprobs.reshape(-1)
    b_actions = actions.reshape((-1,)+ envs.single_action_space.shape)
    b_advantages = advantages.reshape(-1)
    b_returns = returns.reshape(-1)
    b_values = values.reshape(-1)
    # optimizing the policy and value net 
    b_inds = np.arange(batch_size)
    clipfracs = []
    for epoch in range(update_epochs):
        np.random.shuffle(b_inds)
        for start in range(0,batch_size, minibatch_size):
            end = start + minibatch_size
            mb_inds = b_inds[start:end]

            _, newlogprob, entropy, newvalue= agent.get_action_and_value(
                b_obs[mb_inds],b_actions.long()[mb_inds]
            )
            logratio = newlogprob -b_logprobs[mb_inds]
            ratio = logratio.exp()

            with torch.no_grad():
                # calculate approx_kl 
                approx_kl = ((ratio -1) - logratio).mean()
                clipfracs+= [((ratio -1.0).abs() > clip_coef).float().mean()]
            mb_advantages = b_advantages[mb_inds]
            if norm_adv: 
                mb_advantages = (mb_advantages- mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
            #Policy loss 
            pg_loss1 = -mb_advantages * ratio
            pg_loss2 = -mb_advantages * torch.clamp(ratio,1-clip_coef,1+clip_coef)
            pg_loss = torch.max(pg_loss1,pg_loss2).mean()
            #value loss
            newvalue = newvalue.view(-1)
            if clip_vloss: 
                v_loss_unclipped = (newvalue - b_returns[mb_inds]) **2
                v_clipped = b_values[mb_inds] + torch.clamp(
                    newvalue-b_values[mb_inds],
                    -clip_coef,
                    clip_coef
                )
                v_loss_clipped = (v_clipped - b_returns[mb_inds])**2
                v_loss_max = torch.max(v_loss_unclipped,v_loss_clipped)
                v_loss = 0.5*v_loss_max.mean()
            else:
                v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
            entropy_loss = entropy.mean()
            loss = pg_loss-ent_coef * entropy_loss + v_loss * vf_coef

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(agent.parameters(),max_grad_norm)
            optimizer.step()
        if target_kl is not None: 
            if approx_kl > target_kl:
                break
        scheduler.step()
    y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
    var_y  = np.var(y_true)
    explained_var = np.nan if var_y ==0 else 1- np.var(y_true - y_pred) / var_y

  0%|          | 0/152 [00:00<?, ?it/s]

global_step=1040, episodic_return=-44.8292350769043
global_step=1056, episodic_return=-109.95637512207031
global_step=1088, episodic_return=-92.43936920166016
global_step=1120, episodic_return=-98.76458740234375
global_step=1216, episodic_return=-71.72322082519531
global_step=1248, episodic_return=-114.35911560058594
global_step=1280, episodic_return=-171.85198974609375
global_step=1440, episodic_return=-155.98121643066406
global_step=1536, episodic_return=-222.99221801757812
global_step=1552, episodic_return=-322.8173828125
global_step=1680, episodic_return=-0.8139724731445312
global_step=1696, episodic_return=-327.44842529296875
global_step=1744, episodic_return=-135.767333984375
global_step=1840, episodic_return=-462.30316162109375
global_step=1856, episodic_return=-55.040279388427734
global_step=2208, episodic_return=-68.54396057128906
global_step=2320, episodic_return=-107.22386932373047
global_step=2416, episodic_return=-214.29136657714844
global_step=2480, episodic_return=-106.9

## Saving model

In [12]:
from pathlib import Path 
model_path = Path()
model_dir = model_path / "LunarLander.pth"
if model_dir.exists():
    print("Best model already saved")
else:
    print(f"Saving model to: {model_dir}")
    torch.save(obj = agent.state_dict(),
              f = model_dir)

Best model already saved


## Loading model

In [13]:
envs = gym.vector.SyncVectorEnv(
    [make_env("LunarLander-v2") for i in range(num_envs)]
)
assert isinstance(envs.single_action_space,gym.spaces.Discrete), "only discrete action space is supported"
#print("envs.single_observation_space.shape", envs.single_observation_space.shape)
#print("envs.single_action_space.shape.n", envs.single_action_space.n)
agent = Agent(envs).to(device)
agent.load_state_dict(torch.load(f=model_dir))

<All keys matched successfully>

In [19]:
env = gym.make("LunarLander-v2",render_mode='human')
env = gym.wrappers.RecordEpisodeStatistics(env)
observation, info = env.reset()
terminated= False
while not terminated:
    with torch.no_grad():
        action, _, _, _ = agent.get_action_and_value(torch.tensor(observation).to(device))  # this is where you would insert your policy
    observation, reward, terminated, truncated, info = env.step(action.cpu().item())
    terminated = terminated or truncated
    if terminated:
        print(f"Episodic Return: {info['episode']['r']}")
env.close()

Episodic Return: 281.0975036621094


In [17]:
# Evaluate Agent
from tqdm.auto import tqdm
env = gym.make("LunarLander-v2",render_mode='human')
env = gym.wrappers.RecordEpisodeStatistics(env)
scores = []
for i in tqdm(range(10)):
    observation, info = env.reset()
    terminated =False
    while not terminated:
        with torch.no_grad():
            action, _, _, _ = agent.get_action_and_value(torch.tensor(observation).to(device))
        observation, reward, terminated, truncated, info = env.step(action.cpu().item())
        terminated = terminated or truncated
        if terminated or truncated:
            scores.append(float(info['episode']['r']))
env.close()
print(np.mean(scores))

  0%|          | 0/10 [00:00<?, ?it/s]

257.85936126708987


## Playing Doom 