# Deep Dive: The Computational Graph of Pure JAX PPO

**Objective:** Build a high-performance, single-file implementation of Proximal Policy Optimization (PPO) that runs entirely on the GPU.

## 1. The Philosophy: "End-to-End" Compilation

To understand this code, we must understand the hardware bottleneck in standard Deep RL.

### The Standard approach (Python/Gym)
In libraries like PyTorch + Gym, the training loop looks like a ping-pong match:
1. **CPU:** Simulates the environment (physics, rules).
2. **PCIe Bus:** Transfers observation data to GPU.
3. **GPU:** Neural Network predicts action.
4. **PCIe Bus:** Transfers action back to CPU.
5. **CPU:** Steps the environment.

This communication overhead often costs more time than the actual computation.

### The Pure JAX approach
We rewrite the **Environment itself** in JAX (vector math). This allows us to compile the Environment, the Agent, and the Optimizer into a **single XLA kernel**.

* **Input:** A random seed.
* **Output:** Trained parameters.
* **Mechanism:** `jax.lax.scan` creates a compiled `for` loop that runs entirely inside VRAM. The data never leaves the GPU.

In [None]:
import os
import time
import argparse
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad, vmap
import numpy as np
import imageio
from PIL import Image, ImageDraw
import warnings
import logging
from typing import NamedTuple, Any

# Suppress warnings
warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*os.fork().*")
logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR)

## 2. The Environment: Functional & Stateless

In standard OOP (Object Oriented Programming), an environment has internal state (`self.state`). In JAX, we need **Functional Programming**.

The environment is defined by two pure functions:
1. `reset(key) -> state`: Deterministically creates a starting state.
2. `step(state, action) -> (next_state, reward, done)`: Advances physics by one tick.

**Why reimplement CartPole?**
Standard `gym` is written in Python/C++. We cannot `jit` or `vmap` it. By rewriting the dynamics equation (Euler integration) in JAX, we can simulate **thousands** of CartPoles in parallel on a single GPU.

In [None]:
class PureJaxCartPole:
    def __init__(self):
        # Physical constants (static, not state)
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = (self.masscart + self.masspole)
        self.length = 0.5 
        self.polemass_length = (self.masspole * self.length)
        self.force_mag = 10.0
        self.tau = 0.02 # Seconds per step
        self.theta_threshold_radians = 12 * 2 * jnp.pi / 360
        self.x_threshold = 2.4

    def reset(self, key):
        # Pure function: Input Key -> Output State
        state = jax.random.uniform(key, shape=(4,), minval=-0.05, maxval=0.05)
        return state

    def step(self, state, action):
        # Unpack state: [Position, Velocity, Angle, Angular Velocity]
        x, x_dot, theta, theta_dot = state
        
        # Physics Math (Euler Integration)
        force = jax.lax.select(action == 1, self.force_mag, -self.force_mag)
        costheta = jnp.cos(theta)
        sintheta = jnp.sin(theta)

        temp = (force + self.polemass_length * theta_dot**2 * sintheta) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass))
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        # State Update
        x = x + self.tau * x_dot
        x_dot = x_dot + self.tau * xacc
        theta = theta + self.tau * theta_dot
        theta_dot = theta_dot + self.tau * thetaacc

        next_state = jnp.array([x, x_dot, theta, theta_dot])
        
        # Termination Logic
        done = (x < -self.x_threshold) | (x > self.x_threshold) | (theta < -self.theta_threshold_radians) | (theta > self.theta_threshold_radians)
        reward = 1.0 # Keep pole up = +1 reward
        return next_state, reward, done

### Helper: Visualization
Since our environment is just math numbers on a GPU, we don't have a built-in window. We use this CPU-based helper to render the state arrays into pixels for debugging.

