# Deep Dive: The Computational Graph of Pure JAX PPO

This notebook combines high-level intuition, a concrete running example, and the exact mathematical derivation of the gradients with a modular JAX implementation.

## 1. The Computational Context: "End-to-End" Compilation
Before looking at the math, we must understand the hardware execution, which dictates the code structure. In high-performance Reinforcement Learning, data transfer is often the slowest part.

### The Bottleneck Problem
* **Standard RL (Python/Gym):** The CPU simulates physics (Environment) $\leftrightarrow$ GPU computes gradients (Agent). This requires moving data back and forth over the PCIe bus thousands of times per second. This is the "PCIe Bottleneck."
* **Pure JAX RL (This Code):** The Environment is rewritten as stateless vector math. The entire training loop—simulation, data collection, and backpropagation—is compiled into a single XLA kernel.

**Mechanism:** `jax.lax.scan` acts as a compiled `for` loop inside the GPU. The data never leaves VRAM until training is finished. This allows us to train millions of steps in seconds.

In [None]:
import os
import time
import argparse
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad, vmap
import numpy as np
import imageio
from PIL import Image, ImageDraw
import warnings
import logging
from typing import NamedTuple, Any

# Suppress warnings
warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*os.fork().*")
logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR)

### The Environment: Stateless Vector Math
In JAX, we cannot have side effects or internal state objects. Therefore, the environment must be a **pure function**.

$$ (S_{t+1}, R_t, D_t) = f(S_t, A_t, \text{PhysicsParams}) $$

**Key Characteristics:**
1.  **Stateless:** The `step` function does not modify `self`. It takes the current state as input and returns the next state.
2.  **Vectorized:** We don't need a `VectorEnv` wrapper. `jax.vmap` automatically vectorizes this single-environment logic across a batch of agents.
3.  **Differentiable (Optional):** Because it's written in JAX, we *could* differentiate through the physics, though PPO (as a model-free algorithm) treats the environment as a black box.

In [None]:
class PureJaxCartPole:
    def __init__(self):
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = (self.masscart + self.masspole)
        self.length = 0.5 
        self.polemass_length = (self.masspole * self.length)
        self.force_mag = 10.0
        self.tau = 0.02
        self.kinematics_integrator = 'euler'
        self.theta_threshold_radians = 12 * 2 * jnp.pi / 360
        self.x_threshold = 2.4

    def reset(self, key):
        state = jax.random.uniform(key, shape=(4,), minval=-0.05, maxval=0.05)
        return state

    def step(self, state, action):
        x, x_dot, theta, theta_dot = state
        force = jax.lax.select(action == 1, self.force_mag, -self.force_mag)
        costheta = jnp.cos(theta)
        sintheta = jnp.sin(theta)

        temp = (force + self.polemass_length * theta_dot**2 * sintheta) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass))
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        if self.kinematics_integrator == 'euler':
            x = x + self.tau * x_dot
            x_dot = x_dot + self.tau * xacc
            theta = theta + self.tau * theta_dot
            theta_dot = theta_dot + self.tau * thetaacc
        else:
            x_dot = x_dot + self.tau * xacc
            x = x + self.tau * x_dot
            theta_dot = theta_dot + self.tau * thetaacc
            theta = theta + self.tau * theta_dot

        next_state = jnp.array([x, x_dot, theta, theta_dot])
        done = (x < -self.x_threshold) | (x > self.x_threshold) | (theta < -self.theta_threshold_radians) | (theta > self.theta_threshold_radians)
        reward = 1.0
        return next_state, reward, done

### Helper: Visualization
We use a custom PIL-based renderer to verify the agent's behavior.

