# **Proximal Policy Optimization**

In [policy_gradient](https://github.com/kueiwen/reinforcement-learning/blob/main/policy_gradient.ipynb), we have discussed that Proximal Policy Optimization (PPO) is an algorithm based on policy gradient.

PPO is an improved algorithm of TRPO (Trust Region Policy Optimzation), the main difference between TRPO and PPO is constraint function.

In trust_region_policy_gradient, the surrogate objective function is maximized subject to a constraint on the size of policy update.

$$\max_{\theta}\hat{E}_{t}[\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\hat{A}_{t}(s_t,a_t)]$$

$$\text{subject to  }\hat{E}_{t}[D_{\text{KL}}(\pi_{\theta_{\text{old}}}(.|s)||\pi_{\theta}(.|s))]\leq\delta$$

The theory justifying TRPO suggests using a penalty instead of constraint,

$$\max_{\theta}\hat{E}_{t}[\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\hat{A}_{t}(s_t,a_t)]-\beta(D_{\text{KL}}(\pi_{\theta_{\text{old}}}(.|s)||\pi_{\theta}(.|s)))$$

$\beta$ s the coefficient of penalty; however, it is hard to find the best $\beta$ for a reinforcement learning problem to achieve the goal of a first-order algorithm that the monotonic improvement of TRPO.

## **Clipped surrogate objective**

Let $r_t(\theta)$ denote the probability ratio $r_t(\theta)=\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}}$ 

Then TRPO surrogate objective function can be expressed as 

$$L^{\text{CPI}}(\theta)=\hat{E}_{t}[\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\hat{A}_{t}(s_t,a_t)]=\hat{E}_{t}[r_t(\theta)\hat{A}_t]$$

*CPI* refers to conservative policy iteration.

If $r_t(\theta)>1$, the action $a_t$ at state $s_t$ is more likely in the current policy than the old policy.     
If $r_t(\theta)$ is between $0$ and $1$, the action $a_t$ at state $s_t$ is less likely in the current policy than the old policy.

Without a constraint, maximization of $L^{\text{CPI}}$ would lead to an excessively large policy update, so PPO changes to constrain the policy that move $r_t(\theta)$ away from 1.

$$L^{\text{CLIP}}(\theta)=\hat{E}_t[\min(r_t(\theta)\hat{A}_t,\text{clip}(r_t(\theta),1-\epsilon,1+\epsilon)\hat{A}_t)]$$

$\epsilon$ is a hyperparameter, usually set to be 0.2.

The first part in *min* is unclipped part $L^{\text{CPI}}$, the second part, $\text{clip}(r_t(\theta),1-\epsilon,1+\epsilon)\hat{A}_t$, modifies the surrogate objective by clipping the probability ratio within $[1-\epsilon,1+\epsilon]$. Lastly, taking the minimum of the clipped and unclipped objective, whcih result in a lower bound on the unclipped objective.

Note that the probability ratio $r$ is clipped at $1-\epsilon$ or $1+\epsilon$, no matter $\hat{A}_t$ is positive or negative.

<img src="img/ppo_clip.png" width="600"> 

|  $r_t(\theta)$   | $A_t$  | Return of $\min$ | Clipped | Objective | Gradient | 
|  ----  | ----  | ----  | ----  | ----  | ----  |
| $r_t(\theta)\in[1-\epsilon,1+\epsilon]$  | + | $r_t(\theta)A_t$ | No | + | Yes | 
| $r_t(\theta)\in[1-\epsilon,1+\epsilon]$  | - | $r_t(\theta)A_t$ | No | - | Yes | 
| $r_t(\theta)<1-\epsilon$  | + | $r_t(\theta)A_t$ | No | + | Yes | 
| $r_t(\theta)<1-\epsilon$  | - | $(1-\epsilon)A_t$ | Yes | - | 0 | 
| $r_t(\theta)>1+\epsilon$  | + | $(1+\epsilon)A_t$ | Yes | + | 0 | 
| $r_t(\theta)>1+\epsilon$  | - | $r_t(\theta)A_t$ | No | - | Yes | 