In [None]:
def render_cartpole(state, width=600, height=400):
    x, _, theta, _ = state
    world_width = 2.4 * 2
    scale = width / world_width
    carty = 250
    polewidth, polelen = 10.0, scale * 1.0
    cartwidth, cartheight = 50.0, 30.0

    img = Image.new('RGB', (width, height), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    draw.line([(0, carty), (width, carty)], fill=(0, 0, 0), width=1)

    cartx = x * scale + width / 2.0
    draw.rectangle([cartx - cartwidth / 2, carty - cartheight / 2, cartx + cartwidth / 2, carty + cartheight / 2], fill=(0, 0, 0))

    rotation_angle = -theta
    l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2
    coords = []
    for px, py in [(l, b), (l, t), (r, t), (r, b)]:
        px_rot = px * np.cos(rotation_angle) - py * np.sin(rotation_angle)
        py_rot = px * np.sin(rotation_angle) + py * np.cos(rotation_angle)
        coords.append((cartx + px_rot, carty - py_rot))
        
    draw.polygon(coords, fill=(204, 153, 102))
    draw.ellipse([cartx - 2, carty - 2, cartx + 2, carty + 2], fill=(127, 127, 255))
    return np.array(img)

## 3. Networks and Optimizer

In JAX, a Neural Network is not a class holding data. It is a **pure function** that transforms input data using a **parameter tree (dictionary)**.

### Design Choices:
1.  **Orthogonal Initialization:** PPO is sensitive to initial weights. Orthogonal init ensures features are decorrelated at the start, leading to stable initial gradients.
2.  **Separate Actor/Critic:** We use two separate parameter sets (Actor and Critic). This avoids the "competition" for features that can happen in shared backbones.
3.  **Manual Adam:** We implement Adam manually to keep the entire optimizer state stateless and JIT-compatible.

In [None]:
def orthogonal_init(key, shape, scale=1.0):
    flat_shape = (shape[0], np.prod(shape[1:]))
    a = jax.random.normal(key, flat_shape)
    u, _, vt = jnp.linalg.svd(a, full_matrices=False)
    q = u if u.shape == flat_shape else vt
    return scale * q.reshape(shape)

def init_actor_critic(key, obs_dim, action_dim, hidden_dim=64):
    keys = jax.random.split(key, 6)
    def init_layer(k, i, o, s=1.0): 
        return {'w': orthogonal_init(k, (i, o), s), 'b': jnp.zeros((o,))}
    # Returning a dictionary of arrays (The Parameter Tree)
    return {
        'actor': {
            'l1': init_layer(keys[0], obs_dim, hidden_dim, np.sqrt(2)),
            'l2': init_layer(keys[1], hidden_dim, hidden_dim, np.sqrt(2)),
            'head': init_layer(keys[2], hidden_dim, action_dim, 0.01)
        },
        'critic': {
            'l1': init_layer(keys[3], obs_dim, hidden_dim, np.sqrt(2)),
            'l2': init_layer(keys[4], hidden_dim, hidden_dim, np.sqrt(2)),
            'head': init_layer(keys[5], hidden_dim, 1, 1.0)
        }
    }

def forward_mlp(params, x):
    # Functional forward pass
    x = jax.nn.tanh(x @ params['l1']['w'] + params['l1']['b'])
    x = jax.nn.tanh(x @ params['l2']['w'] + params['l2']['b'])
    return x @ params['head']['w'] + params['head']['b']

def get_action_logits(params, obs): return forward_mlp(params['actor'], obs)
def get_value(params, obs): return forward_mlp(params['critic'], obs).squeeze(-1)

def adam_update(grads, opt_state, params, lr, max_grad_norm=0.5):
    # A stateless implementation of Adam optimization
    step = opt_state['step'] + 1
    # 1. Global Gradient Clipping
    leaves, _ = jax.tree_util.tree_flatten(grads)
    total_norm = jnp.sqrt(sum(jnp.sum(g ** 2) for g in leaves))
    grads = jax.tree.map(lambda g: g * jnp.minimum(max_grad_norm / (total_norm + 1e-6), 1.0), grads)
    
    # 2. Adam Momentum Updates
    m = jax.tree.map(lambda m, g: 0.9 * m + 0.1 * g, opt_state['m'], grads)
    v = jax.tree.map(lambda v, g: 0.999 * v + 0.001 * (g ** 2), opt_state['v'], grads)
    
    # 3. Bias Correction
    m_hat = jax.tree.map(lambda m: m / (1 - 0.9 ** step), m)
    v_hat = jax.tree.map(lambda v: v / (1 - 0.999 ** step), v)
    
    # 4. Weight Update
    params = jax.tree.map(lambda p, m, v: p - lr * m / (jnp.sqrt(v) + 1e-8), params, m_hat, v_hat)
    return params, {'m': m, 'v': v, 'step': step}

## 4. Phase 1: The Rollout (Data Collection)

PPO is an **on-policy** algorithm. This means we must collect data using the *current* policy, train on it, and then discard it.

**The Hardware Trick:**
We use `jax.lax.scan` to run the environment loop. In Python, a loop over 128 steps would trigger 128 GPU kernel launches. In JAX, `scan` fuses this entire loop into **one single kernel**.

**Dimensions:**
We run $N$ environments in parallel (via `vmap`).
* Input State: $(N_{envs}, 4)$
* Output Trajectory: $(N_{steps}, N_{envs}, 4)$

**Crucial Concept:** 
During this phase, `jax.stop_gradient` is effectively on. We are not backpropagating through time. We are just recording integer IDs (actions) and floats (rewards).

In [None]:
def create_rollout_fn(env, num_steps, num_envs):
    """Creates the rollout function (Closure)."""
    
    def rollout_step(carry, unused):
        # Unpack the carry state
        env_states, episode_returns, params, key = carry
        key, subkey = jax.random.split(key)
        
        # 1. Get Action from Policy
        logits = get_action_logits(params, env_states)
        action = jax.random.categorical(subkey, logits)
        value = get_value(params, env_states)
        
        # 2. Store Log Probs for later gradient calculation
        log_prob_all = jax.nn.log_softmax(logits)
        action_log_prob = jnp.take_along_axis(log_prob_all, action[:, None], axis=1).squeeze(-1)
        
        # 3. Step the Environment (Vectorized over num_envs)
        next_env_states, reward, done = vmap(env.step)(env_states, action)
        
        # 4. Handle Auto-Reset (Stateless)
        # If done, we replace the state with a fresh reset state immediately.
        key, *reset_keys = jax.random.split(key, num_envs + 1)
        reset_states = vmap(env.reset)(jnp.array(reset_keys))
        next_env_states = jnp.where(done[:, None], reset_states, next_env_states)
        
        # Track metrics
        episode_returns = episode_returns + reward
        final_return = jnp.where(done, episode_returns, 0.0)
        episode_returns = jnp.where(done, 0.0, episode_returns)
        
        # Pack data for trajectory buffer
        transition = (env_states, action, action_log_prob, reward, done, value, final_return)
        return (next_env_states, episode_returns, params, key), transition

    return rollout_step

## 5. Phase 2: GAE (Generalized Advantage Estimation)

After collecting data, we need to grade the agent's performance. 
We need to calculate the **Advantage** ($A_t$): *"How much better was this action than the average action the Critic expected?"*

### The Bias-Variance Trade-off ($\\lambda$)
We estimate the return using **GAE**:
$$ A_t^{GAE} = \delta_t + (\gamma \lambda) \delta_{t+1} + (\gamma \lambda)^2 \delta_{t+2} + ... $$

Where $\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$ is the **TD Error**.

* **If $\lambda = 0$:** The target is mostly the Critic's prediction ($r_t + V_{next}$). Low Variance, High Bias (Stable but biased).
* **If $\lambda = 1$:** The target is the full Monte Carlo sum of rewards. High Variance, Low Bias (Accurate but noisy).
* **We use $\lambda=0.95$:** A sweet spot that balances stability and accuracy.

**Implementation Note:** Because $A_t$ depends on $A_{t+1}$, we must calculate this **backwards** (from $t=T$ down to $0$). We use `jax.lax.scan` with `reverse=True`.

In [None]:
def calculate_gae(params, next_env_states, traj_batch, args):
    obs, actions, logprobs, rewards, dones, values, final_returns = traj_batch
    
    # Bootstrap with the value of the LAST state in the sequence
    next_value = get_value(params, next_env_states)
    
    def gae_scan_fn(carry, t):
        last_gae_lam, next_val = carry
        
        # TD-Error: delta = r + gamma * V(next) - V(current)
        delta = rewards[t] + args['gamma'] * next_val * (1.0 - dones[t]) - values[t]
        
        # GAE recursive formula
        last_gae_lam = delta + args['gamma'] * args['gae_lambda'] * (1.0 - dones[t]) * last_gae_lam
        return (last_gae_lam, values[t]), last_gae_lam

    _, advantages = jax.lax.scan(
        gae_scan_fn, 
        (jnp.zeros_like(next_value), next_value), 
        jnp.arange(args['num_steps']), 
        reverse=True # <--- Crucial: We calculate backwards!
    )
    
    returns = advantages + values
    return advantages, returns

## 6. Phase 3: The PPO Loss Function

This is the heart of the algorithm. We use gradient descent to update our Policy $\pi$ and Value Function $V$.

$$ J_{total} = J_{actor} + c_{vf} J_{critic} - c_{ent} J_{entropy} $$

### 1. The Actor Loss (Clipping)
We want to increase the probability of good actions (high Advantage). However, if we change the policy *too much*, training becomes unstable. 

**The PPO Solution:**
We define the ratio $r_t = \frac{\pi_{new}}{\pi_{old}}$.
$$ L^{CLIP} = \min(r_t A_t, \text{clip}(r_t, 1-\epsilon, 1+\epsilon) A_t) $$
This "Pessimistic Bound" ignores the gradient if the policy has changed too much (outside $1 \pm \epsilon$), preventing catastrophic updates.

### 2. The Critic Loss (Value Clipping)
The Critic learns to predict returns (Regression). We use a similar clipping trick here: if the Value estimate changes too drastically from the previous iteration, we clip the loss. This prevents the Critic from chasing outliers.

### 3. Entropy Bonus
We subtract entropy from the loss (or add it to the objective). This encourages the distribution to remain "flat" (random). If entropy drops too low, the agent becomes deterministic and stops exploring.

In [None]:
def create_update_fn(args):
    """Creates the PPO Update function."""
    
    def loss_fn(params, batch):
        obs, act, logp_old, adv, ret, val_old = batch
        
        # 1. Re-run network on recorded observations
        new_logits = get_action_logits(params, obs)
        new_values = get_value(params, obs)
        
        # 2. Probability Ratios
        logp_new = jnp.take_along_axis(jax.nn.log_softmax(new_logits), act[:, None], axis=1).squeeze(-1)
        ratio = jnp.exp(logp_new - logp_old)
        
        # 3. Actor Loss (Clipped)
        loss_actor_unclipped = adv * ratio
        loss_actor_clipped = adv * jnp.clip(ratio, 1 - args['clip_coef'], 1 + args['clip_coef'])
        loss_actor = -jnp.minimum(loss_actor_unclipped, loss_actor_clipped).mean()
        
        # 4. Critic Loss (Clipped option for stability)
        if args['clip_vloss']:
            v_loss_unclipped = (new_values - ret) ** 2
            v_clipped = val_old + jnp.clip(new_values - val_old, -args['clip_coef'], args['clip_coef'])
            v_loss_clipped = (v_clipped - ret) ** 2
            loss_critic = 0.5 * jnp.maximum(v_loss_unclipped, v_loss_clipped).mean()
        else:
            loss_critic = 0.5 * ((new_values - ret) ** 2).mean()
            
        # 5. Entropy Loss (Exploration Bonus)
        entropy = -jnp.sum(jax.nn.softmax(new_logits) * jax.nn.log_softmax(new_logits), axis=1).mean()
        loss_entropy = -args['ent_coef'] * entropy
        
        total_loss = loss_actor + args['vf_coef'] * loss_critic + loss_entropy
        return total_loss, (loss_actor, loss_critic, entropy)

    def update_minibatch(carry, batch):
        params, opt_state = carry
        (loss, metrics), grads = value_and_grad(loss_fn, has_aux=True)(params, batch)
        params, opt_state = adam_update(grads, opt_state, params, args['lr'], args['max_grad_norm'])
        return (params, opt_state), metrics

    return update_minibatch

## 7. Composition: The `make_train` Factory

This function is the "Orchestrator." It connects the components into a single Computational Graph that JAX can compile.

### Data Flow Architecture
1.  **Collection (Scan):**
    * Input: `(NumEnvs, ObsDim)`
    * Output: `(NumSteps, NumEnvs, ObsDim)` 
2.  **Processing (GAE):**
    * Calculates Advantages using the full time-batch.
3.  **Flattening:**
    * The `NumSteps` and `NumEnvs` dimensions are merged. We don't care about time anymore, only independent data points.
    * Shape: `(NumSteps * NumEnvs, ObsDim)` = `(BatchSize, ObsDim)`
4.  **SGD Epochs (Scan):**
    * We shuffle the large batch.
    * We iterate over minibatches to update the network.

**The JIT Magic:**
The final line `return jit(train_step)` compiles **all** of this logic (Simulation + GAE + Gradient Descent) into a single C++/CUDA binary.

In [None]:
def make_train(args):
    """Combines Env, Policy, and Update into a single JIT-able function."""
    
    env = PureJaxCartPole()
    rollout_fn = create_rollout_fn(env, args['num_steps'], args['num_envs'])
    update_fn = create_update_fn(args)
    
    batch_size = args['num_envs'] * args['num_steps']
    minibatch_size = batch_size // args['num_minibatches']

    def train_step(carry, unused):
        params, opt_state, env_states, episode_returns, key = carry
        
        # --- Phase 1: Data Collection ---
        (next_env_states, episode_returns, params, key), traj_batch = jax.lax.scan(
            rollout_fn, (env_states, episode_returns, params, key), None, length=args['num_steps']
        )
        
        # --- Phase 2: GAE ---
        advantages, returns = calculate_gae(params, next_env_states, traj_batch, args)
        
        # Flatten and Normalize
        obs, actions, logprobs, _, _, values, final_returns = traj_batch
        flat_inds = lambda x: x.reshape(batch_size, -1) if x.ndim > 2 else x.reshape(batch_size)
        
        b_obs, b_act, b_logp, b_adv, b_ret, b_val = map(flat_inds, 
            (obs, actions, logprobs, advantages, returns, values)
        )
        
        if args['norm_adv']:
            b_adv = (b_adv - b_adv.mean()) / (b_adv.std() + 1e-8)

        # --- Phase 3: Update (Multiple Epochs) ---
        def update_epoch(carry, unused):
            params, opt_state, key = carry
            key, subkey = jax.random.split(key)
            
            # Shuffle data
            perm = jax.random.permutation(subkey, batch_size)
            b_obs_s, b_act_s, b_logp_s, b_adv_s, b_ret_s, b_val_s = \
                jax.tree_util.tree_map(lambda x: x[perm], (b_obs, b_act, b_logp, b_adv, b_ret, b_val))
            
            # Iterate over minibatches
            def get_batch(i): 
                return jax.tree_util.tree_map(
                    lambda x: jax.lax.dynamic_slice_in_dim(x, i, minibatch_size), 
                    (b_obs_s, b_act_s, b_logp_s, b_adv_s, b_ret_s, b_val_s)
                )
            
            def run_minibatch(carry, i):
                return update_fn(carry, get_batch(i))

            (params, opt_state), metrics = jax.lax.scan(
                run_minibatch, (params, opt_state), jnp.arange(0, batch_size, minibatch_size)
            )
            return (params, opt_state, key), metrics

        (params, opt_state, key), metrics = jax.lax.scan(
            update_epoch, (params, opt_state, key), None, length=args['update_epochs']
        )
        
        # Return packing
        avg_return = final_returns.sum() / (final_returns > 0).sum().clip(1.0)
        metrics = jax.tree.map(lambda x: x.mean(), metrics)
        
        return (params, opt_state, next_env_states, episode_returns, key), (metrics, avg_return)

    return jit(train_step)

## 8. Execution

**The "First Call" Penalty:**
When you run `train_step` for the first time, JAX/XLA must compile the computational graph. This might take 5-15 seconds. 

**Subsequent Calls:**
Once compiled, the function runs at full GPU speed. On a modern GPU, this CartPole implementation can exceed **2 million Steps Per Second (SPS)**.

In [None]:
config = {
    'seed': 42,
    'total_timesteps': 500000,
    'num_envs': 64,
    'num_steps': 128,
    'lr': 2.5e-4,
    'num_minibatches': 4,
    'update_epochs': 4,
    'gamma': 0.99,
    'gae_lambda': 0.95,
    'clip_coef': 0.2,
    'ent_coef': 0.01,
    'vf_coef': 0.5,
    'max_grad_norm': 0.5,
    'norm_adv': True,
    'clip_vloss': True,
}

print("Compiling...")
train_step = make_train(config)

# Init State
rng = jax.random.PRNGKey(config['seed'])
rng, init_key = jax.random.split(rng)
params = init_actor_critic(init_key, obs_dim=4, action_dim=2)
opt_state = {'m': jax.tree.map(jnp.zeros_like, params), 'v': jax.tree.map(jnp.zeros_like, params), 'step': 0}
rng, *env_keys = jax.random.split(rng, config['num_envs'] + 1)
env_states = vmap(PureJaxCartPole().reset)(jnp.array(env_keys))
episode_returns = jnp.zeros(config['num_envs'])

runner_state = (params, opt_state, env_states, episode_returns, rng)
batch_size = config['num_envs'] * config['num_steps']
num_updates = config['total_timesteps'] // batch_size

print(f"Starting training for {num_updates} updates...")
start_time = time.perf_counter()

for i in range(num_updates):
    runner_state, (metrics, avg_return) = train_step(runner_state, None)
    
    if i % 10 == 0:
        sps = (i + 1) * batch_size / (time.perf_counter() - start_time)
        print(f"Update {i}: Return {avg_return:.2f} | SPS: {int(sps)} | Loss: {metrics[0].mean():.3f}")

print("Done!")