In [None]:
def render_cartpole(state, width=600, height=400):
    x, _, theta, _ = state
    world_width = 2.4 * 2
    scale = width / world_width
    carty = 250
    polewidth, polelen = 10.0, scale * 1.0
    cartwidth, cartheight = 50.0, 30.0

    img = Image.new('RGB', (width, height), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    draw.line([(0, carty), (width, carty)], fill=(0, 0, 0), width=1)

    cartx = x * scale + width / 2.0
    draw.rectangle([cartx - cartwidth / 2, carty - cartheight / 2, cartx + cartwidth / 2, carty + cartheight / 2], fill=(0, 0, 0))

    rotation_angle = -theta
    l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2
    coords = []
    for px, py in [(l, b), (l, t), (r, t), (r, b)]:
        px_rot = px * np.cos(rotation_angle) - py * np.sin(rotation_angle)
        py_rot = px * np.sin(rotation_angle) + py * np.cos(rotation_angle)
        coords.append((cartx + px_rot, carty - py_rot))
        
    draw.polygon(coords, fill=(204, 153, 102))
    draw.ellipse([cartx - 2, carty - 2, cartx + 2, carty + 2], fill=(127, 127, 255))
    return np.array(img)

### Models and Optimizer
We define the Actor (Policy) and Critic (Value Function) as simple MLPs. 

**1. Orthogonal Initialization:**
Standard in PPO. We initialize weights such that the eigenvalues of the Jacobian are close to 1. This preserves the gradient magnitude during the early stages of training, which is crucial for deep RL stability.

**2. Manual Adam Implementation:**
Why not `optax`? To keep this notebook self-contained and to demonstrate that an optimizer is just a state update function:
$$ \theta_{t+1} = \text{Update}(\theta_t, \nabla L, \text{OptState}_t) $$
This allows us to easily `scan` over the optimization steps without external dependencies.

In [None]:
def orthogonal_init(key, shape, scale=1.0):
    flat_shape = (shape[0], np.prod(shape[1:]))
    a = jax.random.normal(key, flat_shape)
    u, _, vt = jnp.linalg.svd(a, full_matrices=False)
    q = u if u.shape == flat_shape else vt
    return scale * q.reshape(shape)

def init_actor_critic(key, obs_dim, action_dim, hidden_dim=64):
    keys = jax.random.split(key, 6);
    def init_layer(k, i, o, s=1.0): 
        return {'w': orthogonal_init(k, (i, o), s), 'b': jnp.zeros((o,))}
    return {
        'actor': {
            'l1': init_layer(keys[0], obs_dim, hidden_dim, np.sqrt(2)),
            'l2': init_layer(keys[1], hidden_dim, hidden_dim, np.sqrt(2)),
            'head': init_layer(keys[2], hidden_dim, action_dim, 0.01)
        },
        'critic': {
            'l1': init_layer(keys[3], obs_dim, hidden_dim, np.sqrt(2)),
            'l2': init_layer(keys[4], hidden_dim, hidden_dim, np.sqrt(2)),
            'head': init_layer(keys[5], hidden_dim, 1, 1.0)
        }
    }

def forward_mlp(params, x):
    x = jax.nn.tanh(x @ params['l1']['w'] + params['l1']['b'])
    x = jax.nn.tanh(x @ params['l2']['w'] + params['l2']['b'])
    return x @ params['head']['w'] + params['head']['b']

def get_action_logits(params, obs): return forward_mlp(params['actor'], obs)
def get_value(params, obs): return forward_mlp(params['critic'], obs).squeeze(-1)

def adam_update(grads, opt_state, params, lr, max_grad_norm=0.5):
    step = opt_state['step'] + 1
    # Gradient Clipping
    leaves, _ = jax.tree_util.tree_flatten(grads)
    total_norm = jnp.sqrt(sum(jnp.sum(g ** 2) for g in leaves))
    grads = jax.tree.map(lambda g: g * jnp.minimum(max_grad_norm / (total_norm + 1e-6), 1.0), grads)
    # Adam Update
    m = jax.tree.map(lambda m, g: 0.9 * m + 0.1 * g, opt_state['m'], grads)
    v = jax.tree.map(lambda v, g: 0.999 * v + 0.001 * (g ** 2), opt_state['v'], grads)
    m_hat = jax.tree.map(lambda m: m / (1 - 0.9 ** step), m)
    v_hat = jax.tree.map(lambda v: v / (1 - 0.999 ** step), v)
    params = jax.tree.map(lambda p, m, v: p - lr * m / (jnp.sqrt(v) + 1e-8), params, m_hat, v_hat)
    return params, {'m': m, 'v': v, 'step': step}

## 2. Phase 1: The Rollout (Data Collection)

**Code Context:** `rollout_fn`

We run $N$ environments in parallel. The policy $\pi_{\theta}(a|s)$ interacts with the environment dynamics $P(s_{t+1}|s_t, a_t)$.

### The Trajectory Buffer
We collect a batch of data $D = \{ (s_t, a_t, r_t, s_{t+1}, \log \pi(a_t|s_t), V(s_t)) \}_{t=0}^{T}$.
This data is stored in static GPU arrays. Because JAX requires fixed shapes, we must decide the trajectory length $T$ (e.g., 128 steps) beforehand.

**The Trace: A Concrete Example**
Let us track one environment (Env #42) over 4 timesteps.

| Step (t) | State ($s_t$) | Critic $V(s_t)$ | Actor Logits $z$ | Softmax $\pi(s_t)$ | Action $a_t$ | Reward $r_t$ | Done $d_t$ |
| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |
| 0 | 0.05 | $1.0$ | [-0.5, 0.5] | $\approx [0.27, 0.73]$ | 1 (Right) | +1.0 | 0 |
| 1 | 0.10 | $1.5$ | [-0.1, 0.1] | $\approx [0.45, 0.55]$ | 0 (Left) | +1.0 | 0 |
| 2 | 0.25 | $0.8$ | [-1.0, 1.0] | $\approx [0.12, 0.88]$ | 1 (Right) | +1.0 | 1 |
| 3 | 0.00 | $0.0$ | [0.0, 0.0] | $\approx [0.50, 0.50]$ | 1 (Right) | 0.0 | 0 |

**Crucial Math Note:**
During this phase, **no gradients are computed**. We simply populate fixed-size buffers (Tensors) with integers and floats. This is purely inference.

In [None]:
def create_rollout_fn(env, num_steps, num_envs):
    """Creates the rollout function (Closure)."""
    
    def rollout_step(carry, unused):
        # Unpack the carry state
        env_states, episode_returns, params, key = carry
        key, subkey = jax.random.split(key)
        
        # 1. Get Action from Policy
        logits = get_action_logits(params, env_states)
        action = jax.random.categorical(subkey, logits)
        value = get_value(params, env_states)
        
        # 2. Store Log Probs for later gradient calculation
        log_prob_all = jax.nn.log_softmax(logits)
        action_log_prob = jnp.take_along_axis(log_prob_all, action[:, None], axis=1).squeeze(-1)
        
        # 3. Step the Environment (Vectorized)
        next_env_states, reward, done = vmap(env.step)(env_states, action)
        
        # 4. Handle Auto-Reset (Stateless)
        key, *reset_keys = jax.random.split(key, num_envs + 1)
        reset_states = vmap(env.reset)(jnp.array(reset_keys))
        next_env_states = jnp.where(done[:, None], reset_states, next_env_states)
        
        # Track metrics
        episode_returns = episode_returns + reward
        final_return = jnp.where(done, episode_returns, 0.0)
        episode_returns = jnp.where(done, 0.0, episode_returns)
        
        # Pack data for trajectory buffer
        transition = (env_states, action, action_log_prob, reward, done, value, final_return)
        return (next_env_states, episode_returns, params, key), transition

    return rollout_step

## 3. Phase 2: Signal Processing (GAE)

**Code Context:** `calculate_gae`

We now calculate the targets for learning. We need to answer: **How good was each action compared to the average?**

### A. The Critic's Target: Discounted Return ($R_t$)
The Critic aims to predict the sum of future rewards. We calculate this recursively backwards (Dynamic Programming):

$$R_t = r_t + \gamma \cdot R_{t+1} \cdot (1 - d_t)$$

### B. The Actor's Weight: GAE Advantage ($\hat{A}_t$)
The Generalized Advantage Estimate balances **Bias** (using the Critic's imperfect prediction) and **Variance** (using the noisy Monte-Carlo rewards).

1.  **TD Error ($\delta_t$):** The one-step surprise.
    $$\delta_t = r_t + \gamma V(s_{t+1})(1-d_t) - V(s_t)$$

2.  **GAE ($\hat{A}_t$):** The exponentially weighted sum of TD errors.
    $$\hat{A}_t = \sum_{k=0}^{\infty} (\gamma \lambda)^k \delta_{t+k} = \delta_t + \gamma \lambda \hat{A}_{t+1}$$

**Trace Example:**
*   **Step 2:** $\delta_2 = 1.0 + 0.99(0) - 0.8 = \mathbf{+0.2}$. (End of episode, exact truth).
*   **Step 1:** $\delta_1 = 1.0 + 0.99(0.8) - 1.5 = \mathbf{+0.292}$.
*   **Step 0:** $\delta_0 = 1.0 + 0.99(1.5) - 1.0 = \mathbf{+1.485}$. (Large positive surprise).

Because $\hat{A}_0 > 0$, the action taken at $t=0$ was *better* than expected, so we should encourage it.

In [None]:
def calculate_gae(params, next_env_states, traj_batch, args):
    obs, actions, logprobs, rewards, dones, values, final_returns = traj_batch
    
    # Bootstrap with the value of the LAST state in the sequence
    next_value = get_value(params, next_env_states)
    
    def gae_scan_fn(carry, t):
        last_gae_lam, next_val = carry
        
        # TD-Error: delta = r + gamma * V(next) - V(current)
        delta = rewards[t] + args['gamma'] * next_val * (1.0 - dones[t]) - values[t]
        
        # GAE: advantage = delta + gamma * lambda * last_advantage
        last_gae_lam = delta + args['gamma'] * args['gae_lambda'] * (1.0 - dones[t]) * last_gae_lam
        return (last_gae_lam, values[t]), last_gae_lam

    _, advantages = jax.lax.scan(
        gae_scan_fn, 
        (jnp.zeros_like(next_value), next_value), 
        jnp.arange(args['num_steps']), 
        reverse=True # <--- Crucial: We calculate backwards!
    )
    
    returns = advantages + values
    return advantages, returns

## 4. Phase 3: The Gradient (Backpropagation)

**Code Context:** `update_ppo` (containing `loss_fn`)

We define a scalar loss function $J(\theta)$ and compute $\nabla_\theta J$ using automatic differentiation.

$$J_{total} = L_{actor} + c_{vf} L_{critic} + c_{ent} L_{entropy}$$

### Part A: The Actor (PPO Clipped Objective)
We want to increase the probability of good actions, but **not too much** at once to avoid policy collapse.

1.  **Probability Ratio:** $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{old}(a_t|s_t)}$. Initially 1.0.
2.  **Clipped Objective:**
    $$L^{CLIP} = \min(r_t \hat{A}_t, \text{clip}(r_t, 1-\epsilon, 1+\epsilon)\hat{A}_t)$$
    If $\hat{A}_t > 0$ (good action), we increase prob up to $1+\epsilon$. If $\hat{A}_t < 0$ (bad action), we decrease prob down to $1-\epsilon$.

### Part B: The Critic (Value Loss)
Simple regression. We want the critic to accurately predict the returns we actually saw.
$$L^{VF} = \text{MSE}(V_\theta(s_t), R_t)$$

### Part C: Entropy Bonus
We add a bonus for randomness to prevent premature convergence (the agent deciding on one action too early).
$$L^{S} = \text{Entropy}(\pi_\theta)$$

In [None]:
def create_update_fn(args):
    """Creates the PPO Update function."""
    
    def loss_fn(params, batch):
        obs, act, logp_old, adv, ret, val_old = batch
        
        # 1. Re-run network on recorded observations
        new_logits = get_action_logits(params, obs)
        new_values = get_value(params, obs)
        
        # 2. Probability Ratios
        logp_new = jnp.take_along_axis(jax.nn.log_softmax(new_logits), act[:, None], axis=1).squeeze(-1)
        ratio = jnp.exp(logp_new - logp_old)
        
        # 3. Actor Loss (Clipped)
        loss_actor_unclipped = adv * ratio
        loss_actor_clipped = adv * jnp.clip(ratio, 1 - args['clip_coef'], 1 + args['clip_coef'])
        loss_actor = -jnp.minimum(loss_actor_unclipped, loss_actor_clipped).mean()
        
        # 4. Critic Loss (Clipped option for stability)
        if args['clip_vloss']:
            v_loss_unclipped = (new_values - ret) ** 2
            v_clipped = val_old + jnp.clip(new_values - val_old, -args['clip_coef'], args['clip_coef'])
            v_loss_clipped = (v_clipped - ret) ** 2
            loss_critic = 0.5 * jnp.maximum(v_loss_unclipped, v_loss_clipped).mean()
        else:
            loss_critic = 0.5 * ((new_values - ret) ** 2).mean()
            
        # 5. Entropy Loss (Exploration Bonus)
        entropy = -jnp.sum(jax.nn.softmax(new_logits) * jax.nn.log_softmax(new_logits), axis=1).mean()
        loss_entropy = -args['ent_coef'] * entropy
        
        total_loss = loss_actor + args['vf_coef'] * loss_critic + loss_entropy
        return total_loss, (loss_actor, loss_critic, entropy)

    def update_minibatch(carry, batch):
        params, opt_state = carry
        (loss, metrics), grads = value_and_grad(loss_fn, has_aux=True)(params, batch)
        params, opt_state = adam_update(grads, opt_state, params, args['lr'], args['max_grad_norm'])
        return (params, opt_state), metrics

    return update_minibatch

## 5. Composition: The `make_train` Factory

This is the heart of JAX compilation. We define the **entire** training lifecycle as a nested function structure. 

**The Hierarchy of Loops:**

1.  **`train_step` (The Outer Loop):** Represents one "Update".
    *   **Input:** Current Params, Optimizer State, Environment State.
    *   **Output:** New Params, Metrics.

    *   **Step 1: Rollout (`jax.lax.scan`):**
        *   Runs `env.step` $T$ times.
        *   Collects a batch of trajectories.

    *   **Step 2: GAE Calculation:**
        *   Processes the trajectories to find Advantages.

    *   **Step 3: Optimization (`jax.lax.scan` - Epochs):**
        *   Iterates $K$ epochs over the data.
        *   **Step 3a: Minibatches (`jax.lax.scan`):**
            *   Slices the data into small batches.
            *   Computes gradients and updates params.

All of this is passed to `jax.jit` at the end, creating one massive, optimized GPU kernel.

In [None]:
def make_train(args):
    """Combines Env, Policy, and Update into a single JIT-able function."""
    
    env = PureJaxCartPole()
    rollout_fn = create_rollout_fn(env, args['num_steps'], args['num_envs'])
    update_fn = create_update_fn(args)
    
    batch_size = args['num_envs'] * args['num_steps']
    minibatch_size = batch_size // args['num_minibatches']

    def train_step(carry, unused):
        params, opt_state, env_states, episode_returns, key = carry
        
        # --- Phase 1: Data Collection ---
        (next_env_states, episode_returns, params, key), traj_batch = jax.lax.scan(
            rollout_fn, (env_states, episode_returns, params, key), None, length=args['num_steps']
        )
        
        # --- Phase 2: GAE ---
        advantages, returns = calculate_gae(params, next_env_states, traj_batch, args)
        
        # Flatten and Normalize
        obs, actions, logprobs, _, _, values, final_returns = traj_batch
        flat_inds = lambda x: x.reshape(batch_size, -1) if x.ndim > 2 else x.reshape(batch_size)
        
        b_obs, b_act, b_logp, b_adv, b_ret, b_val = map(flat_inds, 
            (obs, actions, logprobs, advantages, returns, values)
        )
        
        if args['norm_adv']:
            b_adv = (b_adv - b_adv.mean()) / (b_adv.std() + 1e-8)

        # --- Phase 3: Update (Multiple Epochs) ---
        def update_epoch(carry, unused):
            params, opt_state, key = carry
            key, subkey = jax.random.split(key)
            
            # Shuffle data
            perm = jax.random.permutation(subkey, batch_size)
            b_obs_s, b_act_s, b_logp_s, b_adv_s, b_ret_s, b_val_s = \
                jax.tree_util.tree_map(lambda x: x[perm], (b_obs, b_act, b_logp, b_adv, b_ret, b_val))
            
            # Iterate over minibatches
            def get_batch(i): 
                return jax.tree_util.tree_map(
                    lambda x: jax.lax.dynamic_slice_in_dim(x, i, minibatch_size), 
                    (b_obs_s, b_act_s, b_logp_s, b_adv_s, b_ret_s, b_val_s)
                )
            
            def run_minibatch(carry, i):
                return update_fn(carry, get_batch(i))

            (params, opt_state), metrics = jax.lax.scan(
                run_minibatch, (params, opt_state), jnp.arange(0, batch_size, minibatch_size)
            )
            return (params, opt_state, key), metrics

        (params, opt_state, key), metrics = jax.lax.scan(
            update_epoch, (params, opt_state, key), None, length=args['update_epochs']
        )
        
        # Return packing
        avg_return = final_returns.sum() / (final_returns > 0).sum().clip(1.0)
        metrics = jax.tree.map(lambda x: x.mean(), metrics)
        
        return (params, opt_state, next_env_states, episode_returns, key), (metrics, avg_return)

    return jit(train_step)

## 6. Execution Loop

We are ready to launch.

1.  **Compilation:** The first call to `train_step` will trigger XLA compilation. This might take 10-30 seconds. JAX is unrolling the graphs and optimizing memory layout.
2.  **Execution:** Subsequent calls are near-instant. The Python loop merely dispatches the command to the GPU.

Watch the **SPS (Steps Per Second)**. In a pure JAX implementation, this number can reach into the millions on modern hardware, orders of magnitude faster than standard Python loops.

In [None]:
from IPython.display import Image as IPyImage, display

def record_video(params, env, run_name, step):
    frames = []
    # Use a fresh key for video
    key = jax.random.PRNGKey(0)
    state = env.reset(key)
    
    for _ in range(500):
        # Render
        frame = render_cartpole(np.array(state))
        frames.append(frame)
        
        # Step
        logits = get_action_logits(params, state)
        action = jnp.argmax(logits)
        state, _, done = env.step(state, action)
        
        if done:
            break
            
    video_dir = f"videos/{run_name}"
    os.makedirs(video_dir, exist_ok=True)
    # Save as GIF for reliable autoplay in notebooks
    video_path = f"{video_dir}/step_{step}.gif"
    imageio.mimsave(video_path, frames, fps=50, loop=0)
    return video_path

config = {
    'seed': 42,
    'total_timesteps': 500000,
    'num_envs': 64,
    'num_steps': 128,
    'lr': 2.5e-4,
    'num_minibatches': 4,
    'update_epochs': 4,
    'gamma': 0.99,
    'gae_lambda': 0.95,
    'clip_coef': 0.2,
    'ent_coef': 0.01,
    'vf_coef': 0.5,
    'max_grad_norm': 0.5,
    'norm_adv': True,
    'clip_vloss': True,
    'capture_video': True,
    'video_freq': 20,
}

print("Compiling...")
train_step = make_train(config)

# Init State
rng = jax.random.PRNGKey(config['seed'])
rng, init_key = jax.random.split(rng)
params = init_actor_critic(init_key, obs_dim=4, action_dim=2)
opt_state = {'m': jax.tree.map(jnp.zeros_like, params), 'v': jax.tree.map(jnp.zeros_like, params), 'step': 0}
rng, *env_keys = jax.random.split(rng, config['num_envs'] + 1)
env_states = vmap(PureJaxCartPole().reset)(jnp.array(env_keys))
episode_returns = jnp.zeros(config['num_envs'])

runner_state = (params, opt_state, env_states, episode_returns, rng)
batch_size = config['num_envs'] * config['num_steps']
num_updates = config['total_timesteps'] // batch_size

print(f"Starting training for {num_updates} updates...")
start_time = time.perf_counter()
run_name = f"ppo_notebook_{int(time.time())}"

for i in range(num_updates):
    runner_state, (metrics, avg_return) = train_step(runner_state, None)
    
    if i % 10 == 0:
        sps = (i + 1) * batch_size / (time.perf_counter() - start_time)
        print(f"Update {i}: Return {avg_return:.2f} | SPS: {int(sps)} | Loss: {metrics[0].mean():.3f}")
        
    if config['capture_video'] and i % config['video_freq'] == 0:
        # Extract params from runner_state (it's the first element)
        current_params = runner_state[0]
        video_path = record_video(current_params, PureJaxCartPole(), run_name, i * batch_size)
        print(f"Displaying video for step {i * batch_size}:")
        display(IPyImage(filename=video_path))

print("Done!")