If the advantage $A_t$ is positive, which means that the action at that state is better than other actions, so it is encouraged to increase the probability to take that action; in contrast, If the advantage $A_t$ is negative, which means that the action at that state is worse than other actions, so it is disencouraged to increase the probability to take that action

## **Adaptive KL panelty coefficient**

* Using several epochs of minibatch SGD, optimize the KL-penalized objective

$$L^{\text{KLPEN}}(\theta)=\hat{E}_t[\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\hat{A}_{t}(s_t,a_t)]-\beta(D_{\text{KL}}(\pi_{\theta_{\text{old}}}(.|s)||\pi_{\theta}(.|s)))$$

* Compute $d=\hat{E}_t[D_{\text{KL}}(\pi_{\theta_{\text{old}}}(.|s)||\pi_{\theta}(.|s))]$

    * If $d<d_{\text{targ}}/1.5$, $\beta\leftarrow\beta /2$
    * If $d>d_{\text{targ}}\times 1.5$, $\beta\leftarrow\beta\times 2$


With this scheme, we can see that policy updates is significantly different from $d_{\text{targ}}$. The parameter $1.5$ and $2$ are chosen heuristically, but the algorithm is not sensitive to them.

## **Algorithm**

***PPO-Clip***

---
**Input**: initial policy parameter $\theta_0$, initial value function parameter $\phi_0$

**repeat for k in 0,1,...,L**  
    $\quad$ Collect set of trajectory $D_k={\tau_i}$ by running policy $\pi_k=\pi(\theta_k)$ in the environment       
    $\quad$ Compute rewards-to-go $\hat{R}_t$        
    $\quad$ Compute advantage estimates, $\hat{A}_t$ (using any method of advantage estimation) based on the current value function $V_{\phi_k}$   
    $\quad$ Update the policy by maximizing the PPO-Clip objective, typically via stochastic gradient ascent with Adam    
    $\quad\quad\theta_{k+1}=\arg\max_{\theta}=\frac{1}{|D_k|T}\sum_{\tau\in D_k}\sum_{t=0}^T\min(\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\hat{A}^{\pi_{\theta_k}}(s_t,a_t),g(\epsilon,\hat{A}^{\pi_{\theta_k}}(s_t,a_t)))$      
    $\quad$ Fit value function by regression on mean-square error, typically via some gradient descent algorithm    
    $\quad\quad\phi_{k+1}=\arg\min_{\phi}\frac{1}{|D_k|T}\sum_{\tau\in D_k}\sum_{t=0}^T(V_{\phi}(s_t)-\hat{R}_t)^2$   
**end** 

---

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import scipy.optimize # For L-BFGS
import numpy as np
from typing import Tuple, Callable, Dict, Any

# Precompute constant
LOG_2_PI = np.log(2 * np.pi)

