# JAX PPO: Continuous CartPole Swing-Up

This notebook implements a high-performance PPO agent to solve the **CartPole Swing-Up** problem. Unlike the standard CartPole task (where the pole starts upright), here the pole starts hanging down. The agent must learn to swing it up, stabilize it, and keep it centered.

### Mathematical Framework: Markov Decision Process (MDP)
We model this problem as an MDP defined by the tuple $(\mathcal{S}, \mathcal{A}, P, R, \gamma)$:
- **State Space $\mathcal{S}$**: Continuous state vector $s_t = [x, \dot{x}, \cos(\theta), \sin(\theta), \dot{\theta}] \in \mathbb{R}^5$.
- **Action Space $\mathcal{A}$**: Continuous action $a_t \in [-1, 1]$ representing the force applied to the cart.
- **Dynamics $P(s_{t+1}|s_t, a_t)$**: Deterministic physics simulation of the cart-pole system.
- **Reward $R(s_t, a_t)$**: A dense reward function encouraging swing-up and balance.
- **Discount Factor $\gamma$**: We use $\gamma = 0.995$ to encourage long-term planning.

**Key Features:**
1.  **Continuous Control:** The action space is continuous (force magnitude), modeled by a Gaussian distribution.
2.  **Pure JAX:** The environment, policy, and optimizer are all written in JAX. This allows us to JIT-compile the entire training step (simulation + backprop) into a single GPU kernel.
3.  **End-to-End Compilation:** Data never leaves the GPU during training, eliminating the CPU-GPU communication bottleneck.

In [79]:
import os
import time
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad, vmap
import numpy as np
import imageio
from PIL import Image, ImageDraw
import warnings
from IPython.display import Video, display

# Suppress warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

## 1. The Environment: Physics & Rewards

We implement the physics of the Swing-Up task from scratch.

### Physics Dynamics
The system consists of a cart of mass $M$ and a pole of mass $m$ and length $l$. The state update is governed by the Euler-Lagrange equations. The angular acceleration $\ddot{\theta}$ and linear acceleration $\ddot{x}$ are derived as:

$$\ddot{\theta} = \frac{g \sin\theta - \cos\theta (\frac{F + m l \dot{\theta}^2 \sin\theta}{M+m})}{l (\frac{4}{3} - \frac{m \cos^2\theta}{M+m})}$$

$$\ddot{x} = \frac{F + m l (\dot{\theta}^2 \sin\theta - \ddot{\theta} \cos\theta)}{M+m}$$

We use semi-implicit Euler integration to update the state: $s_{t+1} = s_t + \dot{s}_t \cdot \Delta t$.

### The Reward Function $R(s, a)$
Designing the reward is the hardest part of this task. We use a composite reward function to guide the agent through the phases of swinging up and balancing:

$$R(s, a) = r_{height} + r_{upright} + r_{pos} + r_{vel} + r_{action}$$

Where:
1.  **Height Reward**: $r_{height} = \cos(\theta)$. Encourages the pole to be above the horizontal line.
2.  **Upright Bonus**: $r_{upright} = 1.5 \exp(-10 (1 - \cos(\theta))^2)$. A sharp peak reward when the pole is nearly vertical.
3.  **Position Penalty**: $r_{pos} = -0.01 x^2 - 0.3 x^2 \cdot \mathbb{I}_{balanced}$. We penalize moving away from the center, but the penalty is *stronger* when balanced to keep it centered, and *weaker* during swing-up to allow momentum gathering.
4.  **Velocity Penalty**: $r_{vel} = -0.05 \dot{\theta}^2 \cdot \mathbb{I}_{balanced}$. Dampens oscillation when upright.
5.  **Action Penalty**: $r_{action} = -0.0005 a^2$. Encourages energy efficiency.

