# Proximal Policy Optimization - PPO
PPO is a policy gradient method for reinforcement learning. Simple policy gradient methods do a single gradient update per sample (or a set of samples). Doing multiple gradient steps for a single sample causes problems because the policy deviates too much, producing a bad policy. PPO lets us do multiple gradient updates per sample by trying to keep the policy close to the policy that was used to sample data. It does so by clipping gradient flow if the updated policy is not close to the policy used to sample the data.

In [2]:
import torch
import torch.nn as nn


In [3]:
class PPOLoss(nn.Module):
    """
    Policy gradient loss
    """

    def __init__(self):
        super(PPOLoss, self).__init__()

    def forward(
        self,
        log_pi: torch.Tensor,
        sampled_log_pi: torch.Tensor,
        advantage: torch.Tensor,
        clip: float,
    ):
        ratio = torch.exp(log_pi - sampled_log_pi)
        clipped_ratio = torch.clamp(ratio, 1 - clip, 1 + clip)
        policy_reward = torch.min(ratio * advantage, clipped_ratio * advantage)
        self.clip_fraction = (abs(ratio - 1.0) > clip).to(torch.float).mean()
        return -policy_reward.mean()


In [5]:
class ClippedValueFunctionLoss(nn.Module):
    def __init__(self):
        super(ClippedValueFunctionLoss, self).__init__()

    def forward(
        self,
        value: torch.Tensor,
        sampled_value: torch.Tensor,
        sampled_return: torch.Tensor,
        clip: float,
    ):
        clipped_value = sampled_value + (value - sampled_value).clamp(-clip, clip)
        vf_loss = torch.max(
            (value - sampled_return) ** 2, (clipped_value - sampled_return) ** 2
        )
        return 0.5 * vf_loss.mean()
