![MuJoCo banner](https://raw.githubusercontent.com/google-deepmind/mujoco/main/banner.png)

# <h1><center>Tutorial  <a href="https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" width="140" align="center"/></a></center></h1>

This notebook provides an introductory tutorial for [**MuJoCo XLA (MJX)**](https://github.com/google-deepmind/mujoco/blob/main/mjx), a JAX-based implementation of MuJoCo useful for RL training workloads.

**A Colab runtime with GPU acceleration is required.** If you're using a CPU-only runtime, you can switch using the menu "Runtime > Change runtime type".










# Install MuJoCo, MJX, and Brax

In [None]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax


In [None]:
#@title Check if MuJoCo installation was successful

from google.colab import files

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags


In [None]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

In [None]:
import os
os.environ['MUJOCO_GL'] = 'egl' # Ensure EGL rendering is used

from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict


import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

# Simple ENV with spider



## First Version of Spider

In [None]:
# Generate a simple spider XML model
spider_xml = """
<mujoco>
  <worldbody>
    <light name="top" pos="0 0 1"/>
    <geom name="floor" type="plane" size="10 10 .1" rgba=".9 .9 .9 1"/>
    <body name="torso" pos="0 0 0.25">
      <geom type="sphere" size="0.1" rgba="1 0 0 1"/>
      <joint type="free"/>
      <body name="leg1" pos="0.1 0 0">
        <joint name="joint1" type="hinge" axis="0 1 0" pos="0 0 0"/>
        <geom type="capsule" fromto=".0 .0 .0 .3 .0 .0" size=".02" rgba="0 1 0 1"/>
        <body name="foot1" pos=".3 0 0">
          <joint name="joint2" type="hinge" axis="0 1 0" pos="0 0 0"/>
          <geom type="capsule" fromto=".0 .0 .0 .3 .0 .0" size=".02" rgba="0 1 0 1"/>
        </body>
      </body>
       <body name="leg2" pos="-0.1 0 0">
        <joint name="joint3" type="hinge" axis="0 1 0" pos="0 0 0"/>
        <geom type="capsule" fromto=".0 .0 .0 -.3 .0 .0" size=".02" rgba="0 1 0 1"/>
        <body name="foot2" pos="-.3 0 0">
          <joint name="joint4" type="hinge" axis="0 1 0" pos="0 0 0"/>
          <geom type="capsule" fromto=".0 .0 .0 -.3 .0 .0" size=".02" rgba="0 1 0 1"/>
        </body>
      </body>
       <body name="leg3" pos="0 0.1 0">
        <joint name="joint5" type="hinge" axis="1 0 0" pos="0 0 0"/>
        <geom type="capsule" fromto=".0 .0 .0 .0 .3 .0" size=".02" rgba="0 1 0 1"/>
         <body name="foot3" pos="0 .3 0">
          <joint name="joint6" type="hinge" axis="1 0 0" pos="0 0 0"/>
          <geom type="capsule" fromto=".0 .0 .0 .0 .3 .0" size=".02" rgba="0 1 0 1"/>
        </body>
      </body>
       <body name="leg4" pos="0 -0.1 0">
        <joint name="joint7" type="hinge" axis="1 0 0" pos="0 0 0"/>
        <geom type="capsule" fromto=".0 .0 .0 .0 -.3 .0" size=".02" rgba="0 1 0 1"/>
        <body name="foot4" pos="0 -.3 0">
          <joint name="joint8" type="hinge" axis="1 0 0" pos="0 0 0"/>
          <geom type="capsule" fromto=".0 .0 .0 .0 -.3 .0" size=".02" rgba="0 1 0 1"/>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor name="act1" joint="joint1"/>
    <motor name="act2" joint="joint2"/>
    <motor name="act3" joint="joint3"/>
    <motor name="act4" joint="joint4"/>
    <motor name="act5" joint="joint5"/>
    <motor name="act6" joint="joint6"/>
    <motor name="act7" joint="joint7"/>
    <motor name="act8" joint="joint8"/>
  </actuator>
</mujoco>
"""

## Second Version of spider

In [None]:
spider_xml = """
<mujoco model="ant">
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <option integrator="RK4" timestep="0.01"/>
  <custom>
    <numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
  </custom>
  <default>
    <joint armature="1" damping="1" limited="true"/>
    <geom conaffinity="0" condim="3" density="5.0" friction="1 0.5 0.5" margin="0.01" rgba="0.8 0.6 0.4 1"/>
  </default>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos="0 0 0.75">
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <geom name="torso_geom" pos="0 0 0" size="0.25" type="sphere"/>
      <joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/>
      <body name="front_left_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux_1_geom" size="0.08" type="capsule"/>
        <body name="aux_1" pos="0.2 0.2 0">
          <joint axis="0 0 1" name="hip_1" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 0.2 0">
            <joint axis="-1 1 0" name="ankle_1" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="front_right_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>
        <body name="aux_2" pos="-0.2 0.2 0">
          <joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>
          <body pos="-0.2 0.2 0">
            <joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="back_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>
        <body name="aux_3" pos="-0.2 -0.2 0">
          <joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>
          <body pos="-0.2 -0.2 0">
            <joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="right_back_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux_4_geom" size="0.08" type="capsule"/>
        <body name="aux_4" pos="0.2 -0.2 0">
          <joint axis="0 0 1" name="hip_4" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 -0.2 0">
            <joint axis="1 1 0" name="ankle_4" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="fourth_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_3" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_3" gear="150"/>
  </actuator>
</mujoco>
"""

## Spider Env

In [None]:

# HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'

class Humanoid(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=10.0,
      ctrl_cost_weight=0.1,
      healthy_reward=2.5,
      terminate_when_unhealthy=False, # Set to False to prevent early termination
      healthy_z_range=(0.1, 0.3), # Lowered healthy_z_range for a spider
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      episode_length: int = 1000, # <--- Added episode_length with default
      **kwargs,
  ):
#


    mj_model = mujoco.MjModel.from_xml_string(spider_xml)
    mj_data = mujoco.MjData(mj_model) # Create mj_data
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'
    # Store episode_length as an attribute of the Humanoid instance
    self.episode_length = episode_length
    super().__init__(sys, **kwargs)

    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )
    self._torso_body_idx = mujoco.mj_name2id(
        self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, 'torso'
    )
    # Store mj_data for access to qpos0 and qvel0
    self._mj_data = mj_data


  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = jp.asarray(self._mj_data.qpos) + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jp.asarray(self._mj_data.qvel) + jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    # Calculate forward velocity using the torso's velocity
    torso_velocity = data.cvel[self._torso_body_idx, 0]
    forward_reward = self._forward_reward_weight * torso_velocity


    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action)
    reward = forward_reward + healthy_reward - ctrl_cost
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=data.xpos[self._torso_body_idx, 0],
        y_position=data.xpos[self._torso_body_idx, 1],
        distance_from_origin=jp.linalg.norm(data.xpos[self._torso_body_idx, :2]),
        x_velocity=torso_velocity,
        y_velocity=data.cvel[self._torso_body_idx, 1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    # Simplified observation: torso z-position and x-velocity
    # return jp.concatenate([
    #     data.qpos[2:3],  # Torso z-position
    #     data.cvel[self._torso_body_idx, 0:1], # Torso x-velocity
    # ])

    # More detailed observation for a spider: torso z-position, torso x-velocity,
    # and joint positions and velocities
    return jp.concatenate([
        data.qpos[2:3],  # Torso z-position
        data.cvel[self._torso_body_idx, 0:1], # Torso x-velocity
        data.qpos[7:], # Joint positions (excluding free joint)
        data.qvel[6:], # Joint velocities (excluding free joint)
    ])


envs.register_environment('humanoid', Humanoid)

# Add code to get and print the observation shape
env_test = Humanoid()
dummy_data = mujoco.MjData(env_test.sys.mj_model) # Use mj_data
dummy_obs = env_test._get_obs(env_test.pipeline_init(jp.asarray(dummy_data.qpos), jp.asarray(dummy_data.qvel)), jp.zeros(env_test.sys.nu))
print("Observation shape:", dummy_obs.shape)


##(Optional) Show Data class for inspection

In [None]:
# import jax
# import mujoco
# from brax import envs
# import numpy as np

# # Instantiate the environment
# env_name = 'humanoid'
# env = envs.get_environment(env_name)

# # Define the jit reset function
# jit_reset = jax.jit(env.reset)

# # Reset the environment to get an initial state
# rng = jax.random.PRNGKey(0)
# state = jit_reset(rng)

# # The mjx.Data object is stored in state.pipeline_state
# data = state.pipeline_state

# print("--- Contents of mjx.Data object ---")
# print(f"Type of data: {type(data)}\n")

# print("--- All attributes of mjx.Data object ---\n")
# for attr_name in sorted(dir(data)):
#     if not attr_name.startswith('_'): # Exclude private attributes
#         attr_value = getattr(data, attr_name)
#         attr_type = type(attr_value)
#         if isinstance(attr_value, (np.ndarray, jax.Array)): # Check if it's a JAX or NumPy array
#             print(f"  Attribute: {attr_name}, Type: {attr_type}, Shape: {attr_value.shape}")
#         else:
#             print(f"  Attribute: {attr_name}, Type: {attr_type}")

# print("\n--- Specific attributes shown before ---")
# print("1. Generalized positions (qpos):")
# print(f"  Shape: {data.qpos.shape}")
# print(f"  Value (first 10 elements): {data.qpos[:10]}\n")

# print("2. Generalized velocities (qvel):")
# print(f"  Shape: {data.qvel.shape}")
# print(f"  Value (first 10 elements): {data.qvel[:10]}\n")

# print("3. Cartesian position of bodies (xpos):")
# print(f"  Shape: {data.xpos.shape}")
# print(f"  Value (first 5 bodies):\n{data.xpos[:5]}\n")

# print("4. Cartesian orientation of bodies (xquat):")
# print(f"  Shape: {data.xquat.shape}")
# print(f"  Value (first 5 bodies):\n{data.xquat[:5]}\n")

# print("5. Center of mass velocity (cvel):")
# print(f"  Shape: {data.cvel.shape}")
# print(f"  Value (first 5 bodies):\n{data.cvel[:5]}\n")

# print("6. Body indices (for reference, not part of data object directly but useful):")
# print(f"  Torso body index: {env._torso_body_idx}")

## (Optional) Inspect action

In [None]:
# import jax
# from brax import envs
# import jax.numpy as jp
# from brax.training.agents.ppo import networks as ppo_networks
# from brax.training.agents.ppo import train as ppo
# from brax.io import model

# # Re-instantiate the environment if not already available in this scope
# env_name = 'humanoid'
# env = envs.get_environment(env_name)

# # Define and train a minimal policy if jit_inference_fn is not defined
# try:
#     # Attempt to use jit_inference_fn if already defined (e.g., from a previous run of training cells)
#     _ = jit_inference_fn
# except NameError:
#     print("jit_inference_fn not found, training a minimal policy...")
#     make_inference_fn, params, _ = ppo.train(
#         environment=env, num_timesteps=5_000, num_evals=1, episode_length=env.episode_length # Minimal training just for this demo
#     )
#     model_path = '/tmp/mjx_brax_policy'
#     model.save_params(model_path, params)
#     params = model.load_params(model_path)
#     inference_fn = make_inference_fn(params)
#     jit_inference_fn = jax.jit(inference_fn)
#     print("Minimal policy trained and jit_inference_fn defined.")

# # Initialize the state
# current_rng = jax.random.PRNGKey(1) # Use a new RNG key
# state = jax.jit(env.reset)(current_rng)

# # Generate an action using the inference function
# act_rng, current_rng = jax.random.split(current_rng)
# ctrl, _ = jit_inference_fn(state.obs, act_rng)

# print("--- Content of the 'action' variable (ctrl from policy) ---")
# print(f"Type of action: {type(ctrl)}")
# print(f"Shape of action: {ctrl.shape}")
# print(f"Value of action: {ctrl}")


## Visualize a Rollout

Let's instantiate the environment and visualize a short rollout.

NOTE: Since episodes terminate early if the torso is below the healthy z-range, the only relevant contacts for this task are between the feet and the plane. We turn off other contacts.

In [None]:
# instantiate the environment
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)


In [None]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a trajectory
for i in range(100):
  ctrl = jp.zeros(env.sys.nu) # Set control input to zero for standing still
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

media.show_video(env.render(rollout), fps=1.0 / env.dt)

In [None]:
print("Observation space:", env.observation_size)

# Train Humanoid Policy


In [None]:
# Importujemy 'networks' bezpośrednio z modułu agenta PPO
from brax.training.agents.ppo import networks as ppo_networks
# Typ PPONetworks jest teraz w 'brax.training.types'
from brax.training.agents.ppo.networks import PPONetworks
from brax.io import html, model
import flax.linen as nn

def my_custom_network_factory(
    observation_size: int,
    action_size: int,
    preprocess_observations_fn=None,
    hidden_layer_sizes=(512, 256),
    activation=nn.relu
) -> PPONetworks:
    """Tworzy sieć PPO o niestandardowej architekturze."""

    return ppo_networks.make_networks(
        observation_size=observation_size,
        action_size=action_size,
        preprocess_observations_fn=preprocess_observations_fn,
        hidden_layer_sizes=hidden_layer_sizes,
        activation=activation
    )
train_fn = functools.partial(
    ppo.train, num_timesteps=20_000_00, num_evals=5, reward_scaling=0.1,
    episode_length=1000, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=24, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,
    batch_size=512, seed=0)

# starty, nagrody na wykresie
x_data = []
y_data = [] # eval/episode_reward
ydataerr = [] # eval/episode_reward_std

# New lists for other reward components
y_data_forward = []
ydataerr_forward = []
y_data_quadctrl = []
ydataerr_quadctrl = []
y_data_alive = []
ydataerr_alive = []

times = [datetime.now()]