class Policy(nn.Module):
    """
    A Gaussian policy network for continuous action spaces.

    Args:
        input_dim: Dimension of the state space.
        hidden_dim: Dimension of the hidden layers.
        output_dim: Dimension of the action space.
    """
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super(Policy, self).__init__()
        self.inputLayer = nn.Linear(input_dim, hidden_dim)
        self.hiddenLayer = nn.Linear(hidden_dim, hidden_dim)
        self.outputLayer = nn.Linear(hidden_dim, output_dim)

        self.outputLayer.weight.data.uniform_(-0.003, 0.003)
        self.outputLayer.bias.data.uniform_(-0.003, 0.003)

        # Learnable log standard deviation by nn.Parameter
        self.log_std = nn.Parameter(torch.zeros(1, output_dim))
        # Clamping log_std can improve stability
        self.log_std_min = -20
        self.log_std_max = 2


    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass to get action distribution parameters.

        Args:
            state: Input state tensor.

        Returns:
            A tuple containing:
            - action_mean: Mean of the action distribution.
            - action_log_std: Log standard deviation of the action distribution.
            - action_std: Standard deviation of the action distribution.
        """
        x = torch.tanh(self.inputLayer(x))
        x = torch.tanh(self.hiddenLayer(x))
        action_mean = self.outputLayer(x)

        # Clamp log_std for stability
        self.log_std.data.clamp_(self.log_std_min, self.log_std_max)

        action_log_std = self.log_std.expand_as(action_mean)
        action_std = torch.exp(action_log_std)
        return action_mean, action_log_std, action_std
    
    def get_log_probability_density(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """
        Calculate the log probability density of actions under the policy.

        Args:
            states: State tensor.
            actions: Action tensor.

        Returns:
            Log probability density for each state-action pair.
        """
        action_mean, action_log_std, action_std = self.forward(states)
        var = torch.exp(action_log_std).pow(2)
        log_prob_per_dim = -0.5 * (((actions - action_mean) / action_std)**2) \
                           - action_log_std \
                           - 0.5 * LOG_2_PI
        return log_prob_per_dim
    
    def get_KL_divergence(self, states: torch.Tensor, actions: torch.Tensor, old_log_prob: torch.Tensor) -> torch.Tensor:
        """
        Estimate the KL divergence D_KL(old_policy || current_policy) using samples.
        Assumes 'old_log_prob' contains log probabilities from the sampling policy.

        Args:
            states: State tensor.
            actions: Action tensor (sampled from the old policy).
            old_log_prob: Log probability of the actions under the old policy.

        Returns:
            Mean KL divergence estimate.
        """
        current_log_prob = self.get_log_probability_density(states, actions)
        kl_div = old_log_prob - current_log_prob
        return kl_div.mean()
    
    def get_action(self, state: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
        """
        Sample or get the mean action from the policy.

        Args:
            state: Input state tensor (should be preprocessed, e.g., unsqueezed).
            deterministic: If True, return the mean action. Otherwise, sample.

        Returns:
            Action tensor.
        """
        with torch.no_grad(): # No need to track gradients for action selection
            action_mean, _, action_std = self.forward(state)
            if deterministic:
                return action_mean
            else:
                normal = torch.distributions.normal.Normal(action_mean, action_std)
                return normal.sample()
    

class Value(nn.Module):
    """
    A simple MLP value function network.

    Args:
        input_dim: Dimension of the state space.
        hidden_dim: Dimension of the hidden layers.
    """
    def __init__(self, input_dim: int, hidden_dim: int):
        super(Value, self).__init__()
        self.hidden1 = nn.Linear(input_dim, hidden_dim)
        self.hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.value_head = nn.Linear(hidden_dim, 1)

        self.value_head.weight.data.uniform_(-0.003, 0.003)
        self.value_head.bias.data.uniform_(-0.003, 0.003)


    def forward(self, x):
        x = torch.FloatTensor(x).unsqueeze(0)
        x = torch.tanh(self.hidden1(x))
        x = torch.tanh(self.hidden2(x))
        value = self.value_head(x)
        return value
    

class ActorCritic(nn.Module):
    """
    Actor-Critic architecture combining policy and value networks.

    Args:
        input_dim: Dimension of the state space.
        hidden_dim: Dimension of the hidden layers.
        output_dim: Dimension of the action space.
    """
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super(ActorCritic, self).__init__()
        self.actor = Policy(input_dim, hidden_dim, output_dim)
        self.critic = Value(input_dim, hidden_dim)

    def set_action_std(self, action_std: float):
        """
        Set the action standard deviation for the policy.

        Args:
            action_std: Standard deviation for the action distribution.
        """
        self.actor.log_std.data.fill_(np.log(action_std))

    def get_action(self, state: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
        """
        Get action from the policy.

        Args:
            state: Input state tensor.
            deterministic: If True, return the mean action. Otherwise, sample.

        Returns:
            Action tensor.
        """
        action = self.actor.get_action(state, deterministic)
        log_prob = self.actor.get_log_probability_density(state, action).sum(dim=1, keepdim=True)
        state_val = self.critic(state)
        return action.detach(), log_prob.detach(), state_val.detach()
    
    def evaluate(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Evaluate the policy and value for a given state-action pair.

        Args:
            state: Input state tensor.
            action: Action tensor.

        Returns:
            A tuple containing:
            - log_prob: Log probability of the action under the policy.
            - state_val: Value of the state from the critic.
            - entropy: Entropy of the action distribution.
        """
        log_prob = self.actor.get_log_probability_density(state, action)
        entropy = -(log_prob * torch.exp(log_prob)).sum(dim=1, keepdim=True)
        state_val = self.critic(state)
        return log_prob.sum(dim=1, keepdim=True), state_val, entropy

