**PPO Base is a minimal, research-faithful implementation of Proximal Policy Optimization (PPO) in PyTorch.**
This repository focuses on the core PPO algorithm as described in the original paper, including the clipped surrogate objective, Generalized Advantage Estimation (GAE), and an on-policy actor–critic training loop. The goal is clarity and correctness rather than performance optimizations, making this implementation easy to read, extend, and use as a reference for understanding PPO from first principles

Implementing the PPO  require:
. actor-critic network
. rollout buffer
. GAE
. PPO clipped objective
. update step





In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
from typing import Tuple
from dataclasses import dataclass



**Policy (Actor) & Value Function (Critic)**
What it is

Actor: learns what action to take

Critic: learns how good a state is

PPO is an actor–critic algorithm.

In math (paper)

**Policy: πθ(a | s)**

**Value: Vφ(s)**



In [5]:
class Actor(nn.Module):
  def __init__(self, state_dim, action_dim):
    super().__init__()

    self.net = nn.Sequential(
        nn.Linear(state_dim, 64),
        nn.Tanh(),
        nn.Linear(64, 64),
        nn.Tanh(),

    )

    self.policy_head = nn.Linear(64, action_dim)

  def forward(self, state):
   x = self.net(state)
   logits = self.policy_head(x)
   dist = Categorical(logits=logits)
   return dist


In [9]:
class Critic(nn.Module):
  def __init__(self, state_dim):
    super().__init__()

    self.net = nn.Sequential(
        nn.Linear(state_dim, 64),
        nn.Tanh(),
        nn.Linear(64,64),
        nn.Tanh(),
        nn.Linear(64,1)

    )

  def forward(self, state):
     return self.net(state).squeeze(-1)




**What is a Rollout Buffer?**  
A rollout buffer is a temporary storage that collects experiences generated by running the current policy in the environment for a fixed number of steps.

In [7]:
class RolloutBuffer:
  def __init__(self):
    self.states = []
    self.actions = []
    self.logprobs = []
    self.values = []
    self.rewards = []
    self.dones = []


def add(self, state, action, logprob, value, reward, done):
   self.states.append(state)
   self.actions.append(action)
   self.logprobs.append(logprob)
   self.values.append(value)
   self.rewards.append(reward)
   self.dones.append(done)


def clear(self):
  self.states.clear()
  self.actions.clear()
  self.logprobs.clear()
  self.values.clear()
  self.rewards.clear()
  self.dones.clear()




GAE estimates how much better an action was than expected by combining immediate surprise and future surprise, with controlled decay.

It balances:

**Bias (being wrong but stable)**

**Variance (being correct but noisy)**  

In [1]:
def compute_gae(rewards, values, dones, last_value, gamma=0.99, lam=0.95):
  T = len(rewards)
  advantages = torch.zeros(T)
  gae = 0
  next_value = 0

  for t in reversed(range(T)):
    delta = rewards[t] + gamma*next_value*(1 - dones[t]) - values[t]
    gae = delta + gamma*lam*gae*(1-dones[t])
    advantages[t] = gae
    next_value = values[t]


  return  advantages






$$
L^{CLIP}(\theta) =
\mathbb{E}_t \left[
\min \left(
r_t(\theta) A_t,
\text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t
\right)
\right]
$$

$$
r_t(\theta) =
\frac{
\pi_\theta(a_t \mid s_t)
}{
\pi_{\theta_{\text{old}}}(a_t \mid s_t)
}
$$





In [8]:
  def update(self):

        # Convert memory to tensors
        states = torch.stack(self.states)
        actions = torch.tensor(self.actions)
        old_logprobs = torch.stack(self.logprobs).detach()

        returns = self.compute_returns()

        # Get value estimates for advantage
        with torch.no_grad():
            _, state_values, _ = self.policy.evaluate(states, actions)
        advantages = returns - state_values

        # Optimize policy for K epochs
        for _ in range(self.K_epochs):

            logprobs, state_values, entropy = self.policy.evaluate(states, actions)

            ratios = torch.exp(logprobs - old_logprobs)

            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip,
                                         1 + self.eps_clip) * advantages

            # Final PPO loss
            loss = -torch.min(surr1, surr2) \
                   + 0.5 * self.MseLoss(state_values, returns) \
                   - 0.01 * entropy

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

        # Clear memory
        self.states = []
        self.actions = []
        self.logprobs = []
        self.rewards = []
        self.dones = []