In [80]:
class PureJaxCartPoleSwingUp:
    def __init__(self):
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = (self.masscart + self.masspole)
        self.length = 0.5
        self.polemass_length = (self.masspole * self.length)
        self.force_mag = 20.0
        self.tau = 0.02
        self.x_threshold = 2.4

    def reset(self, key):
        k_x, k_theta, k_vel = jax.random.split(key, 3)
        
        x = jax.random.uniform(k_x, minval=-1.0, maxval=1.0)
        
        # Start the pole at ANY angle (0 to 2pi)
        # It might spawn upright, upside down, or sideways.
        theta = jax.random.uniform(k_theta, minval=0, maxval=2*jnp.pi)
        
        return jnp.array([x, 0.0, theta, 0.0])

        
    def get_obs(self, state):
        x, x_dot, theta, theta_dot = state
        # Normalization hint: Scale raw values roughly to [-1, 1] range
        return jnp.array([
            x / 2.4,
            x_dot / 2.0,
            jnp.cos(theta),
            jnp.sin(theta),
            theta_dot / 3.0
        ])

    def step(self, state, action):
        x, x_dot, theta, theta_dot = state

        # Action is raw network output. We clip it here for physics safety.
        force = jnp.clip(action[0], -1.0, 1.0) * self.force_mag

        costheta = jnp.cos(theta)
        sintheta = jnp.sin(theta)

        temp = (force + self.polemass_length * theta_dot**2 * sintheta) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass))
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        x = x + self.tau * x_dot
        x_dot = x_dot + self.tau * xacc
        theta = theta + self.tau * theta_dot
        theta_dot = theta_dot + self.tau * thetaacc

        next_state = jnp.array([x, x_dot, theta, theta_dot])

        # --- Conditional Position Penalty (Balanced) ---
        pole_height = jnp.cos(theta)
        r_height = pole_height

        # Upright bonus
        upright_bonus = 1.5 * jnp.exp(-10.0 * (1.0 - pole_height)**2)

        # Balanced position penalty
        is_upright = (pole_height > 0.95).astype(jnp.float32)
        r_pos = -0.01 * (x**2) - 0.3 * (x**2) * is_upright

        # Velocity penalty when upright
        r_vel = -0.05 * (theta_dot**2) * is_upright

        # Action penalty
        r_action = -0.0005 * (action[0]**2)

        reward = r_height + upright_bonus + r_pos + r_vel + r_action

        done = (x < -self.x_threshold) | (x > self.x_threshold)

        return next_state, reward, done