In [6]:
import gym

env = gym.make("Pendulum-v0")
num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]

In [None]:
hidden_dim = 64
gamma = 0.99 # Discount factor
tau = 0.95 # GAE lambda parameter
epsilon = 0.2 # PPO clipping parameter
policy_net = Policy(num_inputs, hidden_dim, num_actions)
value_net = Value(num_inputs, hidden_dim, num_actions)
n_episodes = 1000 # Example total episodes
batch_size = 4000  # Target number of steps per policy update batch
max_episode_steps = 1000 # Max steps per episode
log_interval = 100 # How often to print logs
rewards = []

def update_policy(batch: Dict[str, Any]):
    """
    Updates the policy and value networks using PPO.

    Args:
        batch: A dictionary containing 'states', 'actions', 'rewards', 'mask'.
               Assumes data corresponds to a single trajectory or episode.
    """
    # --- 1. Data Preparation ---
    # Consider device placement (.to(device)) if using GPU
    # Using torch.as_tensor is generally safer than torch.FloatTensor
    # Avoid squeeze(0) unless batch structure guarantees dim 0 is size 1.
    # Assume batch contains data for N steps: [N, state_dim], [N, action_dim], etc.
    try:
        states = torch.as_tensor(batch["states"], dtype=torch.float32)
        actions = torch.as_tensor(batch["actions"], dtype=torch.float32)
        rewards = torch.as_tensor(batch["rewards"], dtype=torch.float32)
        # Ensure masks are treated as floats for multiplication
        masks = torch.as_tensor(batch["mask"], dtype=torch.float32)
    except KeyError as e:
        print(f"Error: Batch dictionary missing key: {e}")
        return
    except Exception as e:
        print(f"Error processing batch tensors: {e}")
        return
    
    # Validate shapes - assuming [N, dim] format after potential loading squeeze
    if states.dim() == 1: states = states.unsqueeze(0) # Handle single step case
    if actions.dim() == 1: actions = actions.unsqueeze(0)
    if rewards.dim() == 1: rewards = rewards.unsqueeze(0)
    if masks.dim() == 1: masks = masks.unsqueeze(0)

    # Ensure rewards and masks have a trailing dimension for broadcasting if needed
    if rewards.dim() == 1: rewards = rewards.unsqueeze(-1) # Shape [N, 1]
    if masks.dim() == 1: masks = masks.unsqueeze(-1)     # Shape [N, 1]
    if actions.dim() == 1: actions = actions.unsqueeze(-1) # Shape [N, 1] if action_dim is 1

    
    # --- 2. Value Function Estimation ---
    with torch.no_grad(): # No gradients needed for calculating targets
        # Typo corrected: squeeeze -> squeeze
        # Use squeeze(-1) if value_net outputs [N, 1], or just ensure output is [N]
        values = value_net(states).squeeze(0) # Assuming output [N, 1] -> [N]


    # --- 3. GAE and Returns Calculation ---
    num_steps = rewards.size(0)
    returns = torch.zeros_like(rewards)     # Use zeros_like for correct shape/device/dtype
    deltas = torch.zeros_like(rewards)
    advantages = torch.zeros_like(rewards)

    prev_return = 0.0
    prev_value = 0.0
    prev_advantage = 0.0
    for i in reversed(range(num_steps)):
    # Ensure rewards[i], masks[i], values[i] are scalars or broadcastable
        # Using .item() might be safer if shapes are guaranteed [1], but indexing should work for [N]
        current_reward = rewards[i]
        current_mask = masks[i]
        current_value = values[i] # From value_net(states) calculated earlier

        # Calculate return G(t) = r_t + gamma * G(t+1) * mask
        returns[i] = current_reward + gamma * prev_return * current_mask

        # Calculate TD error (delta) = r_t + gamma * V(s_{t+1}) * mask - V(s_t)
        # Note: prev_value holds V(s_{t+1}) from the previous iteration
        deltas[i] = current_reward + gamma * prev_value * current_mask - current_value

        # Calculate GAE advantage A(t) = delta_t + gamma * tau * A(t+1) * mask
        advantages[i] = deltas[i] + gamma * tau * prev_advantage * current_mask

        # Update values for the next iteration (t-1)
        # Use detach() instead of .data if accessing tensors that might have history
        prev_return = returns[i].item() # Use .item() for scalar python number
        prev_value = current_value.item() # V(s_t) becomes V(s_{t+1}) for next step
        prev_advantage = advantages[i].item()


    