max_y, min_y = 5000, -2000 # Adjusted max_y and min_y for potentially higher/lower rewards
def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  # Append data for new metrics, using .get() with default 0.0 to handle initial missing keys
  y_data_forward.append(metrics.get('eval/episode_metrics/forward_reward', 0.0))
  ydataerr_forward.append(metrics.get('eval/episode_metrics/forward_reward_std', 0.0))
  y_data_quadctrl.append(metrics.get('eval/episode_metrics/reward_quadctrl', 0.0))
  ydataerr_quadctrl.append(metrics.get('eval/episode_metrics/reward_quadctrl_std', 0.0))
  y_data_alive.append(metrics.get('eval/episode_metrics/reward_alive', 0.0))
  ydataerr_alive.append(metrics.get('eval/episode_metrics/reward_alive_std', 0.0))

  plt.figure(figsize=(12, 16)) # Create a new figure for multiple subplots

  # Plot 1: Total Episode Reward
  plt.subplot(4, 1, 1) # 4 rows, 1 column, first plot
  plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.25])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('Reward per episode')
  plt.title(f'Total Episode Reward: {y_data[-1]:.3f}')
  plt.errorbar(x_data, y_data, yerr=ydataerr, label='Total Reward')
  plt.legend()

  # # Plot 2: Forward Reward
  # plt.subplot(4, 1, 2)
  # plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.25])
  # plt.xlabel('# environment steps')
  # plt.ylabel('Forward Reward')
  # plt.title(f'Forward Reward: {y_data_forward[-1]:.3f}')
  # plt.errorbar(x_data, y_data_forward, yerr=ydataerr_forward, label='Forward Reward', color='green')
  # plt.legend()

  # # Plot 3: Control Cost Reward
  # plt.subplot(4, 1, 3)
  # plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.25])
  # plt.xlabel('# environment steps')
  # plt.ylabel('Control Cost Reward')
  # plt.title(f'Control Cost Reward: {y_data_quadctrl[-1]:.3f}')
  # plt.errorbar(x_data, y_data_quadctrl, yerr=ydataerr_quadctrl, label='Control Cost', color='red')
  # plt.legend()

  # # Plot 4: Alive Reward
  # plt.subplot(4, 1, 4)
  # plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.25])
  # plt.xlabel('# environment steps')
  # plt.ylabel('Alive Reward')
  # plt.title(f'Alive Reward: {y_data_alive[-1]:.3f}')
  # plt.errorbar(x_data, y_data_alive, yerr=ydataerr_alive, label='Alive Reward', color='purple')
  # plt.legend()

  plt.tight_layout() # Adjust layout to prevent overlap
  plt.show()

print("Observation space size:", env.observation_size)
print("Action space size:", env.action_size)
make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

<!-- ## Save and Load Policy -->

We can save and load the policy using the brax model API.

In [None]:
#@title Save Model
model_path = '/tmp/mjx_brax_policy'
model.save_params(model_path, params)

In [None]:
#@title Load Model and Define Inference Function
params = model.load_params(model_path)

inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

## Visualize Policy

Finally we can visualize the policy.

In [None]:
eval_env = envs.get_environment(env_name)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
rollout = [state.pipeline_state]

# grab a trajectory
n_steps = 500
render_every = 2

for i in range(n_steps):
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

  if state.done:
    break

media.show_video(env.render(rollout[::render_every]), fps=1.0 / env.dt / render_every)

# Task
Implement a custom Proximal Policy Optimization (PPO) algorithm using Flax and JAX. This involves defining the actor and critic neural network architectures, implementing Generalized Advantage Estimation (GAE), constructing the PPO loss function, and integrating these components into a PPO learner agent. Finally, develop a training loop to train this custom PPO agent to control the existing Brax `Humanoid` environment (configured as a spider) to achieve locomotion.

## Understand PPO Components

### Subtask:
Outline the theoretical components of the Proximal Policy Optimization (PPO) algorithm, including the actor (policy) network, critic (value) network, Generalized Advantage Estimation (GAE), and the PPO loss function (clipped surrogate objective, value function loss, entropy bonus).


### 1. Actor (Policy) Network

The **Actor network** in PPO is responsible for learning the policy, \(\pi(a|s)\), which maps states to a probability distribution over actions. It directly controls the agent's behavior by outputting the actions the agent should take in a given state. The goal of the actor network is to maximize the expected cumulative reward by selecting optimal actions. During training, the actor's parameters are updated to increase the probability of taking actions that lead to higher rewards, guided by the advantage estimates provided by the critic.

### 2. Critic (Value) Network

The **Critic network** in PPO estimates the value function, \(V(s)\), which predicts the expected cumulative reward from a given state \(s\) onwards. Unlike the actor, the critic does not directly influence the agent's actions; instead, it provides a baseline or a measure of how good a particular state is. This value estimate is crucial for calculating the advantage, which tells the agent how much better a taken action was compared to the average action in that state. The critic's parameters are updated to minimize the difference between its predicted value and the actual observed returns, typically using a mean-squared error loss.

### 3. Generalized Advantage Estimation (GAE)

**Generalized Advantage Estimation (GAE)** is a method used to estimate the advantage function, \(A(s, a) = Q(s, a) - V(s)\), which quantifies how much better a specific action \(a\) is than the average action at a given state \(s\). Instead of using raw Monte Carlo returns (which have high variance) or single-step TD errors (which have high bias), GAE provides a trade-off between bias and variance by using a weighted average of n-step returns. It combines these estimates with two hyperparameters: \(\lambda\) (lambda), which controls the trade-off between bias and variance, and \(\gamma\) (gamma), the discount factor. A common formulation for GAE is:

\[\hat{A}_t^{\text{GAE}(\gamma, \lambda)} = \sum_{l=0}^{\infty} (\gamma\lambda)^l \delta_{t+l}\]

where \(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) is the TD residual. GAE produces more stable and reliable advantage estimates, which are essential for guiding the policy updates in PPO more effectively.

### 4. PPO Loss Function

The PPO loss function is a combination of three main terms, each serving a specific purpose to ensure stable and efficient learning:

#### a. Clipped Surrogate Objective

The **clipped surrogate objective** is the core of PPO's policy update. It aims to maximize a measure of advantage while preventing large policy updates that could lead to instability. It is defined as:

\[L^{CLIP}(\theta) = \mathbb{E}_t[\min(r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t)]\]

where:
*   \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\) is the ratio of the new policy's probability of taking action \(a_t\) in state \(s_t\) to the old policy's probability.
*   \(\hat{A}_t\) is the advantage estimate at time \(t\) (typically from GAE).
*   \(\epsilon\) is a hyperparameter (e.g., 0.1 or 0.2) that defines the clipping range.

The `clip` function ensures that the policy ratio \(r_t(\theta)\) stays within a small interval around 1. If the advantage \(\hat{A}_t\) is positive, the policy update is constrained to prevent the new policy from becoming too different from the old one. If \(\hat{A}_t\) is negative, the policy is prevented from shrinking the probability of a bad action too much.

#### b. Value Function Loss

The **value function loss** is a separate term that trains the critic network to accurately predict state values. It is typically a mean-squared error (MSE) between the critic's predicted value \(V_\phi(s_t)\) and the target value (e.g., the GAE return estimate or a discounted sum of rewards):

\[L^{VF}(\phi) = \mathbb{E}_t[(V_\phi(s_t) - V_t^{target})^2]\]

This loss function aims to minimize the error in the value predictions, ensuring the critic provides reliable advantage estimates for the actor.

#### c. Entropy Bonus

The **entropy bonus** is added to the objective function to encourage exploration. It's a term proportional to the entropy of the policy's action distribution:

\[L^{ENT}(\theta) = c \cdot \mathbb{E}_t[H(\pi_\theta(\cdot|s_t))]\]

where \(H(\pi_\theta(\cdot|s_t))\) is the entropy of the policy at state \(s_t\), and \(c\) is a coefficient that controls the strength of the entropy regularization. Maximizing entropy prevents the policy from converging prematurely to a suboptimal deterministic policy, thus promoting diverse actions and better exploration of the environment.

### 4. PPO Loss Function

The PPO loss function is a combination of three main terms, each serving a specific purpose to ensure stable and efficient learning:

#### a. Clipped Surrogate Objective

The **clipped surrogate objective** is the core of PPO's policy update. It aims to maximize a measure of advantage while preventing large policy updates that could lead to instability. It is defined as:

\[L^{CLIP}(\theta) = \mathbb{E}_t[\min(r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t)]\]

where:
*   \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\) is the ratio of the new policy's probability of taking action \(a_t\) in state \(s_t\) to the old policy's probability.
*   \(\hat{A}_t\) is the advantage estimate at time \(t\) (typically from GAE).
*   \(\epsilon\) is a hyperparameter (e.g., 0.1 or 0.2) that defines the clipping range.

The `clip` function ensures that the policy ratio \(r_t(\theta)\) stays within a small interval around 1. If the advantage \(\hat{A}_t\) is positive, the policy update is constrained to prevent the new policy from becoming too different from the old one. If \(\hat{A}_t\) is negative, the policy is prevented from shrinking the probability of a bad action too much.

#### b. Value Function Loss

The **value function loss** is a separate term that trains the critic network to accurately predict state values. It is typically a mean-squared error (MSE) between the critic's predicted value \(V_\phi(s_t)\) and the target value (e.g., the GAE return estimate or a discounted sum of rewards):

\[L^{VF}(\phi) = \mathbb{E}_t[(V_\phi(s_t) - V_t^{target})^2]\]

This loss function aims to minimize the error in the value predictions, ensuring the critic provides reliable advantage estimates for the actor.

#### c. Entropy Bonus

The **entropy bonus** is added to the objective function to encourage exploration. It's a term proportional to the entropy of the policy's action distribution:

\[L^{ENT}(\theta) = c \cdot \mathbb{E}_t[H(\pi_\theta(\cdot|s_t))]\]

where \(H(\pi_\theta(\cdot|s_t))\) is the entropy of the policy at state \(s_t\), and \(c\) is a coefficient that controls the strength of the entropy regularization. Maximizing entropy prevents the policy from converging prematurely to a suboptimal deterministic policy, thus promoting diverse actions and better exploration of the environment.

### 4. PPO Loss Function

The PPO loss function is a combination of three main terms, each serving a specific purpose to ensure stable and efficient learning:

#### a. Clipped Surrogate Objective

The **clipped surrogate objective** is the core of PPO's policy update. It aims to maximize a measure of advantage while preventing large policy updates that could lead to instability. It is defined as:

\[L^{CLIP}(\theta) = \mathbb{E}_t[\min(r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t)]\]

where:
*   \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}\) is the ratio of the new policy's probability of taking action \(a_t\) in state \(s_t\) to the old policy's probability.
*   \(\hat{A}_t\) is the advantage estimate at time \(t\) (typically from GAE).
*   \(\epsilon\) is a hyperparameter (e.g., 0.1 or 0.2) that defines the clipping range.

The `clip` function ensures that the policy ratio \(r_t(\theta)\) stays within a small interval around 1. If the advantage \(\hat{A}_t\) is positive, the policy update is constrained to prevent the new policy from becoming too different from the old one. If \(\hat{A}_t\) is negative, the policy is prevented from shrinking the probability of a bad action too much.

#### b. Value Function Loss

The **value function loss** is a separate term that trains the critic network to accurately predict state values. It is typically a mean-squared error (MSE) between the critic's predicted value \(V_\phi(s_t)\) and the target value (e.g., the GAE return estimate or a discounted sum of rewards):

\[L^{VF}(\phi) = \mathbb{E}_t[(V_\phi(s_t) - V_t^{target})^2]\]

This loss function aims to minimize the error in the value predictions, ensuring the critic provides reliable advantage estimates for the actor.

#### c. Entropy Bonus

The **entropy bonus** is added to the objective function to encourage exploration. It's a term proportional to the entropy of the policy's action distribution:

\[L^{ENT}(\theta) = c \cdot \mathbb{E}_t[H(\pi_\theta(\cdot|s_t))]\]

where \(H(\pi_\theta(\cdot|s_t))\) is the entropy of the policy at state \(s_t\), and \(c\) is a coefficient that controls the strength of the entropy regularization. Maximizing entropy prevents the policy from converging prematurely to a suboptimal deterministic policy, thus promoting diverse actions and better exploration of the environment.

## Define Actor and Critic Networks

### Subtask:
Implement the neural network architectures for the actor (policy) and critic (value) using Flax. This will involve defining `flax.linen.Module` classes for each, taking observation and action sizes into account. The actor will typically output parameters for a distribution (e.g., mean and log standard deviation for a Gaussian policy), while the critic will output a single value estimate.


**Reasoning**:
I need to define the Actor and Critic neural network architectures using Flax's `nn.Module` as instructed. This will involve creating two classes, `Actor` and `Critic`, each with an MLP structure and specific outputs.



In [None]:
import flax.linen as nn
import jax.numpy as jnp

