# Proximal Policy Optimization (PPO) Clipped Objective in Pytorch

I am no expert on the Natural Policy or TROP. But the idea of Natural Policy and TROP are to control the step size. Natural policy try to find a consistent step size using qudratic approximation, and then TRPO add an extra layer to make sure the parameter update does not change the policy beyond a KL divergence threshold.

And the PPO is a simplified implementation that can achieve similiar performance as TRPO without the involved calculation of the Hessian matrix approximation or KL divergence

In [40]:
%matplotlib inline
from IPython import display
from IPython.display import HTML
import matplotlib.pyplot as plt
import matplotlib.animation as animation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Normal, Categorical

import numpy as np
import random
import os
import gym

In [41]:
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor

In [42]:
def save_torch_model(model, filename):
  if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
  torch.save(model.state_dict(), filename)

def load_torch_model(model, filename):
  model.load_state_dict(torch.load(filename))


In [43]:
class PolicyNet_discrete(nn.Module):
  def __init__(self, input_size, output_size):
    super(PolicyNet_continuous,self).__init__()
    self.l1_linear = nn.Linear(input_size, 64)
    self.l2_linear = nn.Linear(64, 32)
    self.l3_linear = nn.Linear(32, output_size)
    nn.init.kaiming_normal_(self.l1_linear.weight)
    nn.init.kaiming_normal_(self.l2_linear.weight)
    self.l3_linear.weight.data.zero_()
    
  def forward(self,x):
    out = F.relu(self.l1_linear(x))
    out = F.relu(self.l2_linear(out))
    out = F.softmax(F.sigmoid(self.l3_linear(out)))
    return out

In [44]:
class PolicyNet_continuous(nn.Module):
  def __init__(self, input_size, output_size):
    super(PolicyNet_continuous,self).__init__()
    self.l1_linear = nn.Linear(input_size, 64)
    self.l2_linear = nn.Linear(64, 32)
    self.l3_linear = nn.Linear(32, output_size)
    nn.init.kaiming_normal_(self.l1_linear.weight)
    nn.init.kaiming_normal_(self.l2_linear.weight)
    self.l3_linear.weight.data.zero_()
    
  def forward(self,x):
    out = F.relu(self.l1_linear(x))
    out = F.relu(self.l2_linear(out))
    out = F.tanh(self.l3_linear(out))
    return out

In [45]:
class ValueNet(nn.Module):
  def __init__(self, input_size):
    super(ValueNet,self).__init__()
    self.l1_linear = nn.Linear(input_size, 64)
    self.l2_linear = nn.Linear(64, 32)
    self.l3_linear = nn.Linear(32, 1)
    nn.init.kaiming_normal_(self.l1_linear.weight)
    nn.init.kaiming_normal_(self.l2_linear.weight)
    self.l3_linear.weight.data.zero_()

  def forward(self, x):
    out = F.relu(self.l1_linear(x))
    out = F.relu(self.l2_linear(out))
    out = self.l3_linear(out)
    return out