In [81]:
def render_cartpole(state, width=608, height=400):
    x, _, theta, _ = state
    world_width = 2.4 * 2
    scale = width / world_width
    carty = 250
    polelen = scale * 1.0

    img = Image.new('RGB', (width, height), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    draw.line([(0, carty), (width, carty)], fill=(0, 0, 0), width=1)

    cartx = x * scale + width / 2.0
    draw.rectangle([cartx-25, carty-15, cartx+25, carty+15], fill=(0, 0, 0))

    tip_x = cartx + polelen * np.sin(theta)
    tip_y = carty - polelen * np.cos(theta)

    draw.line([(cartx, carty), (tip_x, tip_y)], fill=(204, 153, 102), width=12)
    draw.ellipse([cartx - 5, carty - 5, cartx + 5, carty + 5], fill=(127, 127, 255))
    return np.array(img)

## 2. Neural Networks (Continuous Actor-Critic)

For continuous control, we use an Actor-Critic architecture where the policy $\pi_\theta$ outputs the parameters of a probability distribution.

### Gaussian Policy
The Actor network outputs the mean $\mu_\theta(s)$ of a Gaussian distribution. The standard deviation $\sigma$ is a learnable parameter (state-independent parameter vector). The action $a$ is sampled as:

$$a \sim \mathcal{N}(\mu_\theta(s), \sigma)$$

The log-likelihood of an action $a$, which is needed for the policy gradient, is given by:

$$\log \pi_\theta(a|s) = -\frac{1}{2} \left(\frac{a - \mu_\theta(s)}{\sigma}\right)^2 - \log \sigma - \frac{1}{2} \log 2\pi$$

We clamp `log_std` to the range $[-2.0, 0.5]$ to prevent the policy from collapsing (becoming deterministic) or becoming too entropic (random noise).

### Critic Network
The Critic $V_\phi(s)$ estimates the Value function, which is the expected discounted return from state $s$:

$$V^\pi(s) = \mathbb{E}_{\tau \sim \pi} \left[ \sum_{t=0}^\infty \gamma^t r_t \mid s_0 = s \right]$$

Both networks are MLPs with 2 hidden layers of 128 units and Tanh activation.

In [82]:
def orthogonal_init(key, shape, scale=1.0):
    flat_shape = (shape[0], np.prod(shape[1:]))
    a = jax.random.normal(key, flat_shape)
    u, _, vt = jnp.linalg.svd(a, full_matrices=False)
    q = u if u.shape == flat_shape else vt
    q = q.reshape(shape)
    return scale * q

def init_mlp_layer(key, in_dim, out_dim, scale=1.0, bias_init=0.0):
    k_w, k_b = jax.random.split(key)
    return {
        'w': orthogonal_init(k_w, (in_dim, out_dim), scale),
        'b': jnp.full((out_dim,), bias_init)
    }

def init_actor_critic(key, obs_dim, action_dim, hidden_dim=128):
    k1, k2, k3, k4, k5, k6 = jax.random.split(key, 6)
    return {
        'actor': {
            'l1': init_mlp_layer(k1, obs_dim, hidden_dim, scale=np.sqrt(2)),
            'l2': init_mlp_layer(k2, hidden_dim, hidden_dim, scale=np.sqrt(2)),
            'mean': init_mlp_layer(k3, hidden_dim, action_dim, scale=0.01),
            'log_std': jnp.full((action_dim,), -0.5)
        },
        'critic': {
            'l1': init_mlp_layer(k4, obs_dim, hidden_dim, scale=np.sqrt(2)),
            'l2': init_mlp_layer(k5, hidden_dim, hidden_dim, scale=np.sqrt(2)),
            'head': init_mlp_layer(k6, hidden_dim, 1, scale=1.0)
        }
    }

def forward_mlp(params, x, activation=jax.nn.tanh):
    x = x @ params['l1']['w'] + params['l1']['b']
    x = activation(x)
    x = x @ params['l2']['w'] + params['l2']['b']
    x = activation(x)
    return x

def get_action_dist(actor_params, obs):
    x = forward_mlp(actor_params, obs)
    mean = x @ actor_params['mean']['w'] + actor_params['mean']['b']
    log_std = jnp.clip(actor_params['log_std'], -2.0, 0.5)
    return mean, log_std

def get_value(critic_params, obs):
    x = forward_mlp(critic_params, obs)
    return (x @ critic_params['head']['w'] + critic_params['head']['b']).squeeze(-1)

# --- Pure JAX Adam Optimizer ---
def init_adam_state(params):
    m = jax.tree.map(jnp.zeros_like, params)
    v = jax.tree.map(jnp.zeros_like, params)
    return {'m': m, 'v': v, 'step': 0}

def adam_update(grads, opt_state, params, lr, max_grad_norm=None, beta1=0.9, beta2=0.999, eps=1e-8):
    step = opt_state['step'] + 1
    m = opt_state['m']
    v = opt_state['v']
    if max_grad_norm is not None:
        leaves, _ = jax.tree_util.tree_flatten(grads)
        total_norm = jnp.sqrt(sum(jnp.sum(g ** 2) for g in leaves))
        clip_coef = jnp.minimum(max_grad_norm / (total_norm + 1e-6), 1.0)
        grads = jax.tree.map(lambda g: g * clip_coef, grads)
    m = jax.tree.map(lambda m_i, g_i: beta1 * m_i + (1 - beta1) * g_i, m, grads)
    v = jax.tree.map(lambda v_i, g_i: beta2 * v_i + (1 - beta2) * (g_i ** 2), v, grads)
    m_hat = jax.tree.map(lambda m_i: m_i / (1 - beta1 ** step), m)
    v_hat = jax.tree.map(lambda v_i: v_i / (1 - beta2 ** step), v)
    params = jax.tree.map(lambda p_i, m_h, v_h: p_i - lr * m_h / (jnp.sqrt(v_h) + eps), params, m_hat, v_hat)
    return params, {'m': m, 'v': v, 'step': step}

## 3. The PPO Training Loop

We use Proximal Policy Optimization (PPO), a stable and efficient policy gradient method.

### Generalized Advantage Estimation (GAE)
To reduce variance in our policy gradient estimates, we use GAE. The advantage $\hat{A}_t$ is calculated as an exponentially weighted average of TD-errors $\delta_t$:

$$\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)$$
$$\hat{A}_t = \sum_{k=0}^{\infty} (\gamma \lambda)^k \delta_{t+k}$$

### PPO Objective Function
PPO prevents destructive policy updates by clipping the probability ratio $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$. The objective is:

$$L^{CLIP}(\theta) = \hat{\mathbb{E}}_t [\min(r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t)]$$