# 1. Define a Flax nn.Module subclass named Actor
class Actor(nn.Module):
  action_size: int
  hidden_layer_sizes: tuple = (256, 256)

  # 2. Implement an MLP for the actor and a learnable log_std parameter
  @nn.compact
  def __call__(self, observation: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
    x = observation
    for i, hidden_size in enumerate(self.hidden_layer_sizes):
      x = nn.Dense(features=hidden_size, name=f'actor_hidden_{i}')(x)
      x = nn.relu(x)

    # Output layer for mean
    mean = nn.Dense(features=self.action_size, name='actor_mean')(x)

    # Learnable parameter for log_std, initialized to zeros
    log_std = self.param('log_std', nn.initializers.zeros, (self.action_size,))

    return mean, log_std

# 4. Define a Flax nn.Module subclass named Critic
class Critic(nn.Module):
  hidden_layer_sizes: tuple = (256, 256)

  # 5. Implement an MLP for the critic
  @nn.compact
  def __call__(self, observation: jnp.ndarray) -> jnp.ndarray:
    x = observation
    for i, hidden_size in enumerate(self.hidden_layer_sizes):
      x = nn.Dense(features=hidden_size, name=f'critic_hidden_{i}')(x)
      x = nn.relu(x)

    # Output layer for value (single scalar)
    value = nn.Dense(features=1, name='critic_value')(x)

    return jnp.squeeze(value, axis=-1)

print("Actor and Critic network architectures defined using Flax.")

## Implement Generalized Advantage Estimation (GAE)

### Subtask:
Write a JAX function to calculate Generalized Advantage Estimation (GAE) given rewards, value predictions, and `done` flags from collected trajectories. This function is crucial for providing stable and effective gradient estimates for policy updates.


**Reasoning**:
To implement the Generalized Advantage Estimation (GAE) function, I will create a JAX-jittable Python function that iterates backward through the collected trajectories to compute advantages based on rewards, value predictions, and done flags, as specified in the instructions. This function will return the calculated GAE advantages.



In [None]:
import jax
import jax.numpy as jnp

# 1. Define a JAX-jittable function to calculate GAE
@jax.jit
def compute_gae(rewards: jnp.ndarray, values: jnp.ndarray, dones: jnp.ndarray, gamma: float, gae_lambda: float) -> jnp.ndarray:
  """Computes Generalized Advantage Estimation (GAE)."""
  # rewards: array of rewards of shape (T,)
  # values: array of value predictions of shape (T+1,)
  # dones: array of done flags of shape (T,)
  # gamma: discount factor
  # gae_lambda: GAE lambda parameter

  # T is the number of time steps (length of rewards/dones)
  T = rewards.shape[0]

  # 2. Initialize advantages as a JAX array of zeros
  advantages = jnp.zeros_like(rewards)
  last_gae_lam = 0.0

  # 3. Iterate backward through the trajectories
  for t in reversed(range(T)):
    # If the episode terminated at t, the next state's value is 0
    # (1 - dones[t]) ensures this, as dones[t] is 1 for a terminal state.
    next_value = values[t + 1] * (1 - dones[t])

    # 4. Calculate the TD error (delta)
    delta = rewards[t] + gamma * next_value - values[t]

    # 5. Compute the advantage for the current step
    # The last_gae_lam is equivalent to advantages[t+1] * (1 - dones[t])
    # if the episode continues. If it's done, it resets to 0.
    advantages = advantages.at[t].set(delta + gamma * gae_lambda * (1 - dones[t]) * last_gae_lam)
    last_gae_lam = advantages[t]

  # 6. Return the computed advantages array
  return advantages

print("JAX function 'compute_gae' for Generalized Advantage Estimation (GAE) has been defined.")

## Implement PPO Loss Function

### Subtask:
Develop the PPO loss function, which consists of three main parts: the clipped surrogate objective for the policy, the mean squared error for the value function, and an entropy bonus to encourage exploration. This function will take network parameters, collected data (observations, actions, advantages, returns, log probabilities), and hyperparameters as input.


**Reasoning**:
I need to implement the PPO loss function, which involves calculating the clipped surrogate objective, value function loss, and entropy bonus using JAX and Flax. I will create a jittable function that takes network parameters, observed data, and hyperparameters to compute the combined loss.



In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.stats.norm as normal_dist

# 1. Define a JAX-jittable function, ppo_loss
@jax.jit
def ppo_loss(
    actor_params,
    critic_params,
    apply_fn_actor,
    apply_fn_critic,
    observations: jnp.ndarray,
    actions: jnp.ndarray,
    advantages: jnp.ndarray,
    returns: jnp.ndarray,
    old_log_probs: jnp.ndarray,
    clip_param: float,
    value_loss_coeff: float,
    entropy_coeff: float
):
  """Computes the PPO loss function."""

  # 2. Apply the actor network to get mean and log_std
  mean, log_std = apply_fn_actor(actor_params, observations)

  # 3. Create a tfp.distributions.Normal (or equivalent JAX distribution)
  std = jnp.exp(log_std)
  action_distribution = normal_dist.logpdf(x=actions, loc=mean, scale=std)

  # 4. Calculate the new_log_probs of the actions and the entropy
  # For a Gaussian policy, the log_prob of actions and entropy can be computed.
  # Here we use jax.scipy.stats.norm.logpdf for log_probs directly.
  # For entropy, we need to sum over the action dimensions.
  new_log_probs = jnp.sum(action_distribution, axis=-1)
  # The entropy of a multivariate Gaussian is sum(0.5 * (1 + log(2*pi*sigma^2)))
  entropy = jnp.sum(0.5 * (1.0 + jnp.log(2 * jnp.pi) + 2 * log_std), axis=-1)

  # 5. Compute the ratio of the new policy's probabilities to the old policy's probabilities
  ratio = jnp.exp(new_log_probs - old_log_probs)

  # 6. Calculate the clipped surrogate objective (policy loss)
  surr1 = ratio * advantages
  surr2 = jnp.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages
  # Negate because we are using gradient descent (minimizing loss) to maximize objective
  policy_loss = -jnp.mean(jnp.minimum(surr1, surr2))

  # 7. Apply the critic network to get current_values
  current_values = apply_fn_critic(critic_params, observations)

  # 8. Calculate the value function loss as the mean squared error
  value_loss = jnp.mean(jnp.square(current_values - returns))

  # 9. Calculate the entropy bonus as the mean of the entropy
  entropy_loss = -jnp.mean(entropy) # Typically entropy is maximized, so we subtract it from the total loss

  # 10. Combine the three loss components into a total_loss
  total_loss = policy_loss + value_loss_coeff * value_loss + entropy_coeff * entropy_loss

  # 11. Return the total_loss and individual components
  return total_loss, {
      'policy_loss': policy_loss,
      'value_loss': value_loss,
      'entropy_loss': entropy_loss
  }

print("PPO loss function 'ppo_loss' has been defined.")

## Create a PPO Learner/Agent

### Subtask:
Encapsulate the actor and critic networks, their optimizers (e.g., Adam), and the PPO loss calculation into a PPO agent or learner class. This class will manage parameter updates and provide an interface for interacting with the environment.


**Reasoning**:
I need to define the `PPOAgent` class, which will encapsulate the actor and critic networks, their optimizers, and the PPO loss calculation. This involves initializing the networks and optimizers in `__init__`, implementing a `sample_action` method for generating actions, and an `update` method for performing a PPO training step, including gradient computation and parameter updates.



In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.stats.norm as normal_dist
import optax
from flax.training import train_state

# 1. Define a new class, `PPOAgent`, that will serve as the PPO learner.
class PPOAgent:
  def __init__(
      self,
      rng: jax.random.PRNGKey,
      observation_size: int,
      action_size: int,
      actor_hidden_layer_sizes: tuple = (256, 256),
      critic_hidden_layer_sizes: tuple = (256, 256),
      actor_learning_rate: float = 3e-4,
      critic_learning_rate: float = 3e-4,
  ):
    # 2. In the `__init__` method, initialize the `Actor` and `Critic` networks
    #    and their respective optimizers using `optax.adam`.

    # Initialize Actor Network
    actor_rng, critic_rng = jax.random.split(rng)
    self.actor_network = Actor(action_size=action_size, hidden_layer_sizes=actor_hidden_layer_sizes)
    actor_params = self.actor_network.init(actor_rng, jnp.zeros(observation_size))['params']
    actor_optimizer = optax.adam(learning_rate=actor_learning_rate)
    self.actor_state = train_state.TrainState.create(
        apply_fn=self.actor_network.apply, params=actor_params, tx=actor_optimizer
    )

    # Initialize Critic Network
    self.critic_network = Critic(hidden_layer_sizes=critic_hidden_layer_sizes)
    critic_params = self.critic_network.init(critic_rng, jnp.zeros(observation_size))['params']
    critic_optimizer = optax.adam(learning_rate=critic_learning_rate)
    self.critic_state = train_state.TrainState.create(
        apply_fn=self.critic_network.apply, params=critic_params, tx=critic_optimizer
    )

  # 3. Implement a `sample_action` method
  @jax.jit
  def sample_action(
      self,
      actor_params: flax.core.FrozenDict,
      observation: jnp.ndarray,
      rng: jax.random.PRNGKey
  ) -> tuple[jnp.ndarray, jnp.ndarray]:
    mean, log_std = self.actor_network.apply({'params': actor_params}, observation)
    std = jnp.exp(log_std)

    # Create a JAX-compatible distribution for sampling
    # We need to reshape mean and std for batch sampling if observation is batched
    if observation.ndim == 1: # Single observation
        sample_rng, _ = jax.random.split(rng)
        action = mean + std * jax.random.normal(sample_rng, mean.shape)
        log_prob = jnp.sum(normal_dist.logpdf(x=action, loc=mean, scale=std))
    else: # Batched observations
        sample_rngs = jax.random.split(rng, observation.shape[0])
        action = mean + std * jax.random.normal(sample_rngs, mean.shape)
        log_prob = jnp.sum(normal_dist.logpdf(x=action, loc=mean, scale=std), axis=-1)

    return action, log_prob

  # 4. Implement an `update` method
  @jax.jit
  def update(
      self,
      actor_state: train_state.TrainState,
      critic_state: train_state.TrainState,
      observations: jnp.ndarray,
      actions: jnp.ndarray,
      advantages: jnp.ndarray,
      returns: jnp.ndarray,
      old_log_probs: jnp.ndarray,
      clip_param: float,
      value_loss_coeff: float,
      entropy_coeff: float
  ):
    # 5. Inside the `update` method, use `jax.value_and_grad`
    #    to compute the gradients of the `ppo_loss`.
    (loss, metrics), grads = jax.value_and_grad(ppo_loss, argnums=(0, 1), has_aux=True)(
        actor_state.params,
        critic_state.params,
        actor_state.apply_fn,
        critic_state.apply_fn,
        observations,
        actions,
        advantages,
        returns,
        old_log_probs,
        clip_param,
        value_loss_coeff,
        entropy_coeff,
    )

    actor_grads, critic_grads = grads

    # 6. Apply these gradients to update the actor and critic parameters.
    actor_state = actor_state.apply_gradients(grads=actor_grads)
    critic_state = critic_state.apply_gradients(grads=critic_grads)

    # 7. The `update` method should return the updated states and metrics.
    return actor_state, critic_state, metrics

print("PPOAgent class defined.")

**Reasoning**:
The previous code failed because `flax.core.FrozenDict` was used as a type hint without `flax.core` being explicitly imported. I need to add `from flax.core import FrozenDict` to resolve this `NameError`.



In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.stats.norm as normal_dist
import optax
from flax.training import train_state
from flax.core import FrozenDict # Import FrozenDict explicitly

# 1. Define a new class, `PPOAgent`, that will serve as the PPO learner.
class PPOAgent:
  def __init__(
      self,
      rng: jax.random.PRNGKey,
      observation_size: int,
      action_size: int,
      actor_hidden_layer_sizes: tuple = (256, 256),
      critic_hidden_layer_sizes: tuple = (256, 256),
      actor_learning_rate: float = 3e-4,
      critic_learning_rate: float = 3e-4,
  ):
    # 2. In the `__init__` method, initialize the `Actor` and `Critic` networks
    #    and their respective optimizers using `optax.adam`.

    # Initialize Actor Network
    actor_rng, critic_rng = jax.random.split(rng)
    self.actor_network = Actor(action_size=action_size, hidden_layer_sizes=actor_hidden_layer_sizes)
    actor_params = self.actor_network.init(actor_rng, jnp.zeros(observation_size))['params']
    actor_optimizer = optax.adam(learning_rate=actor_learning_rate)
    self.actor_state = train_state.TrainState.create(
        apply_fn=self.actor_network.apply, params=actor_params, tx=actor_optimizer
    )

    # Initialize Critic Network
    self.critic_network = Critic(hidden_layer_sizes=critic_hidden_layer_sizes)
    critic_params = self.critic_network.init(critic_rng, jnp.zeros(observation_size))['params']
    critic_optimizer = optax.adam(learning_rate=critic_learning_rate)
    self.critic_state = train_state.TrainState.create(
        apply_fn=self.critic_network.apply, params=critic_params, tx=critic_optimizer
    )

  # 3. Implement a `sample_action` method
  @jax.jit
  def sample_action(
      self,
      actor_params: FrozenDict,
      observation: jnp.ndarray,
      rng: jax.random.PRNGKey
  ) -> tuple[jnp.ndarray, jnp.ndarray]:
    mean, log_std = self.actor_network.apply({'params': actor_params}, observation)
    std = jnp.exp(log_std)

    # Create a JAX-compatible distribution for sampling
    # We need to reshape mean and std for batch sampling if observation is batched
    if observation.ndim == 1: # Single observation
        sample_rng, _ = jax.random.split(rng)
        action = mean + std * jax.random.normal(sample_rng, mean.shape)
        log_prob = jnp.sum(normal_dist.logpdf(x=action, loc=mean, scale=std))
    else: # Batched observations
        sample_rngs = jax.random.split(rng, observation.shape[0])
        action = mean + std * jax.random.normal(sample_rngs, mean.shape)
        log_prob = jnp.sum(normal_dist.logpdf(x=action, loc=mean, scale=std), axis=-1)

    return action, log_prob

  # 4. Implement an `update` method
  @jax.jit
  def update(
      self,
      actor_state: train_state.TrainState,
      critic_state: train_state.TrainState,
      observations: jnp.ndarray,
      actions: jnp.ndarray,
      advantages: jnp.ndarray,
      returns: jnp.ndarray,
      old_log_probs: jnp.ndarray,
      clip_param: float,
      value_loss_coeff: float,
      entropy_coeff: float
  ):
    # 5. Inside the `update` method, use `jax.value_and_grad`
    #    to compute the gradients of the `ppo_loss`.
    (loss, metrics), grads = jax.value_and_grad(ppo_loss, argnums=(0, 1), has_aux=True)(
        actor_state.params,
        critic_state.params,
        actor_state.apply_fn,
        critic_state.apply_fn,
        observations,
        actions,
        advantages,
        returns,
        old_log_probs,
        clip_param,
        value_loss_coeff,
        entropy_coeff,
    )

    actor_grads, critic_grads = grads

    # 6. Apply these gradients to update the actor and critic parameters.
    actor_state = actor_state.apply_gradients(grads=actor_grads)
    critic_state = critic_state.apply_gradients(grads=critic_grads)

    # 7. The `update` method should return the updated states and metrics.
    return actor_state, critic_state, metrics

print("PPOAgent class defined.")

## Develop the Training Loop

### Subtask:
Construct the main training loop for the custom PPO agent. This loop will involve collecting rollouts, calculating rewards, returns, and advantages, iteratively updating network parameters, and periodically evaluating the policy's performance.


**Reasoning**:
I need to initialize the environment, JAX random key, and the PPO agent, along with defining hyperparameters for the training loop. This sets up the necessary components before starting the main training process.



In [None]:
import jax
import jax.numpy as jnp
from brax import envs
import optax

# 1. Initialize the environment, a JAX random key, and the PPO agent with appropriate hyperparameters
# Environment setup
env_name = 'humanoid'
env = envs.get_environment(env_name)

# Hyperparameters
seed = 0
num_steps = 1_000_000 # Total environment steps
num_envs = 32         # Number of parallel environments for data collection
episode_length = env.episode_length # Max steps per episode

gamma = 0.99          # Discount factor
gae_lambda = 0.95     # GAE lambda parameter
clip_param = 0.2      # PPO clipping parameter
value_loss_coeff = 0.5 # Coefficient for value function loss
entropy_coeff = 0.01  # Coefficient for entropy bonus

epochs_per_update = 4 # Number of PPO epochs per data collection phase
num_minibatches = 4   # Number of minibatches per epoch
batch_size = num_envs * episode_length # Total samples collected per rollout
minibatch_size = batch_size // num_minibatches

actor_learning_rate = 3e-4
critic_learning_rate = 3e-4

rng = jax.random.PRNGKey(seed)
rng, agent_rng = jax.random.split(rng)

# Initialize PPO Agent
ppo_agent = PPOAgent(
    rng=agent_rng,
    observation_size=env.observation_size,
    action_size=env.action_size,
    actor_learning_rate=actor_learning_rate,
    critic_learning_rate=critic_learning_rate
)

# 2. Create JAX-jittable versions of the environment's `reset` and `step` functions.
jit_reset = jax.jit(jax.vmap(env.reset, in_axes=(0,)))
jit_step = jax.jit(jax.vmap(env.step, in_axes=(0, 0)))

print("Environment, PPO agent, and jittable functions initialized.")

**Reasoning**:
The previous code block failed because the `brax` module was not found, despite `brax` being imported in earlier cells. To resolve this, I need to re-import `brax` along with `jax.numpy` and `optax` to ensure all necessary libraries are available for the PPO training loop setup.



In [None]:
import jax
import jax.numpy as jnp
from brax import envs
import optax

# 1. Initialize the environment, a JAX random key, and the PPO agent with appropriate hyperparameters
# Environment setup
env_name = 'humanoid'
env = envs.get_environment(env_name)

# Hyperparameters
seed = 0
num_steps = 1_000_000 # Total environment steps
num_envs = 32         # Number of parallel environments for data collection
episode_length = env.episode_length # Max steps per episode

gamma = 0.99          # Discount factor
gae_lambda = 0.95     # GAE lambda parameter
clip_param = 0.2      # PPO clipping parameter
value_loss_coeff = 0.5 # Coefficient for value function loss
entropy_coeff = 0.01  # Coefficient for entropy bonus

epochs_per_update = 4 # Number of PPO epochs per data collection phase
num_minibatches = 4   # Number of minibatches per epoch
batch_size = num_envs * episode_length # Total samples collected per rollout
minibatch_size = batch_size // num_minibatches

actor_learning_rate = 3e-4
critic_learning_rate = 3e-4

rng = jax.random.PRNGKey(seed)
rng, agent_rng = jax.random.split(rng)

# Initialize PPO Agent
ppo_agent = PPOAgent(
    rng=agent_rng,
    observation_size=env.observation_size,
    action_size=env.action_size,
    actor_learning_rate=actor_learning_rate,
    critic_learning_rate=critic_learning_rate
)

# 2. Create JAX-jittable versions of the environment's `reset` and `step` functions.
jit_reset = jax.jit(jax.vmap(env.reset, in_axes=(0,)))
jit_step = jax.jit(jax.vmap(env.step, in_axes=(0, 0)))

print("Environment, PPO agent, and jittable functions initialized.")

**Reasoning**:
The error `ModuleNotFoundError: No module named 'brax'` indicates that the brax library is not installed in the current runtime, even though it was imported in previous cells. I need to explicitly install it before proceeding with the code that uses it.



In [None]:
!pip install brax

**Reasoning**:
The previous error was due to `brax` not being installed. Now that `brax` has been installed, I need to re-execute the initialization code for the environment, PPO agent, and jittable functions to ensure everything is set up correctly.



In [None]:
import jax
import jax.numpy as jnp
from brax import envs
import optax

# 1. Initialize the environment, a JAX random key, and the PPO agent with appropriate hyperparameters
# Environment setup
env_name = 'humanoid'
env = envs.get_environment(env_name)

# Hyperparameters
seed = 0
num_steps = 1_000_000 # Total environment steps
num_envs = 32         # Number of parallel environments for data collection
episode_length = env.episode_length # Max steps per episode

gamma = 0.99          # Discount factor
gae_lambda = 0.95     # GAE lambda parameter
clip_param = 0.2      # PPO clipping parameter
value_loss_coeff = 0.5 # Coefficient for value function loss
entropy_coeff = 0.01  # Coefficient for entropy bonus

epochs_per_update = 4 # Number of PPO epochs per data collection phase
num_minibatches = 4   # Number of minibatches per epoch
batch_size = num_envs * episode_length # Total samples collected per rollout
minibatch_size = batch_size // num_minibatches

actor_learning_rate = 3e-4
critic_learning_rate = 3e-4

rng = jax.random.PRNGKey(seed)
rng, agent_rng = jax.random.split(rng)

# Initialize PPO Agent
ppo_agent = PPOAgent(
    rng=agent_rng,
    observation_size=env.observation_size,
    action_size=env.action_size,
    actor_learning_rate=actor_learning_rate,
    critic_learning_rate=critic_learning_rate
)

# 2. Create JAX-jittable versions of the environment's `reset` and `step` functions.
jit_reset = jax.jit(jax.vmap(env.reset, in_axes=(0,)))
jit_step = jax.jit(jax.vmap(env.step, in_axes=(0, 0)))

print("Environment, PPO agent, and jittable functions initialized.")

**Reasoning**:
The previous code failed because the `Humanoid` environment object, when created via `envs.get_environment`, did not expose the `episode_length` attribute directly, leading to an `AttributeError`. To fix this, I will explicitly set `episode_length` to its default value of 1000 as a hyperparameter, which is consistent with the `Humanoid` class definition.



In [None]:
import jax
import jax.numpy as jnp
from brax import envs
import optax

# 1. Initialize the environment, a JAX random key, and the PPO agent with appropriate hyperparameters
# Environment setup
env_name = 'humanoid'
env = envs.get_environment(env_name)

# Hyperparameters
seed = 0
num_steps = 1_000_000 # Total environment steps
num_envs = 32         # Number of parallel environments for data collection
episode_length = 1000 # Max steps per episode (Fixed: Directly set the default value)

gamma = 0.99          # Discount factor
gae_lambda = 0.95     # GAE lambda parameter
clip_param = 0.2      # PPO clipping parameter
value_loss_coeff = 0.5 # Coefficient for value function loss
entropy_coeff = 0.01  # Coefficient for entropy bonus

epochs_per_update = 4 # Number of PPO epochs per data collection phase
num_minibatches = 4   # Number of minibatches per epoch
batch_size = num_envs * episode_length # Total samples collected per rollout
minibatch_size = batch_size // num_minibatches

actor_learning_rate = 3e-4
critic_learning_rate = 3e-4

rng = jax.random.PRNGKey(seed)
rng, agent_rng = jax.random.split(rng)

# Initialize PPO Agent
ppo_agent = PPOAgent(
    rng=agent_rng,
    observation_size=env.observation_size,
    action_size=env.action_size,
    actor_learning_rate=actor_learning_rate,
    critic_learning_rate=critic_learning_rate
)

# 2. Create JAX-jittable versions of the environment's `reset` and `step` functions.
jit_reset = jax.jit(jax.vmap(env.reset, in_axes=(0,)))
jit_step = jax.jit(jax.vmap(env.step, in_axes=(0, 0)))

print("Environment, PPO agent, and jittable functions initialized.")

**Reasoning**:
Now that the environment and PPO agent are initialized, I need to implement the data collection phase of the training loop. This involves defining a jittable function to simulate steps in the environment, collect trajectories (observations, actions, rewards, values, log probabilities, and done flags), and prepare them for PPO updates.



In [None]:
from typing import Any

@jax.jit
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng = carry
    step_rng, action_rng, value_rng = jax.random.split(step_rng, 3)

    # Sample action
    action, log_prob = ppo_agent.sample_action(actor_params, env_state.obs, action_rng)

    # Predict value
    value = ppo_agent.critic_network.apply({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng), transition

  # Reset environment for parallel collection
  rng, reset_rng = jax.random.split(rng)
  # We need to ensure that reset_rng is split for each parallel environment
  reset_rngs = jax.random.split(reset_rng, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Jax.scan to efficiently collect data for num_steps_to_collect
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, rng),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  final_value = ppo_agent.critic_network.apply({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined.")


**Reasoning**:
Now that the `collect_rollout` function is defined and necessary components are initialized, I need to implement the main PPO training loop. This involves iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, as specified in the instructions.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = shuffled_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")



**Reasoning**:
The `TypeError` indicates that the `ppo_agent` object, which is a Python class instance, was passed to the JAX-jitted `collect_rollout` function without being marked as a static argument. To resolve this, I need to modify the `collect_rollout` function's `jax.jit` decorator to include `static_argnums=1`, explicitly telling JAX that the `ppo_agent` argument should be treated as static (i.e., not traced).



In [None]:
from typing import Any

@jax.jit(static_argnums=1) # Mark ppo_agent as a static argument
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng = carry
    step_rng, action_rng, value_rng = jax.random.split(step_rng, 3)

    # Sample action
    action, log_prob = ppo_agent.sample_action(actor_params, env_state.obs, action_rng)

    # Predict value
    value = ppo_agent.critic_network.apply({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng), transition

  # Reset environment for parallel collection
  rng, reset_rng = jax.random.split(rng)
  # We need to ensure that reset_rng is split for each parallel environment
  reset_rngs = jax.random.split(reset_rng, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Jax.scan to efficiently collect data for num_steps_to_collect
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, rng),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  final_value = ppo_agent.critic_network.apply({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined.")

**Reasoning**:
The error `TypeError: jit() missing 1 required positional argument: 'fun'` arises because `jax.jit` when called with arguments like `static_argnums` returns a decorator function, but the original `@jax.jit(static_argnums=1)` syntax applies the result of `jax.jit(...)` directly without providing the function to be decorated. To fix this, I need to use `functools.partial` to correctly apply the `static_argnums` to the `jax.jit` decorator before it wraps the `collect_rollout` function.



In [None]:
from typing import Any
from functools import partial

@partial(jax.jit, static_argnums=1) # Mark ppo_agent as a static argument
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng = carry
    step_rng, action_rng, value_rng = jax.random.split(step_rng, 3)

    # Sample action
    action, log_prob = ppo_agent.sample_action(actor_params, env_state.obs, action_rng)

    # Predict value
    value = ppo_agent.critic_network.apply({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng), transition

  # Reset environment for parallel collection
  rng, reset_rng = jax.random.split(rng)
  # We need to ensure that reset_rng is split for each parallel environment
  reset_rngs = jax.random.split(reset_rng, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Jax.scan to efficiently collect data for num_steps_to_collect
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, rng),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  final_value = ppo_agent.critic_network.apply({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined.")

**Reasoning**:
Now that the `collect_rollout` function is correctly defined and jitted, I need to re-execute the main PPO training loop. This will involve iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, as specified in the instructions, and logging the training progress.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)i
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")


**Reasoning**:
The `ConcretizationTypeError` indicates that `num_steps_to_collect` is being treated as a traced JAX array instead of a static integer by `jax.lax.scan`. To fix this, I need to explicitly mark `num_steps_to_collect` (which is the 3rd argument, index 3) as a static argument in the `collect_rollout` function's `jax.jit` decorator, alongside the already static `ppo_agent` (index 1).



In [None]:
from typing import Any
from functools import partial

@partial(jax.jit, static_argnums=(1, 3)) # Mark ppo_agent and num_steps_to_collect as static arguments
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng = carry
    step_rng, action_rng, value_rng = jax.random.split(step_rng, 3)

    # Sample action
    action, log_prob = ppo_agent.sample_action(actor_params, env_state.obs, action_rng)

    # Predict value
    value = ppo_agent.critic_network.apply({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng), transition

  # Reset environment for parallel collection
  rng, reset_rng = jax.random.split(rng)
  # We need to ensure that reset_rng is split for each parallel environment
  reset_rngs = jax.random.split(reset_rng, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Jax.scan to efficiently collect data for num_steps_to_collect
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, rng),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  final_value = ppo_agent.critic_network.apply({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined.")

**Reasoning**:
Now that the `collect_rollout` function is correctly defined and jitted with the necessary static arguments, I need to re-execute the main PPO training loop. This will involve iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, as specified in the instructions, and logging the training progress.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")

**Reasoning**:
The `TypeError` arises because the `sample_action` and `update` methods within `PPOAgent` were independently jitted using `@jax.jit`. When these jitted methods were called from within `collect_rollout` (which is also jitted and uses `jax.lax.scan`), JAX attempted to re-trace the `self` argument (the `PPOAgent` instance) as a dynamic JAX array, leading to a conflict. To resolve this, I need to remove the `@jax.jit` decorators from `sample_action` and `update` methods in `PPOAgent`, allowing the outer `collect_rollout`'s jit to handle the tracing of these calls seamlessly, as the relevant JAX arrays (parameters, observations, etc.) are already being passed correctly.



In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.stats.norm as normal_dist
import optax
from flax.training import train_state
from flax.core import FrozenDict # Import FrozenDict explicitly

# 1. Define a new class, `PPOAgent`, that will serve as the PPO learner.
class PPOAgent:
  def __init__(
      self,
      rng: jax.random.PRNGKey,
      observation_size: int,
      action_size: int,
      actor_hidden_layer_sizes: tuple = (256, 256),
      critic_hidden_layer_sizes: tuple = (256, 256),
      actor_learning_rate: float = 3e-4,
      critic_learning_rate: float = 3e-4,
  ):
    # 2. In the `__init__` method, initialize the `Actor` and `Critic` networks
    #    and their respective optimizers using `optax.adam`.

    # Initialize Actor Network
    actor_rng, critic_rng = jax.random.split(rng)
    self.actor_network = Actor(action_size=action_size, hidden_layer_sizes=actor_hidden_layer_sizes)
    actor_params = self.actor_network.init(actor_rng, jnp.zeros(observation_size))['params']
    actor_optimizer = optax.adam(learning_rate=actor_learning_rate)
    self.actor_state = train_state.TrainState.create(
        apply_fn=self.actor_network.apply, params=actor_params, tx=actor_optimizer
    )

    # Initialize Critic Network
    self.critic_network = Critic(hidden_layer_sizes=critic_hidden_layer_sizes)
    critic_params = self.critic_network.init(critic_rng, jnp.zeros(observation_size))['params']
    critic_optimizer = optax.adam(learning_rate=critic_learning_rate)
    self.critic_state = train_state.TrainState.create(
        apply_fn=self.critic_network.apply, params=critic_params, tx=critic_optimizer
    )

  # 3. Implement a `sample_action` method
  # Removed @jax.jit from here
  def sample_action(
      self,
      actor_params: FrozenDict,
      observation: jnp.ndarray,
      rng: jax.random.PRNGKey
  ) -> tuple[jnp.ndarray, jnp.ndarray]:
    mean, log_std = self.actor_network.apply({'params': actor_params}, observation)
    std = jnp.exp(log_std)

    # Create a JAX-compatible distribution for sampling
    # We need to reshape mean and std for batch sampling if observation is batched
    if observation.ndim == 1: # Single observation
        sample_rng, _ = jax.random.split(rng)
        action = mean + std * jax.random.normal(sample_rng, mean.shape)
        log_prob = jnp.sum(normal_dist.logpdf(x=action, loc=mean, scale=std))
    else: # Batched observations
        sample_rngs = jax.random.split(rng, observation.shape[0])
        action = mean + std * jax.random.normal(sample_rngs, mean.shape)
        log_prob = jnp.sum(normal_dist.logpdf(x=action, loc=mean, scale=std), axis=-1)

    return action, log_prob

  # 4. Implement an `update` method
  # Removed @jax.jit from here
  def update(
      self,
      actor_state: train_state.TrainState,
      critic_state: train_state.TrainState,
      observations: jnp.ndarray,
      actions: jnp.ndarray,
      advantages: jnp.ndarray,
      returns: jnp.ndarray,
      old_log_probs: jnp.ndarray,
      clip_param: float,
      value_loss_coeff: float,
      entropy_coeff: float
  ):
    # 5. Inside the `update` method, use `jax.value_and_grad`
    #    to compute the gradients of the `ppo_loss`.
    (loss, metrics), grads = jax.value_and_grad(ppo_loss, argnums=(0, 1), has_aux=True)(
        actor_state.params,
        critic_state.params,
        actor_state.apply_fn,
        critic_state.apply_fn,
        observations,
        actions,
        advantages,
        returns,
        old_log_probs,
        clip_param,
        value_loss_coeff,
        entropy_coeff,
    )

    actor_grads, critic_grads = grads

    # 6. Apply these gradients to update the actor and critic parameters.
    actor_state = actor_state.apply_gradients(grads=actor_grads)
    critic_state = critic_state.apply_gradients(grads=critic_grads)

    # 7. The `update` method should return the updated states and metrics.
    return actor_state, critic_state, metrics

print("PPOAgent class defined.")

**Reasoning**:
Now that the `PPOAgent` class has been correctly redefined with `jax.jit` decorators removed from its methods, I need to re-execute the main PPO training loop. This will involve iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, and logging the training progress, which should now run without the previous tracing errors.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")


# Task
To address the `ValueError: normal accepts a single key, but was given a key array of shape (32, 2) != (). Use jax.vmap for batching.`, the random number generation and key splitting need to be vectorized correctly within the `collect_rollout` and `PPOAgent.sample_action` functions. This error indicates that batched random keys are being passed to functions that expect single keys, or vice-versa, due to the implicit vectorization (`vmap`) by `jax.lax.scan`.

Here's the plan to resolve the issue:

1.  **Adjust `step_rng` initialization in `collect_rollout`**: Before calling `jax.lax.scan` in `collect_rollout`, split the initial `rollout_rng` (a single key) into a batch of `num_envs` distinct keys. This batch of keys (`initial_step_rngs` with shape `(num_envs, 2)`) should then be passed as the random key component of the `init_carry` to `jax.lax.scan`.
2.  **Vectorize key splitting in `_env_step`**: Inside `_env_step`, `step_rng` will now correctly be a batch of keys (`(num_envs, 2)`). Replace `jax.random.split(step_rng, 3)` with `jax.vmap(jax.random.split, in_axes=(0, None))(step_rng, 3)`. This will split each of the `num_envs` keys independently, resulting in `new_step_rngs`, `action_rngs`, and `value_rngs`, each having a shape of `(num_envs, 2)`.
3.  **Modify `PPOAgent.sample_action` for batched random noise**: Since `action_rng` (passed as `rng` to `sample_action`) will be batched (`(num_envs, 2)`), modify the `sample_action` method to correctly generate batched action noise. The `if observation.ndim == 1` and `else` branches in `sample_action` should be removed, and `jax.vmap` should be used for `jax.random.normal`: `action_noise = jax.vmap(jax.random.normal, in_axes=(0, None))(rng, (mean.shape[-1],))` where `rng` is the input `jax.random.PRNGKey` to `sample_action`. Then, compute `action = mean + std * action_noise`.
4.  **Update `step_rng` in `_env_step` carry**: Ensure that the updated batch of random keys (`new_step_rngs`) is returned in the `carry` of the `_env_step` function for the next iteration of `jax.lax.scan`.

## Adjust `step_rng` initialization in `collect_rollout`

### Subtask:
Split the initial `rollout_rng` into a batch of `num_envs` distinct keys before calling `jax.lax.scan` in `collect_rollout`, and pass this batch of keys as the random key component of the `init_carry`.


**Reasoning**:
The previous error stemmed from `jax.random.normal` receiving a batch of random keys (`sample_rngs`) instead of a single key within `PPOAgent.sample_action` when processing batched observations. To resolve this, I will refine the `sample_action` method to exclusively handle single observations and single random keys. This will allow `collect_rollout` to then use `jax.vmap` effectively for batched processing.



In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.stats.norm as normal_dist
import optax
from flax.training import train_state
from flax.core import FrozenDict # Import FrozenDict explicitly

# 1. Define a new class, `PPOAgent`, that will serve as the PPO learner.
class PPOAgent:
  def __init__(
      self,
      rng: jax.random.PRNGKey,
      observation_size: int,
      action_size: int,
      actor_hidden_layer_sizes: tuple = (256, 256),
      critic_hidden_layer_sizes: tuple = (256, 256),
      actor_learning_rate: float = 3e-4,
      critic_learning_rate: float = 3e-4,
  ):
    # 2. In the `__init__` method, initialize the `Actor` and `Critic` networks
    #    and their respective optimizers using `optax.adam`.

    # Initialize Actor Network
    actor_rng, critic_rng = jax.random.split(rng)
    self.actor_network = Actor(action_size=action_size, hidden_layer_sizes=actor_hidden_layer_sizes)
    actor_params = self.actor_network.init(actor_rng, jnp.zeros(observation_size))['params']
    actor_optimizer = optax.adam(learning_rate=actor_learning_rate)
    self.actor_state = train_state.TrainState.create(
        apply_fn=self.actor_network.apply, params=actor_params, tx=actor_optimizer
    )

    # Initialize Critic Network
    self.critic_network = Critic(hidden_layer_sizes=critic_hidden_layer_sizes)
    critic_params = self.critic_network.init(critic_rng, jnp.zeros(observation_size))['params']
    critic_optimizer = optax.adam(learning_rate=critic_learning_rate)
    self.critic_state = train_state.TrainState.create(
        apply_fn=self.critic_network.apply, params=critic_params, tx=critic_optimizer
    )

  # 3. Implement a `sample_action` method
  # Removed @jax.jit from here
  # Modified to expect a single observation and single RNG, removing internal batching logic.
  def sample_action(
      self,
      actor_params: FrozenDict,
      observation: jnp.ndarray, # Assumed to be a single observation
      rng: jax.random.PRNGKey   # Assumed to be a single random key
  ) -> tuple[jnp.ndarray, jnp.ndarray]:
    mean, log_std = self.actor_network.apply({'params': actor_params}, observation)
    std = jnp.exp(log_std)

    # Split rng for sampling a single action
    sample_rng, _ = jax.random.split(rng)
    action = mean + std * jax.random.normal(sample_rng, mean.shape)
    log_prob = jnp.sum(normal_dist.logpdf(x=action, loc=mean, scale=std))

    return action, log_prob

  # 4. Implement an `update` method
  # Removed @jax.jit from here
  def update(
      self,
      actor_state: train_state.TrainState,
      critic_state: train_state.TrainState,
      observations: jnp.ndarray,
      actions: jnp.ndarray,
      advantages: jnp.ndarray,
      returns: jnp.ndarray,
      old_log_probs: jnp.ndarray,
      clip_param: float,
      value_loss_coeff: float,
      entropy_coeff: float
  ):
    # 5. Inside the `update` method, use `jax.value_and_grad`
    #    to compute the gradients of the `ppo_loss`.
    (loss, metrics), grads = jax.value_and_grad(ppo_loss, argnums=(0, 1), has_aux=True)(
        actor_state.params,
        critic_state.params,
        actor_state.apply_fn,
        critic_state.apply_fn,
        observations,
        actions,
        advantages,
        returns,
        old_log_probs,
        clip_param,
        value_loss_coeff,
        entropy_coeff,
    )

    actor_grads, critic_grads = grads

    # 6. Apply these gradients to update the actor and critic parameters.
    actor_state = actor_state.apply_gradients(grads=actor_grads)
    critic_state = critic_state.apply_gradients(grads=critic_grads)

    # 7. The `update` method should return the updated states and metrics.
    return actor_state, critic_state, metrics

print("PPOAgent class defined.")

**Reasoning**:
The subtask requires splitting the `rollout_rng` into a batch of keys for each environment and passing it to `jax.lax.scan`. I need to modify the `collect_rollout` function to ensure that `step_rng` in the `init_carry` is a batch of `num_envs` random keys, and the internal `_env_step` function uses these batched keys correctly. Additionally, the `PPOAgent.sample_action` and `critic_network.apply` calls within `_env_step` need to be `jax.vmap`-ped because `env_state.obs` will be a batch of observations.



In [None]:
from typing import Any
from functools import partial

@partial(jax.jit, static_argnums=(1, 3))
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng_batch = carry

    # Split each key in the batch of step_rng_batch
    step_rng_batch, action_rng_batch, value_rng_batch = jax.vmap(jax.random.split, in_axes=(0, None), out_axes=(0, 0, 0))(step_rng_batch, 3)

    # Sample action (vmap over environments)
    action, log_prob = jax.vmap(ppo_agent.sample_action, in_axes=(None, 0, 0))(actor_params, env_state.obs, action_rng_batch)

    # Predict value (vmap over environments)
    value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng_batch), transition

  # Reset environment for parallel collection
  rng, reset_rng_base = jax.random.split(rng)
  reset_rngs = jax.random.split(reset_rng_base, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Split the initial RNG for the scan carry into a batch of keys for each environment
  rng, initial_step_rng_base = jax.random.split(rng)
  initial_step_rng_batch = jax.random.split(initial_step_rng_base, num_envs) # (num_envs, 2) shaped keys

  # Jax.scan to efficiently collect data for num_steps_to_collect
  # The carry now includes a batch of RNG keys
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, initial_step_rng_batch),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  # vmap over environments since final_env_state.obs is batched
  final_value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined with batched RNG handling and vmap.")


**Reasoning**:
Now that the `collect_rollout` function is correctly defined and jitted with the necessary static arguments and batched RNG handling, I need to re-execute the main PPO training loop. This will involve iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, and logging the training progress, which should now run without the previous tracing errors.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")

**Reasoning**:
The `ValueError` indicates a mismatch between the `out_axes` specification and the actual PyTree structure returned by `jax.random.split`. Since `jax.random.split(key, 3)` returns a tuple of three `PRNGKey`s, and each `PRNGKey` is itself a PyTree (a tuple of two arrays), the `out_axes` needs to be `((0, 0), (0, 0), (0, 0))` to correctly specify the batch dimension for each component of each key. I will update the `collect_rollout` function with this corrected `out_axes`.



In [None]:
from typing import Any
from functools import partial

@partial(jax.jit, static_argnums=(1, 3))
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng_batch = carry

    # Corrected out_axes for jax.vmap(jax.random.split)
    # Each PRNGKey is a tuple of 2 elements, so for 3 keys, it's ((0,0),(0,0),(0,0))
    split_keys = jax.vmap(
        jax.random.split,
        in_axes=(0, None),
        out_axes=((0, 0), (0, 0), (0, 0))  # This is the fix
    )(step_rng_batch, 3)

    step_rng_batch, action_rng_batch, value_rng_batch = split_keys[0], split_keys[1], split_keys[2]

    # Sample action (vmap over environments)
    action, log_prob = jax.vmap(ppo_agent.sample_action, in_axes=(None, 0, 0))(actor_params, env_state.obs, action_rng_batch)

    # Predict value (vmap over environments)
    value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng_batch), transition

  # Reset environment for parallel collection
  rng, reset_rng_base = jax.random.split(rng)
  reset_rngs = jax.random.split(reset_rng_base, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Split the initial RNG for the scan carry into a batch of keys for each environment
  rng, initial_step_rng_base = jax.random.split(rng)
  initial_step_rng_batch = jax.random.split(initial_step_rng_base, num_envs) # (num_envs, 2) shaped keys

  # Jax.scan to efficiently collect data for num_steps_to_collect
  # The carry now includes a batch of RNG keys
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, initial_step_rng_batch),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  # vmap over environments since final_env_state.obs is batched
  final_value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined with corrected out_axes.")


**Reasoning**:
Now that the `collect_rollout` function is correctly defined and jitted with the necessary static arguments and batched RNG handling, I need to re-execute the main PPO training loop. This will involve iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, and logging the training progress, which should now run without the previous tracing errors.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")

**Reasoning**:
The previous `ValueError` in `jax.vmap` for `jax.random.split` was due to an over-specified `out_axes` for `PRNGKey`s. A `PRNGKey` is internally treated as a PyTree by JAX, so `out_axes=(0, 0, 0)` is sufficient to indicate that the batch dimension should be the first axis for each of the three `PRNGKey`s returned by `jax.random.split`.



In [None]:
from typing import Any
from functools import partial

@partial(jax.jit, static_argnums=(1, 3))
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng_batch = carry

    # Corrected out_axes for jax.vmap(jax.random.split)
    # Each PRNGKey is a tuple of 2 elements, so for 3 keys, it's (0,0,0) (not ((0,0),(0,0),(0,0)))
    split_keys = jax.vmap(
        jax.random.split,
        in_axes=(0, None),
        out_axes=(0, 0, 0)  # This is the fix
    )(step_rng_batch, 3)

    step_rng_batch, action_rng_batch, value_rng_batch = split_keys[0], split_keys[1], split_keys[2]

    # Sample action (vmap over environments)
    action, log_prob = jax.vmap(ppo_agent.sample_action, in_axes=(None, 0, 0))(actor_params, env_state.obs, action_rng_batch)

    # Predict value (vmap over environments)
    value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng_batch), transition

  # Reset environment for parallel collection
  rng, reset_rng_base = jax.random.split(rng)
  reset_rngs = jax.random.split(reset_rng_base, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Split the initial RNG for the scan carry into a batch of keys for each environment
  rng, initial_step_rng_base = jax.random.split(rng)
  initial_step_rng_batch = jax.random.split(initial_step_rng_base, num_envs) # (num_envs, 2) shaped keys

  # Jax.scan to efficiently collect data for num_steps_to_collect
  # The carry now includes a batch of RNG keys
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, initial_step_rng_batch),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  # vmap over environments since final_env_state.obs is batched
  final_value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined with corrected out_axes.")


**Reasoning**:
Now that the `collect_rollout` function is correctly defined and jitted with the necessary static arguments and batched RNG handling, I need to re-execute the main PPO training loop. This will involve iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, and logging the training progress, which should now run without the previous tracing errors.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")

**Reasoning**:
The `ValueError` in `jax.vmap` for `jax.random.split` indicates that the `out_axes=(0, 0, 0)` specification does not correctly match the nested PyTree structure of `PRNGKey` outputs. Each `PRNGKey` is a tuple of two arrays, so when `jax.random.split` returns three `PRNGKey`s, the `out_axes` needs to reflect this nesting as `((0, 0), (0, 0), (0, 0))` to properly specify the batching for each component of each key.



In [None]:
from typing import Any
from functools import partial

@partial(jax.jit, static_argnums=(1, 3))
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng_batch = carry

    # Corrected out_axes for jax.vmap(jax.random.split)
    # Each PRNGKey is a tuple of 2 elements, so for 3 keys, it's ((0,0),(0,0),(0,0))
    split_keys = jax.vmap(
        jax.random.split,
        in_axes=(0, None),
        out_axes=((0, 0), (0, 0), (0, 0))  # This is the fix
    )(step_rng_batch, 3)

    step_rng_batch, action_rng_batch, value_rng_batch = split_keys[0], split_keys[1], split_keys[2]

    # Sample action (vmap over environments)
    action, log_prob = jax.vmap(ppo_agent.sample_action, in_axes=(None, 0, 0))(actor_params, env_state.obs, action_rng_batch)

    # Predict value (vmap over environments)
    value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng_batch), transition

  # Reset environment for parallel collection
  rng, reset_rng_base = jax.random.split(rng)
  reset_rngs = jax.random.split(reset_rng_base, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Split the initial RNG for the scan carry into a batch of keys for each environment
  rng, initial_step_rng_base = jax.random.split(rng)
  initial_step_rng_batch = jax.random.split(initial_step_rng_base, num_envs) # (num_envs, 2) shaped keys

  # Jax.scan to efficiently collect data for num_steps_to_collect
  # The carry now includes a batch of RNG keys
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, initial_step_rng_batch),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  # vmap over environments since final_env_state.obs is batched
  final_value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined with corrected out_axes.")

**Reasoning**:
Now that the `collect_rollout` function is correctly defined and jitted with the necessary static arguments and batched RNG handling, I need to re-execute the main PPO training loop. This will involve iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, and logging the training progress, which should now run without the previous tracing errors.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")

**Reasoning**:
The `ValueError` indicates a persistent issue with `out_axes` specification for `jax.vmap` when used with `jax.random.split`, suggesting that JAX's internal tracing is having trouble reconciling the PyTree structure of `PRNGKey` outputs with explicit `out_axes`. The most robust solution is often to remove the `out_axes` argument and allow JAX to infer the batching structure, which it can typically do correctly for standard PyTrees like tuples of `PRNGKey`s.



In [None]:
from typing import Any
from functools import partial

@partial(jax.jit, static_argnums=(1, 3))
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng_batch = carry

    # Removed out_axes, allowing JAX to infer the batching for jax.random.split
    split_keys = jax.vmap(
        jax.random.split,
        in_axes=(0, None)
    )(step_rng_batch, 3)

    step_rng_batch, action_rng_batch, value_rng_batch = split_keys[0], split_keys[1], split_keys[2]

    # Sample action (vmap over environments)
    action, log_prob = jax.vmap(ppo_agent.sample_action, in_axes=(None, 0, 0))(actor_params, env_state.obs, action_rng_batch)

    # Predict value (vmap over environments)
    value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng_batch), transition

  # Reset environment for parallel collection
  rng, reset_rng_base = jax.random.split(rng)
  reset_rngs = jax.random.split(reset_rng_base, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Split the initial RNG for the scan carry into a batch of keys for each environment
  rng, initial_step_rng_base = jax.random.split(rng)
  initial_step_rng_batch = jax.random.split(initial_step_rng_base, num_envs) # (num_envs, 2) shaped keys

  # Jax.scan to efficiently collect data for num_steps_to_collect
  # The carry now includes a batch of RNG keys
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, initial_step_rng_batch),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  # vmap over environments since final_env_state.obs is batched
  final_value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined with inferred out_axes.")

**Reasoning**:
Now that the `collect_rollout` function is correctly defined and jitted with the necessary static arguments and batched RNG handling, I need to re-execute the main PPO training loop. This will involve iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, and logging the training progress, which should now run without the previous tracing errors.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")

**Reasoning**:
The `ValueError` arises because `jax.random.split` implicitly vectorizes when given a batch of keys (shape `(num_envs, 2)`), producing a single array of shape `(num_envs, num_splits, 2)`. The previous code incorrectly wrapped this already-vectorized call with `jax.vmap` and then tried to unpack the result by indexing the `num_envs` dimension as if it were the `num_splits` dimension, leading to `action_rng_batch` having an incorrect batch size of `3` instead of `32`. To fix this, I will remove the redundant `jax.vmap` around `jax.random.split` and unpack the resulting `(num_envs, 3, 2)` array by slicing along the second dimension.



In [None]:
from typing import Any
from functools import partial

@partial(jax.jit, static_argnums=(1, 3))
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng_batch = carry

    # Fix: Remove the outer jax.vmap. jax.random.split implicitly handles batched keys.
    # The result will be an array of shape (num_envs, num_splits, 2).
    split_keys_array = jax.random.split(step_rng_batch, 3)

    # Fix: Unpack by slicing along the second (num_splits) dimension.
    step_rng_batch = split_keys_array[:, 0, :]
    action_rng_batch = split_keys_array[:, 1, :]
    value_rng_batch = split_keys_array[:, 2, :]

    # Sample action (vmap over environments)
    action, log_prob = jax.vmap(ppo_agent.sample_action, in_axes=(None, 0, 0))(actor_params, env_state.obs, action_rng_batch)

    # Predict value (vmap over environments)
    value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng_batch), transition

  # Reset environment for parallel collection
  rng, reset_rng_base = jax.random.split(rng)
  reset_rngs = jax.random.split(reset_rng_base, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Split the initial RNG for the scan carry into a batch of keys for each environment
  rng, initial_step_rng_base = jax.random.split(rng)
  initial_step_rng_batch = jax.random.split(initial_step_rng_base, num_envs) # (num_envs, 2) shaped keys

  # Jax.scan to efficiently collect data for num_steps_to_collect
  # The carry now includes a batch of RNG keys
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, initial_step_rng_batch),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  # vmap over environments since final_env_state.obs is batched
  final_value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined with correct RNG handling.")


**Reasoning**:
Now that the `collect_rollout` function is correctly defined and jitted with the necessary static arguments and batched RNG handling, I need to re-execute the main PPO training loop. This will involve iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, and logging the training progress, which should now run without the previous tracing errors.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")


**Reasoning**:
The `ValueError` arises because `jax.random.split` expects a single `PRNGKey` but was implicitly receiving a batch of keys (`step_rng_batch`). To correctly split each key in the batch, `jax.vmap` must be used explicitly around `jax.random.split` within the `_env_step` function.



In [None]:
from typing import Any
from functools import partial

@partial(jax.jit, static_argnums=(1, 3))
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng_batch_in = carry # Renamed incoming rng for clarity

    # FIX: Explicitly use jax.vmap to split the batch of keys
    # jax.random.split(key, 3) for a single key returns a tuple of 3 keys.
    # vmap will apply this to each key in step_rng_batch_in, resulting in a tuple of 3 arrays,
    # each of shape (num_envs, 2) when out_axes=(0,0,0).
    new_rng_keys_tuple = jax.vmap(jax.random.split, in_axes=(0, None), out_axes=(0,0,0))(step_rng_batch_in, 3)

    step_rng_batch_out = new_rng_keys_tuple[0] # Key for the next iteration of _env_step
    action_rng_batch = new_rng_keys_tuple[1]
    value_rng_batch = new_rng_keys_tuple[2]

    # Sample action (vmap over environments)
    action, log_prob = jax.vmap(ppo_agent.sample_action, in_axes=(None, 0, 0))(actor_params, env_state.obs, action_rng_batch)

    # Predict value (vmap over environments)
    value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng_batch_out), transition # Update carry with next key batch

  # Reset environment for parallel collection
  rng, reset_rng_base = jax.random.split(rng)
  reset_rngs = jax.random.split(reset_rng_base, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Split the initial RNG for the scan carry into a batch of keys for each environment
  rng, initial_step_rng_base = jax.random.split(rng)
  initial_step_rng_batch = jax.random.split(initial_step_rng_base, num_envs) # (num_envs, 2) shaped keys

  # Jax.scan to efficiently collect data for num_steps_to_collect
  # The carry now includes a batch of RNG keys
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, initial_step_rng_batch),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  # vmap over environments since final_env_state.obs is batched
  final_value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined with correct RNG handling.")


**Reasoning**:
Now that the `collect_rollout` function is correctly defined and jitted with the necessary static arguments and batched RNG handling, I need to re-execute the main PPO training loop. This will involve iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, and logging the training progress, which should now run without the previous tracing errors.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")


**Reasoning**:
The persistent `ValueError` from `jax.vmap` suggests that `jax.random.split` (when used with a `num` argument greater than 1) does not return a tuple of keys, but rather a single `PRNGKey` array where the `num` splits are along its leading dimension. Therefore, applying `jax.vmap` to `jax.random.split` on a batched input key will result in a single `PRNGKey` array with shape `(num_envs, num_splits, 2)`. I need to remove the `out_axes` from `jax.vmap` (as JAX can infer it) and then unpack this single resulting array by slicing along its second dimension (the `num_splits` dimension) to correctly extract the batched `step_rng_batch_out`, `action_rng_batch`, and `value_rng_batch` for each environment.



In [None]:
from typing import Any
from functools import partial

@partial(jax.jit, static_argnums=(1, 3))
def collect_rollout(rng: jax.random.PRNGKey, ppo_agent: PPOAgent, env_state: Any, num_steps_to_collect: int):
  """Collects a rollout of trajectories from the environment."""

  def _env_step(carry, _):
    env_state, actor_params, critic_params, step_rng_batch_in = carry # Renamed incoming rng for clarity

    # FIX: jax.random.split(key, num) returns a single key array of shape (num, 2) if key is (2,).
    # When vmapped over an input key of shape (num_envs, 2), it will return a single key array of shape (num_envs, num_splits, 2).
    all_split_keys_array = jax.vmap(jax.random.split, in_axes=(0, None))(step_rng_batch_in, 3)

    # Now slice this single array to get the individual batched keys
    step_rng_batch_out = all_split_keys_array[:, 0, :] # Shape (num_envs, 2)
    action_rng_batch = all_split_keys_array[:, 1, :]   # Shape (num_envs, 2)
    value_rng_batch = all_split_keys_array[:, 2, :]    # Shape (num_envs, 2)

    # Sample action (vmap over environments)
    action, log_prob = jax.vmap(ppo_agent.sample_action, in_axes=(None, 0, 0))(actor_params, env_state.obs, action_rng_batch)

    # Predict value (vmap over environments)
    value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, env_state.obs)

    # Step environment
    next_env_state = jit_step(env_state, action)

    # Store data
    transition = {
        'observations': env_state.obs,
        'actions': action,
        'rewards': next_env_state.reward,
        'dones': next_env_state.done,
        'values': value,
        'log_probs': log_prob,
    }

    return (next_env_state, actor_params, critic_params, step_rng_batch_out), transition # Update carry with next key batch

  # Reset environment for parallel collection
  rng, reset_rng_base = jax.random.split(rng)
  reset_rngs = jax.random.split(reset_rng_base, num_envs)
  env_state = jit_reset(reset_rngs)

  actor_params = ppo_agent.actor_state.params
  critic_params = ppo_agent.critic_state.params

  # Split the initial RNG for the scan carry into a batch of keys for each environment
  rng, initial_step_rng_base = jax.random.split(rng)
  initial_step_rng_batch = jax.random.split(initial_step_rng_base, num_envs) # (num_envs, 2) shaped keys

  # Jax.scan to efficiently collect data for num_steps_to_collect
  # The carry now includes a batch of RNG keys
  (final_env_state, _, _, _), traj_batch = jax.lax.scan(
      _env_step,
      (env_state, actor_params, critic_params, initial_step_rng_batch),
      None,
      length=num_steps_to_collect
  )

  # Get the final value for GAE calculation
  # vmap over environments since final_env_state.obs is batched
  final_value = jax.vmap(ppo_agent.critic_network.apply, in_axes=(None, 0))({'params': critic_params}, final_env_state.obs)

  return traj_batch, final_env_state, final_value

print("Data collection function 'collect_rollout' defined with correct RNG handling.")

## Develop the Training Loop

### Subtask:
Continue the main PPO training loop, which involves iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, and logging the training progress.


**Reasoning**:
Now that the `collect_rollout` function is correctly defined and jitted with the necessary static arguments and batched RNG handling, I need to re-execute the main PPO training loop. This will involve iteratively collecting rollouts, computing GAE, normalizing advantages, and updating the actor and critic networks using mini-batches, and logging the training progress, which should now run without the previous tracing errors.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")

**Reasoning**:
The `TypeError` arises because `ppo_loss` receives `apply_fn_actor` and `apply_fn_critic` (which are methods) as dynamic JAX arguments, which JAX cannot trace. To fix this, I need to mark these arguments as static using `functools.partial` with `static_argnums` in the `jax.jit` decorator for `ppo_loss`. Additionally, the way `actor_params` and `critic_params` are passed to `apply_fn_actor` and `apply_fn_critic` within `ppo_loss` is incorrect; they should be nested within a dictionary like `{'params': actor_params}` as expected by Flax's `apply` method.



In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.stats.norm as normal_dist
from functools import partial

# 1. Define a JAX-jittable function, ppo_loss
@partial(jax.jit, static_argnums=(2, 3)) # Mark apply_fn_actor (2) and apply_fn_critic (3) as static
def ppo_loss(
    actor_params,
    critic_params,
    apply_fn_actor,
    apply_fn_critic,
    observations: jnp.ndarray,
    actions: jnp.ndarray,
    advantages: jnp.ndarray,
    returns: jnp.ndarray,
    old_log_probs: jnp.ndarray,
    clip_param: float,
    value_loss_coeff: float,
    entropy_coeff: float
):
  """Computes the PPO loss function."""

  # 2. Apply the actor network to get mean and log_std
  # FIX: Pass parameters nested in a dictionary as expected by Flax apply_fn
  mean, log_std = apply_fn_actor({'params': actor_params}, observations)

  # 3. Create a tfp.distributions.Normal (or equivalent JAX distribution)
  std = jnp.exp(log_std)
  action_distribution = normal_dist.logpdf(x=actions, loc=mean, scale=std)

  # 4. Calculate the new_log_probs of the actions and the entropy
  # For a Gaussian policy, the log_prob of actions and entropy can be computed.
  # Here we use jax.scipy.stats.norm.logpdf for log_probs directly.
  # For entropy, we need to sum over the action dimensions.
  new_log_probs = jnp.sum(action_distribution, axis=-1)
  # The entropy of a multivariate Gaussian is sum(0.5 * (1 + log(2*pi*sigma^2)))
  entropy = jnp.sum(0.5 * (1.0 + jnp.log(2 * jnp.pi) + 2 * log_std), axis=-1)

  # 5. Compute the ratio of the new policy's probabilities to the old policy's probabilities
  ratio = jnp.exp(new_log_probs - old_log_probs)

  # 6. Calculate the clipped surrogate objective (policy loss)
  surr1 = ratio * advantages
  surr2 = jnp.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages
  # Negate because we are using gradient descent (minimizing loss) to maximize objective
  policy_loss = -jnp.mean(jnp.minimum(surr1, surr2))

  # 7. Apply the critic network to get current_values
  # FIX: Pass parameters nested in a dictionary as expected by Flax apply_fn
  current_values = apply_fn_critic({'params': critic_params}, observations)

  # 8. Calculate the value function loss as the mean squared error
  value_loss = jnp.mean(jnp.square(current_values - returns))

  # 9. Calculate the entropy bonus as the mean of the entropy
  entropy_loss = -jnp.mean(entropy) # Typically entropy is maximized, so we subtract it from the total loss

  # 10. Combine the three loss components into a total_loss
  total_loss = policy_loss + value_loss_coeff * value_loss + entropy_coeff * entropy_loss

  # 11. Return the total_loss and individual components
  return total_loss, {
      'policy_loss': policy_loss,
      'value_loss': value_loss,
      'entropy_loss': entropy_loss
  }

print("PPO loss function 'ppo_loss' has been redefined with static arguments and corrected parameter passing.")

**Reasoning**:
The `ppo_loss` function has been redefined with the necessary fixes for static arguments and parameter passing. Now, I need to re-execute the main PPO training loop to ensure these changes are applied and to continue the training process, as instructed by the subtask.



In [None]:
import jax.numpy as jnp
import time

# Initialize environment state for parallel environments
rng, init_rng = jax.random.split(rng)
init_rngs = jax.random.split(init_rng, num_envs)
current_env_state = jit_reset(init_rngs)

# Initialize total steps and best reward for logging
total_steps = 0
best_reward = -jnp.inf

print("Starting PPO training loop...")

while total_steps < num_steps:
  # 4a. Collect a batch of trajectories (rollouts) from the environment
  rng, rollout_rng = jax.random.split(rng)
  traj_batch, current_env_state, final_value = collect_rollout(
      rollout_rng,
      ppo_agent,
      current_env_state,
      episode_length # Collect data for one full episode length per env
  )

  # Stack all collected data from parallel environments into a single batch
  observations = traj_batch['observations']
  actions = traj_batch['actions']
  rewards = traj_batch['rewards']
  dones = traj_batch['dones']
  values = traj_batch['values']
  log_probs = traj_batch['log_probs']

  # 4c. Calculate Generalized Advantage Estimation (GAE) and returns
  # The 'dones' array needs to be augmented with a 0 for the final state
  # so that the value of the final state (final_value) is not discounted
  # if it's not a terminal state.
  # For each parallel environment, we have (episode_length) steps.
  # values shape: (episode_length, num_envs)
  # rewards shape: (episode_length, num_envs)
  # dones shape: (episode_length, num_envs)
  # final_value shape: (num_envs)

  # Reshape to (num_envs, episode_length) for easier processing per environment
  rewards_reshaped = rewards.T
  dones_reshaped = dones.T
  values_reshaped = values.T

  all_advantages = []
  all_returns = []

  for i in range(num_envs):
    # Pad values with the final_value for each environment
    values_i = jnp.concatenate([values_reshaped[i], final_value[i:i+1]])

    # GAE expects (T,) arrays, so pass one environment's data at a time
    advantages_i = compute_gae(
        rewards_reshaped[i],
        values_i,
        dones_reshaped[i],
        gamma,
        gae_lambda
    )
    all_advantages.append(advantages_i)
    all_returns.append(advantages_i + values_reshaped[i]) # Returns are A + V

  advantages = jnp.stack(all_advantages).T # Shape back to (episode_length, num_envs)
  returns = jnp.stack(all_returns).T       # Shape back to (episode_length, num_envs)

  # 4d. Normalize the advantages
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

  # 4e. Reshape and flatten the collected data for mini-batch processing
  # All data needs to be flattened from (episode_length, num_envs, ...) to (batch_size, ...)
  flattened_observations = observations.reshape(-1, observations.shape[-1])
  flattened_actions = actions.reshape(-1, actions.shape[-1])
  flattened_advantages = advantages.flatten()
  flattened_returns = returns.flatten()
  flattened_old_log_probs = log_probs.flatten()

  # 4f. Iterate for a specified number of `epochs_per_update`
  for _ in range(epochs_per_update):
    rng, shuffle_rng = jax.random.split(rng)
    permutation = jax.random.permutation(shuffle_rng, batch_size)

    shuffled_observations = flattened_observations[permutation]
    shuffled_actions = flattened_actions[permutation]
    shuffled_advantages = flattened_advantages[permutation]
    shuffled_returns = flattened_returns[permutation]
    shuffled_old_log_probs = flattened_old_log_probs[permutation]

    for i in range(num_minibatches):
      start_idx = i * minibatch_size
      end_idx = (i + 1) * minibatch_size

      mb_observations = shuffled_observations[start_idx:end_idx]
      mb_actions = shuffled_actions[start_idx:end_idx]
      mb_advantages = shuffled_advantages[start_idx:end_idx]
      mb_returns = shuffled_returns[start_idx:end_idx]
      mb_old_log_probs = flattened_old_log_probs[start_idx:end_idx]

      # 4f.ii. For each mini-batch, call the PPO agent's `update` method
      ppo_agent.actor_state, ppo_agent.critic_state, metrics = ppo_agent.update(
          ppo_agent.actor_state,
          ppo_agent.critic_state,
          mb_observations,
          mb_actions,
          mb_advantages,
          mb_returns,
          mb_old_log_probs,
          clip_param,
          value_loss_coeff,
          entropy_coeff
      )

  total_steps += batch_size # Update total steps count

  # 4g. Periodically evaluate the agent's performance
  # For simplicity, we can just use the mean reward from the collected rollouts as an eval metric
  # A more rigorous evaluation would involve running dedicated evaluation episodes with an evaluation policy.
  avg_episode_reward = rewards.mean()

  if avg_episode_reward > best_reward:
      best_reward = avg_episode_reward

  # 4h. Print or log relevant training statistics
  print(f"Total Steps: {total_steps}, Avg Episode Reward: {avg_episode_reward:.2f}, Best Reward: {best_reward:.2f}")
  print(f"  Policy Loss: {metrics['policy_loss']:.4f}, Value Loss: {metrics['value_loss']:.4f}, Entropy Loss: {metrics['entropy_loss']:.4f}")

print("PPO training finished.")

# Optional: Save the final policy parameters
# model.save_params('/tmp/my_custom_ppo_policy', ppo_agent.actor_state.params)
# print("Saved final policy parameters.")

In [None]:
eval_env = envs.get_environment(env_name)

jit_reset_eval = jax.jit(eval_env.reset)
jit_step_eval = jax.jit(eval_env.step)

# Get the final actor parameters from the trained agent
final_actor_params = ppo_agent.actor_state.params

# Initialize the state for evaluation
rng, eval_rng = jax.random.split(rng)
eval_state = jit_reset_eval(eval_rng)
rollout_eval = [eval_state.pipeline_state]

# grab a trajectory using the trained policy
n_steps_eval = 500  # Number of steps for evaluation rollout
render_every = 2   # Render every N steps to speed up video generation

print("Generating rollout with trained policy...")
for i in range(n_steps_eval):
  act_rng, eval_rng = jax.random.split(eval_rng)
  # Use the ppo_agent's sample_action for inference
  # We need to explicitly pass the actor_params here
  action, _ = ppo_agent.sample_action(final_actor_params, eval_state.obs, act_rng)
  eval_state = jit_step_eval(eval_state, action)
  rollout_eval.append(eval_state.pipeline_state)

  if eval_state.done:
    print(f"Episode finished early at step {i+1}")
    break

print("Rendering video...")
media.show_video(eval_env.render(rollout_eval[::render_every]), fps=1.0 / eval_env.dt / render_every)

# PPO - github implementation

In [None]:
from abc import *

import torch
import torch.nn as nn
class NetworkBase(nn.Module, metaclass=ABCMeta):
    @abstractmethod
    def __init__(self):
        super(NetworkBase, self).__init__()
    @abstractmethod
    def forward(self, x):
        return x

class Network(NetworkBase):
    def __init__(self, layer_num, input_dim, output_dim, hidden_dim, activation_function = torch.relu,last_activation = None):
        super(Network, self).__init__()
        self.activation = activation_function
        self.last_activation = last_activation
        layers_unit = [input_dim]+ [hidden_dim]*(layer_num-1)
        layers = ([nn.Linear(layers_unit[idx],layers_unit[idx+1]) for idx in range(len(layers_unit)-1)])
        self.layers = nn.ModuleList(layers)
        self.last_layer = nn.Linear(layers_unit[-1],output_dim)
        self.network_init()
    def forward(self, x):
        return self._forward(x)
    def _forward(self, x):
        for layer in self.layers:
            x = self.activation(layer(x))
        x = self.last_layer(x)
        if self.last_activation != None:
            x = self.last_activation(x)
        return x
    def network_init(self):
        for layer in self.modules():
            if isinstance(layer, nn.Linear):
                nn.init.orthogonal_(layer.weight)
                layer.bias.data.zero_()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Actor(Network):
    def __init__(self, layer_num, input_dim, output_dim, hidden_dim, activation_function = torch.tanh,last_activation = None, trainable_std = False):
        super(Actor, self).__init__(layer_num, input_dim, output_dim, hidden_dim, activation_function ,last_activation)
        self.trainable_std = trainable_std
        if self.trainable_std == True:
            self.logstd = nn.Parameter(torch.zeros(1, output_dim))
    def forward(self, x):
        mu = self._forward(x)
        if self.trainable_std == True:
            std = torch.exp(self.logstd)
        else:
            logstd = torch.zeros_like(mu)
            std = torch.exp(logstd)
        return mu,std

class Critic(Network):
    def __init__(self, layer_num, input_dim, output_dim, hidden_dim, activation_function, last_activation = None):
        super(Critic, self).__init__(layer_num, input_dim, output_dim, hidden_dim, activation_function ,last_activation)

    def forward(self, *x):
        x = torch.cat(x,-1)
        return self._forward(x)


In [None]:
import numpy as np
import torch

class Dict(dict):
    def __init__(self,config,section_name,location = False):
        super(Dict,self).__init__()
        self.initialize(config, section_name,location)
    def initialize(self, config, section_name,location):
        for key,value in config.items(section_name):
            if location :
                self[key] = value
            else:
                self[key] = eval(value)
    def __getattr__(self,val):
        return self[val]

def make_transition(state,action,reward,next_state,done,log_prob=None):
    transition = {}
    transition['state'] = state
    transition['action'] = action
    transition['reward'] = reward
    transition['next_state'] = next_state
    transition['log_prob'] = log_prob
    transition['done'] = done
    return transition

def make_mini_batch(*value):
    mini_batch_size = value[0]
    full_batch_size = len(value[1])
    full_indices = np.arange(full_batch_size)
    np.random.shuffle(full_indices)
    for i in range(full_batch_size // mini_batch_size):
        indices = full_indices[mini_batch_size*i : mini_batch_size*(i+1)]
        yield [x[indices] for x in value[1:]]

def convert_to_tensor(*value):
    device = value[0]
    return [torch.tensor(x).float().to(device) for x in value[1:]]

class ReplayBuffer():
    def __init__(self, action_prob_exist, max_size, state_dim, num_action):
        self.max_size = max_size
        self.data_idx = 0
        self.action_prob_exist = action_prob_exist
        self.data = {}

        self.data['state'] = np.zeros((self.max_size, state_dim))
        self.data['action'] = np.zeros((self.max_size, num_action))
        self.data['reward'] = np.zeros((self.max_size, 1))
        self.data['next_state'] = np.zeros((self.max_size, state_dim))
        self.data['done'] = np.zeros((self.max_size, 1))
        if self.action_prob_exist :
            self.data['log_prob'] = np.zeros((self.max_size, 1))
    def put_data(self, transition):
        idx = self.data_idx % self.max_size
        self.data['state'][idx] = transition['state']
        self.data['action'][idx] = transition['action']
        self.data['reward'][idx] = transition['reward']
        self.data['next_state'][idx] = transition['next_state']
        self.data['done'][idx] = float(transition['done'])
        if self.action_prob_exist :
            self.data['log_prob'][idx] = transition['log_prob']

        self.data_idx += 1
    def sample(self, shuffle, batch_size = None):
        if shuffle :
            sample_num = min(self.max_size, self.data_idx)
            rand_idx = np.random.choice(sample_num, batch_size,replace=False)
            sampled_data = {}
            sampled_data['state'] = self.data['state'][rand_idx]
            sampled_data['action'] = self.data['action'][rand_idx]
            sampled_data['reward'] = self.data['reward'][rand_idx]
            sampled_data['next_state'] = self.data['next_state'][rand_idx]
            sampled_data['done'] = self.data['done'][rand_idx]
            if self.action_prob_exist :
                sampled_data['log_prob'] = self.data['log_prob'][rand_idx]
            return sampled_data
        else:
            return self.data
    def size(self):
        return min(self.max_size, self.data_idx)
class RunningMeanStd(object):
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)


def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class PPO(nn.Module):
    def __init__(self, writer, device, state_dim, action_dim, args):
        super(PPO,self).__init__()
        self.args = args

        self.data = ReplayBuffer(action_prob_exist = True, max_size = self.args.traj_length, state_dim = state_dim, num_action = action_dim)
        self.actor = Actor(self.args.layer_num, state_dim, action_dim, self.args.hidden_dim, \
                           self.args.activation_function,self.args.last_activation,self.args.trainable_std)
        self.critic = Critic(self.args.layer_num, state_dim, 1, \
                             self.args.hidden_dim, self.args.activation_function,self.args.last_activation)

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.args.actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.args.critic_lr)

        self.writer = writer
        self.device = device

    def get_action(self,x):
        mu,sigma = self.actor(x)
        return mu,sigma

    def v(self,x):
        return self.critic(x)

    def put_data(self,transition):
        self.data.put_data(transition)

    def get_gae(self, states, rewards, next_states, dones):
        values = self.v(states).detach()  # (T, 1)
        next_values = self.v(next_states).detach()  # (T, 1)

        # TD errors: delta_t = r_t + gamma * V(s_{t+1}) * (1 - d_t) - V(s_t)
        td_errors = rewards + self.args.gamma * next_values * (1 - dones) - values  # (T, 1)

        advantages = torch.zeros_like(rewards).to(self.device)  # (T, 1)
        last_gae_lam = 0.0

        # Iterate backward to compute GAE: A_t = delta_t + gamma * lambda * A_{t+1} * (1 - d_t)
        for t in reversed(range(len(td_errors))):
            # (1 - dones[t]) acts as a mask, setting future advantage to 0 if terminal
            advantages[t] = td_errors[t] + self.args.gamma * self.args.lambda_ * (1 - dones[t]) * last_gae_lam
            last_gae_lam = advantages[t] # A_t becomes A_{t+1} for the next step

        return values, advantages

    def train_net(self,n_epi):
        data = self.data.sample(shuffle = False)
        states, actions, rewards, next_states, dones, old_log_probs = convert_to_tensor(self.device, data['state'], data['action'], data['reward'], data['next_state'], data['done'], data['log_prob'])

        old_values, advantages = self.get_gae(states, rewards, next_states, dones)
        returns = advantages + old_values
        advantages = (advantages - advantages.mean())/(advantages.std()+1e-3)

        for i in range(self.args.train_epoch):
            for state,action,old_log_prob,advantage,return_,old_value \
            in make_mini_batch(self.args.batch_size, states, actions, \
                                           old_log_probs,advantages,returns,old_values):
                curr_mu,curr_sigma = self.get_action(state)
                value = self.v(state).float()
                curr_dist = torch.distributions.Normal(curr_mu,curr_sigma)
                entropy = curr_dist.entropy() * self.args.entropy_coef
                curr_log_prob = curr_dist.log_prob(action).sum(1,keepdim = True)

                #policy clipping
                ratio = torch.exp(curr_log_prob - old_log_prob.detach())
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1-self.args.max_clip, 1+self.args.max_clip) * advantage
                actor_loss = (-torch.min(surr1, surr2) - entropy).mean()

                #value clipping (PPO2 technic)
                old_value_clipped = old_value + (value - old_value).clamp(-self.args.max_clip,self.args.max_clip)
                value_loss = (value - return_.detach().float()).pow(2)
                value_loss_clipped = (old_value_clipped - return_.detach().float()).pow(2)
                critic_loss = 0.5 * self.args.critic_coef * torch.max(value_loss,value_loss_clipped).mean()

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.max_grad_norm)
                self.actor_optimizer.step()

                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.max_grad_norm)
                self.critic_optimizer.step()

                if self.writer != None:
                    self.writer.add_scalar("loss/actor_loss", actor_loss.item(), n_epi)
                    self.writer.add_scalar("loss/critic_loss", critic_loss.item(), n_epi)