In [150]:
class PPO():
  def __init__(self, env, steps_in_state = 1):
    self.is_training = True
    self.state_value_range = [{'max':None, 'min':None}] * env.observation_space.shape[0]
    self.steps_in_state = steps_in_state
    self.actor = PolicyNet_continuous(env.observation_space.shape[0] * steps_in_state, env.action_space.shape[0])
    self.actor_prime = PolicyNet_continuous(env.observation_space.shape[0] * steps_in_state, env.action_space.shape[0])
    self.critic = ValueNet(env.observation_space.shape[0] * steps_in_state)
    if use_cuda:
      self.actor.cuda()
      self.actor_prime.cuda()
      self.critic.cuda()
    # copy the weight in actor, make sure actor and actor_prime start with same weights
    self.actor_prime.load_state_dict(self.actor.state_dict())
    self.env = env
    self.range_scale = (env.action_space.high[0] - env.action_space.low[0]) / 2.0
    self._gamma = 0.96
    self._epsilon = 0.2

  def pick_action(self, state):
    action = self.actor(state)
    # add noise
    if self.is_training:
      action_dist = Normal(action, 0.5)
      action = action_dist.sample()
    else:
      action = action.item()
    return np.clip(action, -1.0, 1.0)
  
  def update_ppo_clip(self, batch):
    (states, actions, rewards, next_states, ended) = zip(*batch)
    states_tensor = torch.stack(states)
    actions_tensor = FloatTensor(actions).view(-1,1)
    rewards_tensor = FloatTensor(rewards).view(-1,1)
    next_states_tensor = torch.stack(next_states)
    ended_tensor = FloatTensor(ended).view(-1,1)
    
    critic_loss = rewards_tensor + self._gamma * (1 - ended_tensor) * self.critic(next_states_tensor) - self.critic(states_tensor)
    advantage = []
    for delta in critic_loss.view(-1).tolist()[::-1]:
      adv = delta
      if len(advantage) > 0:
        adv += self._gamma * advantage[0]
      advantage.insert(0, adv)
    advantage_tensor = FloatTensor(advantage).view(-1,1)
                             
    # hardcode standard deviation for the action probability to 0.5
    action_dist_prime = Normal(self.actor_prime(states_tensor), FloatTensor([0.5]*len(batch)))
    action_dist_old = Normal(self.actor(states_tensor), FloatTensor([0.5]*len(batch)))
    
    action_prob_ratio = torch.exp(action_dist_prime.log_prob(actions_tensor) - action_dist_old.log_prob(actions_tensor))
    actor_loss = torch.min(action_prob_ratio * advantage_tensor, \
                           torch.clamp(action_prob_ratio, 1 - self._epsilon, 1 + self._epsilon) * advantage_tensor)
    
    actor_loss = -actor_loss.mean()
    self.actor_prime_optimizer.zero_grad()
    actor_loss.backward(retain_graph=True)
    self.actor_prime_optimizer.step()

    # mean square of critic_loss
    critic_loss = critic_loss * critic_loss
    critic_loss = critic_loss.mean()
    self.critic_coptimizer.zero_grad()
    critic_loss.backward(retain_graph=True)
    self.critic_coptimizer.step()
    
  def train(self, env, episode_limit=1000, batch_size=64, copy_on_episode=10, lr=1e-3, lr_actor=None, lr_critic=None, checkpoint=100):
    lr_actor = lr if lr_actor == None else lr_actor
    lr_critic = lr if lr_critic == None else lr_critic
    
    self.actor_prime_optimizer = torch.optim.Adam(self.actor_prime.parameters(), lr=lr_actor, weight_decay=1e-3)
    self.critic_coptimizer = torch.optim.Adam(self.critic.parameters(), lr=lr_critic, weight_decay=1e-3)
    
    best_score = -99999
    running_score = None
    self.iteration = 0
    for episode_count in range(episode_limit):
      s0 = env.reset()
      seq = [s0] * self.steps_in_state
      state = FloatTensor(seq).view(-1)
      episode_ended = False
      score = 0
      episode = []
      while not episode_ended:
        action =  self.pick_action(state)
        (s1, reward, episode_ended, info) = env.step([action * self.range_scale])
        seq = seq[1:]
        seq.append(s1)
        next_state = FloatTensor(seq).view(-1)
        if episode_ended:
          ended = 1
        else:
          ended = 0
        episode.append((state, action, reward, next_state, ended))
        s0 = s1
        state = next_state
        score += reward
        if len(episode) == batch_size or episode_ended:
          if (episode_ended):
            episode.pop()
          if len(episode) > 0:
            self.update_ppo_clip(episode)
          episode = []

      if running_score == None:
        running_score = score
      else:
        running_score = running_score * 0.9 + score * 0.1
      
      if (episode_count + 1) % checkpoint == 0 and running_score != None:
        if running_score > best_score:
          best_score = running_score
          save_torch_model(self.actor,'model/ppo_actor_best.pth')          
        save_torch_model(self.actor,'model/ppo_actor_iter_%d.pth' %(episode_count+1))
        print('%d: running_score:%.2f, ' %(episode_count+1, running_score))
        
      if  (episode_count + 1) % copy_on_episode == 0:
        self.actor.load_state_dict(self.actor_prime.state_dict())
        


In [151]:
env = gym.make('Pendulum-v0')
agent = PPO(env)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [54]:
agent.is_training = True
agent.train(env, episode_limit=1, batch_size=1, copy_on_episode=1, lr_actor=1e-4, lr_critic=1e-3, checkpoint=1)

1: running_score:-883.88, 


In [152]:
agent.is_training = True
agent.train(env, episode_limit=3000, batch_size=64, copy_on_episode=5, lr_actor=1e-4, lr_critic=1e-3, checkpoint=200)

200: running_score:-1215.53, 
400: running_score:-1080.34, 
600: running_score:-1174.50, 
800: running_score:-1147.52, 
1000: running_score:-1030.34, 
1200: running_score:-914.68, 
1400: running_score:-862.33, 
1600: running_score:-870.42, 
1800: running_score:-815.15, 
2000: running_score:-729.43, 
2200: running_score:-659.83, 
2400: running_score:-626.04, 
2600: running_score:-595.97, 
2800: running_score:-443.90, 
3000: running_score:-295.77, 


In [147]:
# agent.is_training = True
# agent.train(env, episode_limit=500, batch_size=64, copy_on_episode=5, lr_actor=1e-4, lr_critic=1e-4, checkpoint=50)

In [153]:
# use the final mean as solution and run a sample episode
agent.is_training = False
load_torch_model(agent.actor,'model/ppo_actor_best.pth')
state = env.reset()
frames = []
frames.append(env.render(mode='rgb_array'))
ended = False
score = 0
while not ended:
  action = agent.pick_action(FloatTensor([state]).view(-1))
  (state, reward, ended, info) = env.step([action*2])
  score += reward
  frames.append(env.render(mode='rgb_array'))
print(score)

-132.5008541134088


In [154]:
%%capture
def animate(frames):
  fig, ax = plt.subplots()
  ax.grid('off')
  ax.axis('off')
  ims = []
  for i in range(len(frames)):
      im = plt.imshow(frames[i], animated=True)
      ims.append([im])
  ani = animation.ArtistAnimation(fig, ims, interval=20, blit=True, repeat_delay=1000)
  return ani

ani = animate(frames)
ani.save('pendulum_ppo.mp4')

In [155]:
%%HTML
<video width="400" controls>
  <source src="pendulum_ppo.mp4" type="video/mp4">
</video>