### Total Loss
We optimize a joint loss function including the value function error and an entropy bonus to encourage exploration:

$$L^{total} = -L^{CLIP}(\theta) + c_{vf} (V_\phi(s_t) - V_t^{target})^2 - c_{ent} S[\pi](s_t)$$

We compile the entire rollout and update process into a single `train_segment` function using `jax.lax.scan`, which runs entirely on the GPU.

In [83]:
class Args:
    seed = 1
    total_timesteps = 100_000_000
    num_envs = 1024
    num_steps = 128
    learning_rate = 0.0001
    num_minibatches = 4
    update_epochs = 4
    gamma = 0.995
    gae_lambda = 0.95
    clip_coef = 0.1
    ent_coef = 0.01
    vf_coef = 0.5
    max_grad_norm = 1.0
    norm_adv = True
    capture_video = True
    checkpoint = 200

args = Args()

def run_training(args):
    run_name = f"JaxSwingUp_Stable_{int(time.time())}"
    print(f"Running STABLE JAX CartPole Swing-Up: {run_name}")
    print(f"JAX is running on: {jax.default_backend().upper()} device(s).")

    # Derived constants
    num_envs = args.num_envs
    num_steps = args.num_steps
    batch_size = num_envs * num_steps
    minibatch_size = batch_size // args.num_minibatches
    total_iterations = args.total_timesteps // batch_size

    env = PureJaxCartPoleSwingUp()
    key = jax.random.PRNGKey(args.seed)
    key, init_key = jax.random.split(key)

    params = init_actor_critic(init_key, obs_dim=5, action_dim=1, hidden_dim=128)
    opt_state = init_adam_state(params)

    key, *env_keys = jax.random.split(key, num_envs + 1)
    env_states = vmap(env.reset)(jnp.array(env_keys))
    episode_returns = jnp.zeros(num_envs)

    # --- JIT Compiled Update ---
    @jit
    def train_segment(carry, update_i):
        params, opt_state, env_states, episode_returns, key = carry

        # --- 1. Rollout ---
        def rollout_step(carry, step_idx):
            env_states, episode_returns, episode_lengths, key = carry
            key, subkey = jax.random.split(key)
            obs = vmap(env.get_obs)(env_states)

            mean, log_std = get_action_dist(params['actor'], obs)
            std = jnp.exp(log_std)
            noise = jax.random.normal(subkey, mean.shape)
            action = mean + std * noise

            action_log_prob = -0.5 * ((action - mean) / std) ** 2 - log_std - 0.5 * jnp.log(2 * jnp.pi)
            action_log_prob = action_log_prob.sum(axis=-1)

            value = get_value(params['critic'], obs)
            next_env_states, reward, done_boundary = vmap(env.step)(env_states, action)

            episode_lengths = episode_lengths + 1
            done_length = episode_lengths >= 500
            done = done_boundary | done_length

            key, *reset_keys = jax.random.split(key, num_envs + 1)
            reset_states = vmap(env.reset)(jnp.array(reset_keys))
            next_env_states = jnp.where(done[:, None], reset_states, next_env_states)

            episode_returns = episode_returns + reward
            final_return = jnp.where(done, episode_returns, 0.0)
            episode_returns = jnp.where(done, 0.0, episode_returns)
            episode_lengths = jnp.where(done, 0, episode_lengths)

            transition = (obs, action, action_log_prob, reward, done, value, final_return)
            return (next_env_states, episode_returns, episode_lengths, key), transition

        (next_env_states, episode_returns, episode_lengths, key), traj = jax.lax.scan(
            rollout_step, (env_states, episode_returns, jnp.zeros(num_envs, dtype=jnp.int32), key), None, length=num_steps
        )
        obs, actions, logprobs, rewards, dones, values, final_returns = traj

        # --- 2. GAE ---
        next_obs = vmap(env.get_obs)(next_env_states)
        next_value = get_value(params['critic'], next_obs)

        def gae_scan(carry, t):
            last_gae_lam, next_val = carry
            delta = rewards[t] + args.gamma * next_val * (1.0 - dones[t]) - values[t]
            last_gae_lam = delta + args.gamma * args.gae_lambda * (1.0 - dones[t]) * last_gae_lam
            return (last_gae_lam, values[t]), last_gae_lam

        _, advantages = jax.lax.scan(gae_scan, (jnp.zeros_like(next_value), next_value), jnp.arange(num_steps), reverse=True)
        returns = advantages + values

        # Flatten
        b_obs = obs.reshape((batch_size, -1))
        b_logprobs = logprobs.reshape(batch_size)
        b_actions = actions.reshape((batch_size, -1))
        b_advantages = advantages.reshape(batch_size)
        b_returns = returns.reshape(batch_size)
        b_values = values.reshape(batch_size)

        if args.norm_adv:
            b_advantages = (b_advantages - b_advantages.mean()) / (b_advantages.std() + 1e-8)

        # --- 3. Update ---
        frac = 1.0 - (update_i / total_iterations)
        current_lr = args.learning_rate * frac

        def update_epoch(carry, _):
            params, opt_state, key = carry
            key, subkey = jax.random.split(key)
            inds = jax.random.permutation(subkey, batch_size)

            sb_obs, sb_logprobs, sb_actions = b_obs[inds], b_logprobs[inds], b_actions[inds]
            sb_advantages, sb_returns, sb_values = b_advantages[inds], b_returns[inds], b_values[inds]

            def process_minibatch(carry, i):
                params, opt_state = carry
                start = i * minibatch_size
                mb_obs = jax.lax.dynamic_slice(sb_obs, (start, 0), (minibatch_size, sb_obs.shape[1]))
                mb_logprobs = jax.lax.dynamic_slice(sb_logprobs, (start,), (minibatch_size,))
                mb_actions = jax.lax.dynamic_slice(sb_actions, (start, 0), (minibatch_size, sb_actions.shape[1]))
                mb_advantages = jax.lax.dynamic_slice(sb_advantages, (start,), (minibatch_size,))
                mb_returns = jax.lax.dynamic_slice(sb_returns, (start,), (minibatch_size,))
                mb_values = jax.lax.dynamic_slice(sb_values, (start,), (minibatch_size,))

                def loss_fn(params):
                    mean, log_std = get_action_dist(params['actor'], mb_obs)
                    std = jnp.exp(log_std)
                    new_logprobs = -0.5 * ((mb_actions - mean) / std) ** 2 - log_std - 0.5 * jnp.log(2 * jnp.pi)
                    new_logprobs = new_logprobs.sum(axis=-1)

                    entropy = (log_std + 0.5 + 0.5 * jnp.log(2 * jnp.pi)).sum(axis=-1).mean()

                    logratio = new_logprobs - mb_logprobs
                    ratio = jnp.exp(logratio)

                    pg_loss = -jnp.minimum(
                        mb_advantages * ratio,
                        mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                    ).mean()

                    new_values = get_value(params['critic'], mb_obs)
                    v_loss = 0.5 * ((new_values - mb_returns) ** 2).mean()

                    loss = pg_loss - args.ent_coef * entropy + args.vf_coef * v_loss
                    return loss, (pg_loss, v_loss, entropy)

                (loss, metrics), grads = value_and_grad(loss_fn, has_aux=True)(params)
                params, opt_state = adam_update(grads, opt_state, params, current_lr, args.max_grad_norm)
                return (params, opt_state), metrics

            (params, opt_state), batch_metrics = jax.lax.scan(process_minibatch, (params, opt_state), jnp.arange(args.num_minibatches))
            return (params, opt_state, key), batch_metrics

        (params, opt_state, key), epoch_metrics = jax.lax.scan(update_epoch, (params, opt_state, key), None, length=args.update_epochs)

        num_dones = dones.sum()
        avg_ret = jnp.where(num_dones > 0, final_returns.sum() / num_dones, 0.0)
        metrics = jax.tree.map(lambda x: x.mean(), epoch_metrics)

        return (params, opt_state, next_env_states, episode_returns, key), (avg_ret, metrics)

    # --- Training Loop ---
    print(f"Starting training for {total_iterations} iterations...")
    moving_avg_ret = 0.0
    start_time = time.perf_counter()
    last_time = start_time

    for i in range(total_iterations):
        (params, opt_state, env_states, episode_returns, key), (avg_ret, metrics) = train_segment(
            (params, opt_state, env_states, episode_returns, key), i
        )
        jax.block_until_ready(params)

        if avg_ret != 0:
            moving_avg_ret = 0.95 * moving_avg_ret + 0.05 * avg_ret if moving_avg_ret != 0 else avg_ret

        if i % args.checkpoint == 0 or i == total_iterations - 1:
            current_time = time.perf_counter()
            sps = (args.checkpoint * batch_size) / (current_time - last_time)
            last_time = current_time
            elapsed = current_time - start_time
            eta = (total_iterations - i) * (elapsed / i) if i > 0 else 0

            pg_loss, v_loss, entropy = metrics
            print(f"Update: {i}/{total_iterations} | Return: {moving_avg_ret:.2f} | SPS: {int(sps)} | "
                  f"PLoss: {pg_loss.item():.3f} | Ent: {entropy.item():.3f}")

            if args.capture_video and i > 0:
                video_path = record_video(params, env, run_name, i * batch_size)
                display(Video(video_path, embed=True, html_attributes="loop autoplay muted"))