## PPO - training loop

In [None]:
# Import necessary libraries
import jax
import jax.numpy as jnp
from brax import envs
import torch
import numpy as np

# --- 1. Define Hyperparameters for the PyTorch PPO Agent ---
# Create a simple config class to hold hyperparameters
class PPOConfig:
    def __init__(self):
        self.traj_length = 2048 # Number of steps to collect before update
        self.layer_num = 2
        self.hidden_dim = 256
        self.activation_function = torch.tanh
        self.last_activation = None
        self.trainable_std = True
        self.actor_lr = 3e-4
        self.critic_lr = 3e-4
        self.train_epoch = 10 # Number of PPO epochs
        self.batch_size = 64 # Minibatch size for update
        self.gamma = 0.99
        self.lambda_ = 0.95
        self.max_clip = 0.2
        self.critic_coef = 0.5
        self.entropy_coef = 0.01
        self.max_grad_norm = 0.5

ppo_torch_args = PPOConfig()

# --- 2. Initialize Brax Environment (JAX) ---
env_name = 'humanoid'
env = envs.get_environment(env_name) # Removed episode_length=1000
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# --- DEBUG: Print environment action size ---
print(f"DEBUG: env.observation_size: {env.observation_size}")
print(f"DEBUG: env.action_size: {env.action_size}")
print(f"DEBUG: env.sys.nu (num actuators): {env.sys.nu}")