for i_episode in range(n_episodes):
    # Data storage for the current batch (will collect multiple episodes)
    batch_states = []
    batch_actions = []
    batch_rewards = []
    batch_masks = [] # Represents (1 - done)

    steps_in_batch = 0
    episodes_in_batch = 0
    total_reward_in_batch = 0.0

    # Collect experience until batch_size is reached
    while steps_in_batch < batch_size:
        state = env.reset()
        # Ensure state is in the format expected by policy_net (e.g., numpy array)
        # If policy_net expects a tensor, convert here:
        # state_tensor = torch.from_numpy(state).float().unsqueeze(0)

        episode_reward = 0.0
        episode_steps = 0

        # Temporary storage for the current episode's trajectory
        episode_states = []
        episode_actions = []
        episode_rewards = []
        episode_masks = []

        for t in range(max_episode_steps):
            # 1. Get Action
            # Ensure state format matches policy_net.get_action input requirement
            # Assuming get_action returns a tensor
            state_tensor = torch.from_numpy(state.reshape(-1)).float().unsqueeze(0)
            action_tensor = policy_net.get_action(state_tensor)
            # Convert action to numpy for the environment step if needed
            action_numpy = action_tensor.detach().cpu().numpy() # Adjust based on env requirements

            # 2. Step Environment
            # Ensure env.step returns consistent types (usually numpy for state/reward)
            next_state, reward, done, _ = env.step(action_numpy)

            # 3. Store Transition Data (using consistent types, e.g., numpy for states)
            episode_states.append(state.reshape(-1)) # Store original state (numpy)
            episode_actions.append(action_tensor) # Store action tensor
            episode_rewards.append(reward) # Store reward (float/numpy)
            episode_masks.append(1.0 - float(done)) # Store mask (float)

            state = next_state # Update state for next iteration
            episode_reward += reward
            episode_steps += 1

            if done:
                break

        # End of episode: Append episode data to the main batch lists
        batch_states.extend(episode_states)
        batch_actions.extend(episode_actions)
        batch_rewards.extend(episode_rewards)
        batch_masks.extend(episode_masks)

        # Update batch counters
        steps_in_batch += episode_steps
        total_reward_in_batch += episode_reward
        episodes_in_batch += 1

        # Store the reward of the *last completed* episode for logging
        last_episode_reward = episode_reward

    # --- Batch Finalization and Policy Update ---
    # Calculate average reward per episode in this batch
    avg_reward_per_episode = total_reward_in_batch / episodes_in_batch if episodes_in_batch > 0 else 0.0
    total_steps_processed += steps_in_batch

    # Prepare batch dictionary for update_policy
    # Convert lists of data points into single tensors
    # Ensure correct dtypes and device placement (.to(device)) if using GPU
    update_batch = {
        "states": torch.tensor(np.asarray(batch_states), dtype=torch.float32),
        "actions": torch.stack(batch_actions), # Stack list of action tensors
        "rewards": torch.tensor(batch_rewards, dtype=torch.float32).unsqueeze(1), # Add dim for [N, 1]
        "mask": torch.tensor(batch_masks, dtype=torch.float32).unsqueeze(1)      # Add dim for [N, 1]
        # Note: "next_states" is often not needed directly by GAE/TRPO update,
        # but if it were, you'd collect and tensorize it similarly.
    }

    # Call the policy update function
    update_policy(update_batch) # Pass the correctly formatted batch
    rewards.append(avg_reward_per_episode[0])
    # --- Logging ---
    if i_episode % log_interval == 0:
        print(f'Episode {i_episode}\tSteps Collected: {steps_in_batch}\t'
              f'Last Ep Reward: {last_episode_reward[0]:.2f}\t'
              f'Avg Batch Ep Reward: {avg_reward_per_episode[0]:.2f}')