def format_time(seconds):
    h = int(seconds // 3600)
    m = int((seconds % 3600) // 60)
    s = int(seconds % 60)
    return f"{h:02d}:{m:02d}:{s:02d}"

def record_video(params, env, run_name, step):
    num_panels = 3
    panel_width = 400
    panel_height = 300
    spacer_width = 20  # --- CONFIG: Width of the gap between videos ---
    
    frames = []
    
    # 1. SETUP KEYS
    master_key = jax.random.PRNGKey(step)
    reset_keys = jax.random.split(master_key, num_panels)
    
    # 2. BATCHED RESET
    states = vmap(env.reset)(reset_keys)
    
    print(f"Recording multi-view video with spacing at step {step}")
    
    # --- PRE-ALLOCATE SPACER ---
    # Create a solid white block of pixels (H x W x C)
    # Using 255 for white (assuming RGB uint8 images)
    spacer = np.full((panel_height, spacer_width, 3), 255, dtype=np.uint8)
    
    # 3. RUN LOOP
    for t in range(400):
        # --- A. Render & Stitch ---
        panel_images = []
        for i in range(num_panels):
            single_state = np.array(states[i])
            # Ensure render output is uint8 for consistent stitching
            img = render_cartpole(single_state, width=panel_width, height=panel_height).astype(np.uint8)
            panel_images.append(img)
            
        # BUILD THE STITCHED ROW WITH SPACERS
        elements_to_join = []
        for i in range(num_panels):
            elements_to_join.append(panel_images[i])
            # Add a spacer after every panel EXCEPT the last one
            if i < num_panels - 1:
                elements_to_join.append(spacer)
                
        # Concatenate: [Img1, Spacer, Img2, Spacer, Img3]
        combined_frame = np.concatenate(elements_to_join, axis=1)
        frames.append(combined_frame)
        
        # --- B. Batched Step ---
        obs = vmap(env.get_obs)(states)
        mean, _ = get_action_dist(params['actor'], obs)
        action = mean 
        states, _, dones = vmap(env.step)(states, action)
        
        if dones.all() and t > 50: # Ensure at least some frames are recorded
            break
            
    # 4. SAVE
    video_dir = f"videos/{run_name}"
    os.makedirs(video_dir, exist_ok=True)
    path = f"{video_dir}/step_{step}_multi_spaced.mp4"
    # Ensure fps is an integer
    imageio.mimsave(path, frames, fps=50)
    
    return path

In [84]:
run_training(args)

Running STABLE JAX CartPole Swing-Up: JaxSwingUp_Stable_1764609263
JAX is running on: GPU device(s).
Starting training for 762 iterations...
Update: 0/762 | Return: -14.97 | SPS: 3226580 | PLoss: -0.001 | Ent: 0.920
Update: 200/762 | Return: 57.43 | SPS: 1214443 | PLoss: -0.000 | Ent: 1.018
Recording multi-view video with spacing at step 26214400




Update: 400/762 | Return: 77.32 | SPS: 1281173 | PLoss: -0.000 | Ent: 1.136
Recording multi-view video with spacing at step 52428800




Update: 600/762 | Return: 3663.62 | SPS: 1136239 | PLoss: -0.000 | Ent: 1.183
Recording multi-view video with spacing at step 78643200




Update: 761/762 | Return: 3663.62 | SPS: 712820 | PLoss: -0.000 | Ent: 1.198
Recording multi-view video with spacing at step 99745792