# --- 3. Initialize PyTorch PPO Agent ---
# Set device for PyTorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using PyTorch device: {device}")

# Dummy writer for now, as it's not provided in the notebook context
class DummyWriter:
    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        pass
writer = DummyWriter()

ppo_torch_agent = PPO(
    writer=writer,
    device=device,
    state_dim=env.observation_size,
    action_dim=env.action_size,
    args=ppo_torch_args
)
ppo_torch_agent.to(device)

print("PyTorch PPO agent initialized.")

# --- 4. Main Training Loop ---
num_total_steps_pytorch = 1_000_000 # Total environment steps for PyTorch PPO
current_total_steps = 0
episode_count = 0
rng = jax.random.PRNGKey(0) # JAX RNG for environment

print("Starting PyTorch PPO training loop...")

while current_total_steps < num_total_steps_pytorch:
    episode_count += 1
    rng, reset_rng = jax.random.split(rng)
    env_state = jit_reset(reset_rng)

    episode_reward = 0
    # Clear replay buffer for new trajectory collection (on-policy PPO)
    ppo_torch_agent.data = ReplayBuffer(
        action_prob_exist=True,
        max_size=ppo_torch_args.traj_length,
        state_dim=env.observation_size,
        num_action=env.action_size
    )

    for t in range(ppo_torch_args.traj_length): # Collect for traj_length steps
        # Convert JAX observation to PyTorch tensor and add batch dimension
        obs_torch = torch.from_numpy(np.array(env_state.obs)).float().to(device).unsqueeze(0)

        # Get action from PyTorch actor
        with torch.no_grad():
            mu, sigma = ppo_torch_agent.get_action(obs_torch)
            action_dist = torch.distributions.Normal(mu, sigma)
            action_torch = action_dist.sample()
            # --- DEBUG: Print action_torch shape ---
            # print(f"DEBUG: action_torch shape: {action_torch.shape}")
            log_prob_torch = action_dist.log_prob(action_torch).sum(dim=-1, keepdim=True)

        # Convert PyTorch action to JAX array and remove batch dimension for environment step
        action_jax = jnp.asarray(action_torch.squeeze(0).cpu().numpy())
        # --- DEBUG: Print action_jax shape ---
        # print(f"DEBUG: action_jax shape (after squeeze): {action_jax.shape}")

        # Step JAX environment
        rng, step_rng = jax.random.split(rng) # Need a new rng for each step if needed by brax (jit_step doesn't take it)
        next_env_state = jit_step(env_state, action_jax)

        # Prepare numpy arrays for ReplayBuffer (remove batch dimension where applicable)
        obs_np = obs_torch.squeeze(0).cpu().numpy()
        action_np = action_torch.squeeze(0).cpu().numpy()
        reward_np = np.array(next_env_state.reward).reshape(1) # Ensure (1,) shape
        next_obs_np = np.array(next_env_state.obs)
        done_np = np.array(next_env_state.done).reshape(1)     # Ensure (1,) shape
        log_prob_np = log_prob_torch.squeeze(0).cpu().numpy()

        # Store transition in PyTorch ReplayBuffer
        transition = make_transition(
            obs_np,
            action_np,
            reward_np,
            next_obs_np,
            done_np,
            log_prob_np
        )
        ppo_torch_agent.put_data(transition)

        env_state = next_env_state
        episode_reward += next_env_state.reward
        current_total_steps += 1

        if env_state.done:
            print(f"Episode {episode_count} finished early at step {t+1}. Total Reward: {episode_reward:.2f}. Total steps: {current_total_steps}")
            break
    else: # If loop completes without break
      print(f"Episode {episode_count} completed {ppo_torch_args.traj_length} steps. Total Reward: {episode_reward:.2f}. Total steps: {current_total_steps}")

    # Train PPO agent after collecting traj_length steps
    if ppo_torch_agent.data.size() >= ppo_torch_args.traj_length:
        ppo_torch_agent.train_net(episode_count)
        print(f"PPO agent trained for episode {episode_count}.")
        # Clear the buffer for next rollout (on-policy)
        ppo_torch_agent.data = ReplayBuffer(
            action_prob_exist=True,
            max_size=ppo_torch_args.traj_length,
            state_dim=env.observation_size,
            num_action=env.action_size
        )


print("\nPyTorch PPO training finished.")

# --- 5. Visualize Trained Policy (PyTorch PPO) ---
print("\nGenerating rollout with trained PyTorch PPO policy...")
eval_env = envs.get_environment(env_name) # Create a new env for evaluation if needed
jit_reset_eval = jax.jit(eval_env.reset)
jit_step_eval = jax.jit(eval_env.step)

rng, eval_rng = jax.random.split(rng)
eval_state = jit_reset_eval(eval_rng)
rollout_eval_torch = [eval_state.pipeline_state]

n_steps_eval = 500
render_every = 2

for i in range(n_steps_eval):
    # Convert JAX observation to PyTorch tensor and add batch dimension
    obs_torch_eval = torch.from_numpy(np.array(eval_state.obs)).float().to(device).unsqueeze(0)

    # Get action from PyTorch actor (deterministic for evaluation)
    with torch.no_grad():
        mu_eval, _ = ppo_torch_agent.get_action(obs_torch_eval)
        action_torch_eval = mu_eval # Use mean for deterministic action
        # --- DEBUG: Print action_torch_eval shape ---
        print(f"DEBUG: action_torch_eval shape: {action_torch_eval.shape}")

    # Convert PyTorch action to JAX array and remove batch dimension for environment step
    action_jax_eval = jnp.asarray(action_torch_eval.squeeze(0).cpu().numpy())
    # --- DEBUG: Print action_jax_eval shape ---
    print(f"DEBUG: action_jax_eval shape (after squeeze): {action_jax_eval.shape}")

    # Step JAX environment
    eval_state = jit_step_eval(eval_state, action_jax_eval)
    rollout_eval_torch.append(eval_state.pipeline_state)

    if eval_state.done:
        print(f"Evaluation episode finished early at step {i+1}")
        break

print("Rendering video from PyTorch PPO trained policy...")
media.show_video(eval_env.render(rollout_eval_torch[::render_every]), fps=1.0 / eval_env.dt / render_every)