# Mujoco Cpider Notebook contains:
  1. Instalation Section
  2. Env definition with rollout
  3. Training with rollout

# Install MuJoCo, MJX, and Brax

In [None]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax


In [None]:
#@title Check if MuJoCo installation was successful

from google.colab import files

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags


In [None]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

In [None]:
import os
# os.environ['MUJOCO_GL'] = 'egl' # Ensure EGL rendering is used

from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict


import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

# Simple ENV with spider



## XML spider defintion

In [None]:
spider_xml = """
<mujoco model="ant">
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <option integrator="RK4" timestep="0.01"/>
  <custom>
    <numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
  </custom>
  <default>
    <joint armature="1" damping="1" limited="true"/>
    <geom conaffinity="0" condim="3" density="5.0" friction="1 0.5 0.5" margin="0.01" rgba="0.8 0.6 0.4 1"/>
  </default>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos="0 0 0.75">
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <geom name="torso_geom" pos="0 0 0" size="0.25" type="sphere"/>
      <joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/>
      <body name="front_left_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux_1_geom" size="0.08" type="capsule"/>
        <body name="aux_1" pos="0.2 0.2 0">
          <joint axis="0 0 1" name="hip_1" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 0.2 0">
            <joint axis="-1 1 0" name="ankle_1" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="front_right_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>
        <body name="aux_2" pos="-0.2 0.2 0">
          <joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>
          <body pos="-0.2 0.2 0">
            <joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="back_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>
        <body name="aux_3" pos="-0.2 -0.2 0">
          <joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>
          <body pos="-0.2 -0.2 0">
            <joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="right_back_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux_4_geom" size="0.08" type="capsule"/>
        <body name="aux_4" pos="0.2 -0.2 0">
          <joint axis="0 0 1" name="hip_4" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 -0.2 0">
            <joint axis="1 1 0" name="ankle_4" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="fourth_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_3" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_3" gear="150"/>
  </actuator>
</mujoco>
"""

## Spider Env

In [None]:
class Humanoid(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=1.5, # Zmniejszone (było 10.0)
      north_reward_weight=10.0,   # Zmniejszone (było 10.0)
      sideways_cost_weight=0.1,  # Zwiększone
      ctrl_cost_weight=0.001,      # Zwiększone
      healthy_reward=5.0,        # Zwiększone
      terminate_when_unhealthy=True, # Zmienione na True (KLUCZOWE!)
      orientation_cost_weight=1.0, # Kara za odchylenie od pionu (roll/pitch)
      z_angular_velocity_cost_weight=0.1, # Kara za wirowanie tułowia (yaw rate)
      healthy_z_range=(0.3, 1.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      episode_length: int = 1000,
      **kwargs,
  ):
    mj_model = mujoco.MjModel.from_xml_string(spider_xml)
    mj_data = mujoco.MjData(mj_model)
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    sys = mjcf.load_model(mj_model)

    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)
    kwargs['backend'] = 'mjx'

    self.episode_length = episode_length
    super().__init__(sys, **kwargs)

    self._north_reward_weight = north_reward_weight
    self._sideways_cost_weight = sideways_cost_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._orientation_cost_weight = orientation_cost_weight
    self._z_angular_velocity_cost_weight = z_angular_velocity_cost_weight
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )
    self._torso_body_idx = mujoco.mj_name2id(
        self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, 'torso'
    )
    self._mj_data = mj_data

  # --- TUTAJ BYŁ PRAWDOPODOBNIE BŁĄD WCIĘCIA ---
  # Metoda reset musi być na tym samym poziomie wcięcia co __init__

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = jp.asarray(self._mj_data.qpos) + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jp.asarray(self._mj_data.qvel) + jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
        'north_reward': zero,
        'sideways_cost': zero,
        'orientation_cost': zero, # Dodane
        'z_angular_cost': zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)
    orientation_cost = self._orientation_cost_weight * jp.sum(jp.square(data.q[4:6]))
    # ^ Używamy q[4] i q[5], które odpowiadają za odchylenia 'x' i 'y' kwaternionu (roll i pitch)

    # 2. KOSZT PRĘDKOŚCI KĄTOWEJ Z (kara za wirowanie tułowia)
    # Prędkość kątowa tułowia (wokół osi Z) znajduje się w cvel[idx, 5].
    z_angular_velocity = data.cvel[self._torso_body_idx, 5]
    z_angular_cost = self._z_angular_velocity_cost_weight * jp.square(z_angular_velocity)

    # ----------------------------------------------------
    # Calculate reward for moving in the 'north' direction (positive y-axis)
    torso_y_velocity = data.cvel[self._torso_body_idx, 1]
    north_reward = self._north_reward_weight * torso_y_velocity

    # Calculate cost for sideways movement (x-axis)
    torso_x_velocity = data.cvel[self._torso_body_idx, 0]
    sideways_cost = self._sideways_cost_weight * jp.abs(torso_x_velocity)

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)

    # Healthy reward logic
    healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    # Done logic (CRITICAL FIX)
    if self._terminate_when_unhealthy:
        done = 1.0 - is_healthy
    else:
        done = 0.0

    # Reward scaling (divided by 100.0 as discussed to stabilize PPO)
    raw_reward = north_reward + healthy_reward - ctrl_cost - sideways_cost - orientation_cost - z_angular_cost # Nowy koszt
    # Opcjonalnie: skalowanie tutaj, lub w configu PPO.
    # Na razie zostawiamy surowe, bo zmieniłeś wagi na mniejsze (1.5 zamiast 10).
    reward = raw_reward

    obs = self._get_obs(data, action)

    state.metrics.update(
        forward_reward=north_reward,
        reward_linvel=north_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=data.xpos[self._torso_body_idx, 0],
        y_position=data.xpos[self._torso_body_idx, 1],
        distance_from_origin=jp.linalg.norm(data.xpos[self._torso_body_idx, :2]),
        x_velocity=torso_x_velocity,
        y_velocity=torso_y_velocity,
        north_reward=north_reward,
        sideways_cost=sideways_cost,
        orientation_cost=orientation_cost, # Dodane
        z_angular_cost=z_angular_cost,
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    return jp.concatenate([
        data.qpos[2:3],  # Torso z-position
        data.cvel[self._torso_body_idx, 0:1], # Torso x-velocity
        data.cvel[self._torso_body_idx, 1:2], # Torso y-velocity
        data.qpos[7:], # Joint positions
        data.qvel[6:], # Joint velocities
    ])

# Re-register environment
envs.register_environment('humanoid', Humanoid)


# Add code to get and print the observation shape
env_test = Humanoid()
dummy_data = mujoco.MjData(env_test.sys.mj_model) # Use mj_data
dummy_obs = env_test._get_obs(env_test.pipeline_init(jp.asarray(dummy_data.qpos), jp.asarray(dummy_data.qvel)), jp.zeros(env_test.sys.nu))
print("Observation shape:", dummy_obs.shape)


##(Optional) Show Data class for inspection

In [None]:
# import jax
# import mujoco
# from brax import envs
# import numpy as np

# # Instantiate the environment
# env_name = 'humanoid'
# env = envs.get_environment(env_name)

# # Define the jit reset function
# jit_reset = jax.jit(env.reset)

# # Reset the environment to get an initial state
# rng = jax.random.PRNGKey(0)
# state = jit_reset(rng)

# # The mjx.Data object is stored in state.pipeline_state
# data = state.pipeline_state

# print("--- Contents of mjx.Data object ---")
# print(f"Type of data: {type(data)}\n")

# print("--- All attributes of mjx.Data object ---\n")
# for attr_name in sorted(dir(data)):
#     if not attr_name.startswith('_'): # Exclude private attributes
#         attr_value = getattr(data, attr_name)
#         attr_type = type(attr_value)
#         if isinstance(attr_value, (np.ndarray, jax.Array)): # Check if it's a JAX or NumPy array
#             print(f"  Attribute: {attr_name}, Type: {attr_type}, Shape: {attr_value.shape}")
#         else:
#             print(f"  Attribute: {attr_name}, Type: {attr_type}")

# print("\n--- Specific attributes shown before ---")
# print("1. Generalized positions (qpos):")
# print(f"  Shape: {data.qpos.shape}")
# print(f"  Value (first 10 elements): {data.qpos[:10]}\n")

# print("2. Generalized velocities (qvel):")
# print(f"  Shape: {data.qvel.shape}")
# print(f"  Value (first 10 elements): {data.qvel[:10]}\n")

# print("3. Cartesian position of bodies (xpos):")
# print(f"  Shape: {data.xpos.shape}")
# print(f"  Value (first 5 bodies):\n{data.xpos[:5]}\n")

# print("4. Cartesian orientation of bodies (xquat):")
# print(f"  Shape: {data.xquat.shape}")
# print(f"  Value (first 5 bodies):\n{data.xquat[:5]}\n")

# print("5. Center of mass velocity (cvel):")
# print(f"  Shape: {data.cvel.shape}")
# print(f"  Value (first 5 bodies):\n{data.cvel[:5]}\n")

# print("6. Body indices (for reference, not part of data object directly but useful):")
# print(f"  Torso body index: {env._torso_body_idx}")

## (Optional) Inspect action

In [None]:
# import jax
# from brax import envs
# import jax.numpy as jp
# from brax.training.agents.ppo import networks as ppo_networks
# from brax.training.agents.ppo import train as ppo
# from brax.io import model

# # Re-instantiate the environment if not already available in this scope
# env_name = 'humanoid'
# env = envs.get_environment(env_name)

# # Define and train a minimal policy if jit_inference_fn is not defined
# try:
#     # Attempt to use jit_inference_fn if already defined (e.g., from a previous run of training cells)
#     _ = jit_inference_fn
# except NameError:
#     print("jit_inference_fn not found, training a minimal policy...")
#     make_inference_fn, params, _ = ppo.train(
#         environment=env, num_timesteps=5_000, num_evals=1, episode_length=env.episode_length # Minimal training just for this demo
#     )
#     model_path = '/tmp/mjx_brax_policy'
#     model.save_params(model_path, params)
#     params = model.load_params(model_path)
#     inference_fn = make_inference_fn(params)
#     jit_inference_fn = jax.jit(inference_fn)
#     print("Minimal policy trained and jit_inference_fn defined.")

# # Initialize the state
# current_rng = jax.random.PRNGKey(1) # Use a new RNG key
# state = jax.jit(env.reset)(current_rng)

# # Generate an action using the inference function
# act_rng, current_rng = jax.random.split(current_rng)
# ctrl, _ = jit_inference_fn(state.obs, act_rng)

# print("--- Content of the 'action' variable (ctrl from policy) ---")
# print(f"Type of action: {type(ctrl)}")
# print(f"Shape of action: {ctrl.shape}")
# print(f"Value of action: {ctrl}")


## Visualize a Rollout

Let's instantiate the environment and visualize a short rollout.

NOTE: Since episodes terminate early if the torso is below the healthy z-range, the only relevant contacts for this task are between the feet and the plane. We turn off other contacts.

In [None]:
# instantiate the environment
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)


In [None]:
# initialize the state
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state.pipeline_state]

# grab a trajectory
for i in range(100):
  ctrl = jp.zeros(env.sys.nu) # Set control input to zero for standing still
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

media.show_video(env.render(rollout), fps=1.0 / env.dt)

In [None]:
print("Observation space:", env.observation_size)

# PPO - github implementation

In [None]:
from abc import *

import torch
import torch.nn as nn
class NetworkBase(nn.Module, metaclass=ABCMeta):
    @abstractmethod
    def __init__(self):
        super(NetworkBase, self).__init__()
    @abstractmethod
    def forward(self, x):
        return x

class Network(NetworkBase):
    def __init__(self, layer_num, input_dim, output_dim, hidden_dim, activation_function = torch.relu,last_activation = None):
        super(Network, self).__init__()
        self.activation = activation_function
        self.last_activation = last_activation
        layers_unit = [input_dim]+ [hidden_dim]*(layer_num-1)
        layers = ([nn.Linear(layers_unit[idx],layers_unit[idx+1]) for idx in range(len(layers_unit)-1)])
        self.layers = nn.ModuleList(layers)
        self.last_layer = nn.Linear(layers_unit[-1],output_dim)
        self.network_init()
    def forward(self, x):
        return self._forward(x)
    def _forward(self, x):
        for layer in self.layers:
            x = self.activation(layer(x))
        x = self.last_layer(x)
        if self.last_activation != None:
            x = self.last_activation(x)
        return x
    def network_init(self):
        for layer in self.modules():
            if isinstance(layer, nn.Linear):
                nn.init.orthogonal_(layer.weight)
                layer.bias.data.zero_()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Actor(Network):
    def __init__(self, layer_num, input_dim, output_dim, hidden_dim, activation_function = torch.tanh,last_activation = None, trainable_std = False):
        super(Actor, self).__init__(layer_num, input_dim, output_dim, hidden_dim, activation_function ,last_activation)
        self.trainable_std = trainable_std
        if self.trainable_std == True:
            self.logstd = nn.Parameter(torch.zeros(1, output_dim))
    def forward(self, x):
        mu = self._forward(x)
        if self.trainable_std == True:
            std = torch.exp(self.logstd)
        else:
            logstd = torch.zeros_like(mu)
            std = torch.exp(logstd)
        return mu,std

class Critic(Network):
    def __init__(self, layer_num, input_dim, output_dim, hidden_dim, activation_function, last_activation = None):
        super(Critic, self).__init__(layer_num, input_dim, output_dim, hidden_dim, activation_function ,last_activation)

    def forward(self, *x):
        x = torch.cat(x,-1)
        return self._forward(x)


In [None]:
import numpy as np
import torch

class Dict(dict):
    def __init__(self,config,section_name,location = False):
        super(Dict,self).__init__()
        self.initialize(config, section_name,location)
    def initialize(self, config, section_name,location):
        for key,value in config.items(section_name):
            if location :
                self[key] = value
            else:
                self[key] = eval(value)
    def __getattr__(self,val):
        return self[val]

def make_transition(state,action,reward,next_state,done,log_prob=None):
    transition = {}
    transition['state'] = state
    transition['action'] = action
    transition['reward'] = reward
    transition['next_state'] = next_state
    transition['log_prob'] = log_prob
    transition['done'] = done
    return transition

def make_mini_batch(*value):
    mini_batch_size = value[0]
    full_batch_size = len(value[1])
    full_indices = np.arange(full_batch_size)
    np.random.shuffle(full_indices)
    for i in range(full_batch_size // mini_batch_size):
        indices = full_indices[mini_batch_size*i : mini_batch_size*(i+1)]
        yield [x[indices] for x in value[1:]]

def convert_to_tensor(*value):
    device = value[0]
    return [torch.tensor(x).float().to(device) for x in value[1:]]

class ReplayBuffer():
    def __init__(self, action_prob_exist, max_size, state_dim, num_action):
        self.max_size = max_size
        self.data_idx = 0
        self.action_prob_exist = action_prob_exist
        self.data = {}

        self.data['state'] = np.zeros((self.max_size, state_dim))
        self.data['action'] = np.zeros((self.max_size, num_action))
        self.data['reward'] = np.zeros((self.max_size, 1))
        self.data['next_state'] = np.zeros((self.max_size, state_dim))
        self.data['done'] = np.zeros((self.max_size, 1))
        if self.action_prob_exist :
            self.data['log_prob'] = np.zeros((self.max_size, 1))
    def put_data(self, transition):
        idx = self.data_idx % self.max_size
        self.data['state'][idx] = transition['state']
        self.data['action'][idx] = transition['action']
        self.data['reward'][idx] = transition['reward']
        self.data['next_state'][idx] = transition['next_state']
        self.data['done'][idx] = float(transition['done'])
        if self.action_prob_exist :
            self.data['log_prob'][idx] = transition['log_prob']

        self.data_idx += 1
    def sample(self, shuffle, batch_size = None):
        if shuffle :
            sample_num = min(self.max_size, self.data_idx)
            rand_idx = np.random.choice(sample_num, batch_size,replace=False)
            sampled_data = {}
            sampled_data['state'] = self.data['state'][rand_idx]
            sampled_data['action'] = self.data['action'][rand_idx]
            sampled_data['reward'] = self.data['reward'][rand_idx]
            sampled_data['next_state'] = self.data['next_state'][rand_idx]
            sampled_data['done'] = self.data['done'][rand_idx]
            if self.action_prob_exist :
                sampled_data['log_prob'] = self.data['log_prob'][rand_idx]
            return sampled_data
        else:
            return self.data
    def size(self):
        return min(self.max_size, self.data_idx)
class RunningMeanStd(object):
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean(x, axis=0)
        batch_var = np.var(x, axis=0)
        batch_count = x.shape[0]
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)


def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class PPO(nn.Module):
    def __init__(self, writer, device, state_dim, action_dim, args):
        super(PPO,self).__init__()
        self.args = args

        self.data = ReplayBuffer(action_prob_exist = True, max_size = self.args.traj_length, state_dim = state_dim, num_action = action_dim)
        self.actor = Actor(self.args.layer_num, state_dim, action_dim, self.args.hidden_dim, \
                           self.args.activation_function,self.args.last_activation,self.args.trainable_std)
        self.critic = Critic(self.args.layer_num, state_dim, 1, \
                             self.args.hidden_dim, self.args.activation_function,self.args.last_activation)

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.args.actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.args.critic_lr)

        self.writer = writer
        self.device = device

    def get_action(self,x):
        mu,sigma = self.actor(x)
        return mu,sigma

    def v(self,x):
        return self.critic(x)

    def put_data(self,transition):
        self.data.put_data(transition)

    def get_gae(self, states, rewards, next_states, dones):
        values = self.v(states).detach()  # (T, 1)
        next_values = self.v(next_states).detach()  # (T, 1)

        # TD errors: delta_t = r_t + gamma * V(s_{t+1}) * (1 - d_t) - V(s_t)
        td_errors = rewards + self.args.gamma * next_values * (1 - dones) - values  # (T, 1)

        advantages = torch.zeros_like(rewards).to(self.device)  # (T, 1)
        last_gae_lam = 0.0

        # Iterate backward to compute GAE: A_t = delta_t + gamma * lambda * A_{t+1} * (1 - d_t)
        for t in reversed(range(len(td_errors))):
            # (1 - dones[t]) acts as a mask, setting future advantage to 0 if terminal
            advantages[t] = td_errors[t] + self.args.gamma * self.args.lambda_ * (1 - dones[t]) * last_gae_lam
            last_gae_lam = advantages[t] # A_t becomes A_{t+1} for the next step

        return values, advantages

    def train_net(self,n_epi):
        data = self.data.sample(shuffle = False)
        states, actions, rewards, next_states, dones, old_log_probs = convert_to_tensor(self.device, data['state'], data['action'], data['reward'], data['next_state'], data['done'], data['log_prob'])

        old_values, advantages = self.get_gae(states, rewards, next_states, dones)
        returns = advantages + old_values
        advantages = (advantages - advantages.mean())/(advantages.std()+1e-3)

        for i in range(self.args.train_epoch):
            for state,action,old_log_prob,advantage,return_,old_value \
            in make_mini_batch(self.args.batch_size, states, actions, \
                                           old_log_probs,advantages,returns,old_values):
                curr_mu,curr_sigma = self.get_action(state)
                value = self.v(state).float()
                curr_dist = torch.distributions.Normal(curr_mu,curr_sigma)
                entropy = curr_dist.entropy() * self.args.entropy_coef
                curr_log_prob = curr_dist.log_prob(action).sum(1,keepdim = True)

                #policy clipping
                ratio = torch.exp(curr_log_prob - old_log_prob.detach())
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1-self.args.max_clip, 1+self.args.max_clip) * advantage
                actor_loss = (-torch.min(surr1, surr2) - entropy).mean()

                #value clipping (PPO2 technic)
                old_value_clipped = old_value + (value - old_value).clamp(-self.args.max_clip,self.args.max_clip)
                value_loss = (value - return_.detach().float()).pow(2)
                value_loss_clipped = (old_value_clipped - return_.detach().float()).pow(2)
                critic_loss = 0.5 * self.args.critic_coef * torch.max(value_loss,value_loss_clipped).mean()

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.max_grad_norm)
                self.actor_optimizer.step()

                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.max_grad_norm)
                self.critic_optimizer.step()

                if self.writer != None:
                    self.writer.add_scalar("loss/actor_loss", actor_loss.item(), n_epi)
                    self.writer.add_scalar("loss/critic_loss", critic_loss.item(), n_epi)


## PPO - training loop

In [None]:
# Import necessary libraries
import jax
import jax.numpy as jnp
from brax import envs
import torch
import numpy as np

# --- 1. Define Hyperparameters for the PyTorch PPO Agent ---
# Create a simple config class to hold hyperparameters
class PPOConfig:
    def __init__(self):
        self.traj_length = 2048 # Number of steps to collect before update
        self.layer_num = 2
        self.hidden_dim = 256
        self.activation_function = torch.tanh
        self.last_activation = None
        self.trainable_std = True
        self.actor_lr = 3e-4
        self.critic_lr = 3e-4
        self.train_epoch = 1 # Number of PPO epochs
        self.batch_size = 64 # Minibatch size for update
        self.gamma = 0.99
        self.lambda_ = 0.95
        self.max_clip = 0.2
        self.critic_coef = 0.5
        self.entropy_coef = 0.01
        self.max_grad_norm = 0.5

ppo_torch_args = PPOConfig()

# --- 2. Initialize Brax Environment (JAX) ---
env_name = 'humanoid'
env = envs.get_environment(env_name) # Removed episode_length=1000
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# --- DEBUG: Print environment action size ---
print(f"DEBUG: env.observation_size: {env.observation_size}")
print(f"DEBUG: env.action_size: {env.action_size}")
print(f"DEBUG: env.sys.nu (num actuators): {env.sys.nu}")

# --- 3. Initialize PyTorch PPO Agent ---
# Set device for PyTorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using PyTorch device: {device}")

# Dummy writer for now, as it's not provided in the notebook context
class DummyWriter:
    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        pass
writer = DummyWriter()

ppo_torch_agent = PPO(
    writer=writer,
    device=device,
    state_dim=env.observation_size,
    action_dim=env.action_size,
    args=ppo_torch_args
)
ppo_torch_agent.to(device)

print("PyTorch PPO agent initialized.")

# --- 4. Main Training Loop ---
# num_total_steps_pytorch = 1_000_000 # Total environment steps for PyTorch PPO
current_total_steps = 0
episode_count = 0
rng = jax.random.PRNGKey(0) # JAX RNG for environment

print("Starting PyTorch PPO training loop...")

# Change the while loop condition to run for 10 episodes
while episode_count < 300:
    episode_count += 1
    rng, reset_rng = jax.random.split(rng)
    env_state = jit_reset(reset_rng)

    episode_reward = 0
    episode_north_reward = 0
    episode_healthy_reward = 0
    episode_ctrl_cost = 0
    episode_sideways_cost = 0

    # Clear replay buffer for new trajectory collection (on-policy PPO)
    ppo_torch_agent.data = ReplayBuffer(
        action_prob_exist=True,
        max_size=ppo_torch_args.traj_length,
        state_dim=env.observation_size,
        num_action=env.action_size
    )

    for t in range(ppo_torch_args.traj_length): # Collect for traj_length steps
        # Convert JAX observation to PyTorch tensor and add batch dimension
        obs_torch = torch.from_numpy(np.array(env_state.obs)).float().to(device).unsqueeze(0)

        # Get action from PyTorch actor
        with torch.no_grad():
            mu, sigma = ppo_torch_agent.get_action(obs_torch)
            action_dist = torch.distributions.Normal(mu, sigma)
            action_torch = action_dist.sample()
            log_prob_torch = action_dist.log_prob(action_torch).sum(dim=-1, keepdim=True)

        # Convert PyTorch action to JAX array and remove batch dimension for environment step
        action_jax = jnp.asarray(action_torch.squeeze(0).cpu().numpy())

        # Step JAX environment
        rng, step_rng = jax.random.split(rng) # Need a new rng for each step if needed by brax (jit_step doesn't take it)
        next_env_state = jit_step(env_state, action_jax)

        # Prepare numpy arrays for ReplayBuffer (remove batch dimension where applicable)
        obs_np = obs_torch.squeeze(0).cpu().numpy()
        action_np = action_torch.squeeze(0).cpu().numpy()
        reward_np = np.array(next_env_state.reward).reshape(1) # Ensure (1,) shape
        next_obs_np = np.array(next_env_state.obs)
        done_np = np.array(next_env_state.done).reshape(1)     # Ensure (1,) shape
        log_prob_np = log_prob_torch.squeeze(0).cpu().numpy()

        # Store transition in PyTorch ReplayBuffer
        transition = make_transition(
            obs_np,
            action_np,
            reward_np,
            next_obs_np,
            done_np,
            log_prob_np
        )
        ppo_torch_agent.put_data(transition)

        episode_reward += next_env_state.reward
        episode_north_reward += next_env_state.metrics['north_reward']
        episode_healthy_reward += next_env_state.metrics['reward_alive']
        episode_ctrl_cost += -next_env_state.metrics['reward_quadctrl'] # ctrl_cost is stored as negative in metrics
        episode_sideways_cost += next_env_state.metrics['sideways_cost']

        env_state = next_env_state
        current_total_steps += 1

        if env_state.done:
            print(f"Episode {episode_count} finished early at step {t+1}. Total Reward: {episode_reward:.2f}. Total steps: {current_total_steps}")
            print(f"  Reward Components: North Reward: {episode_north_reward:.2f}, Healthy Reward: {episode_healthy_reward:.2f}, Ctrl Cost: {episode_ctrl_cost:.2f}, Sideways Cost: {episode_sideways_cost:.2f}")
            break
    else: # If loop completes without break
      print(f"Episode {episode_count} completed {ppo_torch_args.traj_length} steps. Total Reward: {episode_reward:.2f}. Total steps: {current_total_steps}")
      print(f"  Reward Components: North Reward: {episode_north_reward:.2f}, Healthy Reward: {episode_healthy_reward:.2f}, Ctrl Cost: {episode_ctrl_cost:.2f}, Sideways Cost: {episode_sideways_cost:.2f}")

    # Train PPO agent after collecting traj_length steps
    if ppo_torch_agent.data.size() >= ppo_torch_args.traj_length:
        ppo_torch_agent.train_net(episode_count)
        print(f"PPO agent trained for episode {episode_count}.")
        # Clear the buffer for next rollout (on-policy)
        ppo_torch_agent.data = ReplayBuffer(
            action_prob_exist=True,
            max_size=ppo_torch_args.traj_length,
            state_dim=env.observation_size,
            num_action=env.action_size
        )


print("\nPyTorch PPO training finished.")

# --- 5. Visualize Trained Policy (PyTorch PPO) ---
print("\nGenerating rollout with trained PyTorch PPO policy...")
eval_env = envs.get_environment(env_name) # Create a new env for evaluation if needed
jit_reset_eval = jax.jit(eval_env.reset)
jit_step_eval = jax.jit(eval_env.step)

rng, eval_rng = jax.random.split(rng)
eval_state = jit_reset_eval(eval_rng)
rollout_eval_torch = [eval_state.pipeline_state]

n_steps_eval = 500
render_every = 2

for i in range(n_steps_eval):
    # Convert JAX observation to PyTorch tensor and add batch dimension
    obs_torch_eval = torch.from_numpy(np.array(eval_state.obs)).float().to(device).unsqueeze(0)

    # Get action from PyTorch actor (deterministic for evaluation)
    with torch.no_grad():
        mu_eval, _ = ppo_torch_agent.get_action(obs_torch_eval)
        action_torch_eval = mu_eval # Use mean for deterministic action

    # Convert PyTorch action to JAX array and remove batch dimension for environment step
    action_jax_eval = jnp.asarray(action_torch_eval.squeeze(0).cpu().numpy())

    # Step JAX environment
    eval_state = jit_step_eval(eval_state, action_jax_eval)
    rollout_eval_torch.append(eval_state.pipeline_state)

    if eval_state.done:
        print(f"Evaluation episode finished early at step {i+1}")
        break

print("Rendering video from PyTorch PPO trained policy...")
media.show_video(eval_env.render(rollout_eval_torch[::render_every]), fps=1.0 / eval_env.dt / render_every)


In [None]:

# --- 5. Visualize Trained Policy (PyTorch PPO) ---
print("\nGenerating rollout with trained PyTorch PPO policy...")
eval_env = envs.get_environment(env_name) # Create a new env for evaluation if needed
jit_reset_eval = jax.jit(eval_env.reset)
jit_step_eval = jax.jit(eval_env.step)

rng, eval_rng = jax.random.split(rng)
eval_state = jit_reset_eval(eval_rng)
rollout_eval_torch = [eval_state.pipeline_state]

n_steps_eval = 500
render_every = 2

for i in range(n_steps_eval):
    # Convert JAX observation to PyTorch tensor and add batch dimension
    obs_torch_eval = torch.from_numpy(np.array(eval_state.obs)).float().to(device).unsqueeze(0)

    # Get action from PyTorch actor (deterministic for evaluation)
    with torch.no_grad():
        mu_eval, _ = ppo_torch_agent.get_action(obs_torch_eval)
        action_torch_eval = mu_eval # Use mean for deterministic action
        # --- DEBUG: Print action_torch_eval shape ---
        print(f"DEBUG: action_torch_eval shape: {action_torch_eval.shape}")

    # Convert PyTorch action to JAX array and remove batch dimension for environment step
    action_jax_eval = jnp.asarray(action_torch_eval.squeeze(0).cpu().numpy())
    # --- DEBUG: Print action_jax_eval shape ---
    print(f"DEBUG: action_jax_eval shape (after squeeze): {action_jax_eval.shape}")

    # Step JAX environment
    eval_state = jit_step_eval(eval_state, action_jax_eval)
    rollout_eval_torch.append(eval_state.pipeline_state)

    if eval_state.done:
        print(f"Evaluation episode finished early at step {i+1}")
        break

print("Rendering video from PyTorch PPO trained policy...")
media.show_video(eval_env.render(rollout_eval_torch[::render_every]), fps=1.0 / eval_env.dt / render_every)

# Task
I will proceed with the following steps:

1.  **Modify `PPO.train_net` (cell `SYyu-Ijsdp_q`)**:
    *   Separate the `actor_loss` into `policy_loss_raw` (policy gradient part) and `entropy_loss_term` (entropy regularization part).
    *   Ensure `critic_loss` (value loss part) is correctly captured.
    *   Modify the `train_net` method to return these three average loss values: `policy_loss_raw.item()`, `entropy_loss_term.item()`, and `critic_loss.item()`.

2.  **Modify the training loop (cell `uvqNqIbid8KS`)**:
    *   Add `import csv` and `from datetime import datetime`.
    *   Generate a unique CSV filename using the current timestamp.
    *   Open the CSV file and write the header row, including: `episode`, `total_reward`, `north_reward`, `healthy_reward`, `ctrl_cost`, `sideways_cost`, `policy_gradient_loss`, `entropy_loss`, `value_loss`, and `total_steps`.
    *   In the training loop, after `ppo_torch_agent.train_net` is called, capture the returned `policy_loss`, `entropy_loss`, and `value_loss`.
    *   Append a new row to the CSV file with the `episode_count`, `episode_reward`, component rewards, the captured detailed loss values, and `current_total_steps`.
    *   Make sure to handle the `ctrl_cost` correctly, as it's stored as negative in metrics.

This will ensure the detailed loss components are captured and logged for each training episode.

```python
# cell_id: SYyu-Ijsdp_q

import torch
import torch.nn as nn
import torch.optim as optim

class PPO(nn.Module):
    def __init__(self, writer, device, state_dim, action_dim, args):
        super(PPO,self).__init__()
        self.args = args

        self.data = ReplayBuffer(action_prob_exist = True, max_size = self.args.traj_length, state_dim = state_dim, num_action = action_dim)
        self.actor = Actor(self.args.layer_num, state_dim, action_dim, self.args.hidden_dim, \
                           self.args.activation_function,self.args.last_activation,self.args.trainable_std)
        self.critic = Critic(self.args.layer_num, state_dim, 1, \
                             self.args.hidden_dim, self.args.activation_function,self.args.last_activation)

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.args.actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.args.critic_lr)

        self.writer = writer
        self.device = device

    def get_action(self,x):
        mu,sigma = self.actor(x)
        return mu,sigma

    def v(self,x):
        return self.critic(x)

    def put_data(self,transition):
        self.data.put_data(transition)

    def get_gae(self, states, rewards, next_states, dones):
        values = self.v(states).detach()  # (T, 1)
        next_values = self.v(next_states).detach()  # (T, 1)

        # TD errors: delta_t = r_t + gamma * V(s_{t+1}) * (1 - d_t) - V(s_t)
        td_errors = rewards + self.args.gamma * next_values * (1 - dones) - values  # (T, 1)

        advantages = torch.zeros_like(rewards).to(self.device)  # (T, 1)
        last_gae_lam = 0.0

        # Iterate backward to compute GAE: A_t = delta_t + gamma * lambda * A_{t+1} * (1 - d_t)
        for t in reversed(range(len(td_errors))):
            # (1 - dones[t]) acts as a mask, setting future advantage to 0 if terminal
            advantages[t] = td_errors[t] + self.args.gamma * self.args.lambda_ * (1 - dones[t]) * last_gae_lam
            last_gae_lam = advantages[t] # A_t becomes A_{t+1} for the next step

        return values, advantages

    def train_net(self,n_epi):
        data = self.data.sample(shuffle = False)
        states, actions, rewards, next_states, dones, old_log_probs = convert_to_tensor(self.device, data['state'], data['action'], data['reward'], data['next_state'], data['done'], data['log_prob'])

        old_values, advantages = self.get_gae(states, rewards, next_states, dones)
        returns = advantages + old_values
        advantages = (advantages - advantages.mean())/(advantages.std()+1e-3)

        # Initialize lists to store losses for averaging
        policy_losses_raw_epoch = []
        entropy_losses_term_epoch = []
        critic_losses_epoch = []

        for i in range(self.args.train_epoch):
            for state,action,old_log_prob,advantage,return_,old_value \
            in make_mini_batch(self.args.batch_size, states, actions, \
                                           old_log_probs,advantages,returns,old_values):
                curr_mu,curr_sigma = self.get_action(state)
                value = self.v(state).float()
                curr_dist = torch.distributions.Normal(curr_mu,curr_sigma)
                entropy = curr_dist.entropy() * self.args.entropy_coef
                curr_log_prob = curr_dist.log_prob(action).sum(1,keepdim = True)

                #policy clipping
                ratio = torch.exp(curr_log_prob - old_log_prob.detach())
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1-self.args.max_clip, 1+self.args.max_clip) * advantage
                
                policy_loss_raw = (-torch.min(surr1, surr2)).mean()
                entropy_loss_term = (-entropy).mean()
                actor_loss = policy_loss_raw + entropy_loss_term

                #value clipping (PPO2 technic)
                old_value_clipped = old_value + (value - old_value).clamp(-self.args.max_clip,self.args.max_clip)
                value_loss = (value - return_.detach().float()).pow(2)
                value_loss_clipped = (old_value_clipped - return_.detach().float()).pow(2)
                critic_loss = 0.5 * self.args.critic_coef * torch.max(value_loss,value_loss_clipped).mean()

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.max_grad_norm)
                self.actor_optimizer.step()

                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.max_grad_norm)
                self.critic_optimizer.step()

                # Collect losses for averaging
                policy_losses_raw_epoch.append(policy_loss_raw.item())
                entropy_losses_term_epoch.append(entropy_loss_term.item())
                critic_losses_epoch.append(critic_loss.item())

                if self.writer != None:
                    self.writer.add_scalar("loss/actor_loss", actor_loss.item(), n_epi)
                    self.writer.add_scalar("loss/critic_loss", critic_loss.item(), n_epi)
        
        # Calculate average losses over all mini-batches and epochs
        avg_policy_loss_raw = sum(policy_losses_raw_epoch) / len(policy_losses_raw_epoch) if policy_losses_raw_epoch else 0
        avg_entropy_loss_term = sum(entropy_losses_term_epoch) / len(entropy_losses_term_epoch) if entropy_losses_term_epoch else 0
        avg_critic_loss = sum(critic_losses_epoch) / len(critic_losses_epoch) if critic_losses_epoch else 0

        return avg_policy_loss_raw, avg_entropy_loss_term, avg_critic_loss

```

```python
# cell_id: uvqNqIbid8KS

# Import necessary libraries
import jax
import jax.numpy as jnp
from brax import envs
import torch
import numpy as np
import csv # Added for CSV logging
from datetime import datetime # Added for unique filename

# --- 1. Define Hyperparameters for the PyTorch PPO Agent ---
# Create a simple config class to hold hyperparameters
class PPOConfig:
    def __init__(self):
        self.traj_length = 2048 # Number of steps to collect before update
        self.layer_num = 2
        self.hidden_dim = 256
        self.activation_function = torch.tanh
        self.last_activation = None
        self.trainable_std = True
        self.actor_lr = 3e-4
        self.critic_lr = 3e-4
        self.train_epoch = 1 # Number of PPO epochs
        self.batch_size = 64 # Minibatch size for update
        self.gamma = 0.99
        self.lambda_ = 0.95
        self.max_clip = 0.2
        self.critic_coef = 0.5
        self.entropy_coef = 0.01
        self.max_grad_norm = 0.5

ppo_torch_args = PPOConfig()

# --- 2. Initialize Brax Environment (JAX) ---
env_name = 'humanoid'
env = envs.get_environment(env_name) # Removed episode_length=1000
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# --- DEBUG: Print environment action size ---
print(f"DEBUG: env.observation_size: {env.observation_size}")
print(f"DEBUG: env.action_size: {env.action_size}")
print(f"DEBUG: env.sys.nu (num actuators): {env.sys.nu}")

# --- 3. Initialize PyTorch PPO Agent ---
# Set device for PyTorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using PyTorch device: {device}")

# Dummy writer for now, as it's not provided in the notebook context
class DummyWriter:
    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        pass
writer = DummyWriter()

ppo_torch_agent = PPO(
    writer=writer,
    device=device,
    state_dim=env.observation_size,
    action_dim=env.action_size,
    args=ppo_torch_args
)
ppo_torch_agent.to(device)

print("PyTorch PPO agent initialized.")

# --- CSV Logging Setup ---
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_filename = f"ppo_training_log_{timestamp}.csv"
csv_file = open(csv_filename, 'w', newline='')
csv_writer = csv.writer(csv_file)

# Define header for CSV
header = [
    'episode', 'total_reward', 'north_reward', 'healthy_reward',
    'ctrl_cost', 'sideways_cost', 'policy_gradient_loss',
    'entropy_loss', 'value_loss', 'total_steps'
]
csv_writer.writerow(header)
print(f"Logging training data to {csv_filename}")
# --- End CSV Logging Setup ---

# --- 4. Main Training Loop ---
# num_total_steps_pytorch = 1_000_000 # Total environment steps for PyTorch PPO
current_total_steps = 0
episode_count = 0
rng = jax.random.PRNGKey(0) # JAX RNG for environment

print("Starting PyTorch PPO training loop...")

# Change the while loop condition to run for 10 episodes
while episode_count < 300:
    episode_count += 1
    rng, reset_rng = jax.random.split(rng)
    env_state = jit_reset(reset_rng)

    episode_reward = 0
    episode_north_reward = 0
    episode_healthy_reward = 0
    episode_ctrl_cost = 0
    episode_sideways_cost = 0

    # Clear replay buffer for new trajectory collection (on-policy PPO)
    ppo_torch_agent.data = ReplayBuffer(
        action_prob_exist=True,
        max_size=ppo_torch_args.traj_length,
        state_dim=env.observation_size,
        num_action=env.action_size
    )

    for t in range(ppo_torch_args.traj_length): # Collect for traj_length steps
        # Convert JAX observation to PyTorch tensor and add batch dimension
        obs_torch = torch.from_numpy(np.array(env_state.obs)).float().to(device).unsqueeze(0)

        # Get action from PyTorch actor
        with torch.no_grad():
            mu, sigma = ppo_torch_agent.get_action(obs_torch)
            action_dist = torch.distributions.Normal(mu, sigma)
            action_torch = action_dist.sample()
            log_prob_torch = action_dist.log_prob(action_torch).sum(dim=-1, keepdim=True)

        # Convert PyTorch action to JAX array and remove batch dimension for environment step
        action_jax = jnp.asarray(action_torch.squeeze(0).cpu().numpy())

        # Step JAX environment
        rng, step_rng = jax.random.split(rng) # Need a new rng for each step if needed by brax (jit_step doesn't take it)
        next_env_state = jit_step(env_state, action_jax)

        # Prepare numpy arrays for ReplayBuffer (remove batch dimension where applicable)
        obs_np = obs_torch.squeeze(0).cpu().numpy()
        action_np = action_torch.squeeze(0).cpu().numpy()
        reward_np = np.array(next_env_state.reward).reshape(1) # Ensure (1,) shape
        next_obs_np = np.array(next_env_state.obs)
        done_np = np.array(next_env_state.done).reshape(1)     # Ensure (1,) shape
        log_prob_np = log_prob_torch.squeeze(0).cpu().numpy()

        # Store transition in PyTorch ReplayBuffer
        transition = make_transition(
            obs_np,
            action_np,
            reward_np,
            next_obs_np,
            done_np,
            log_prob_np
        )
        ppo_torch_agent.put_data(transition)

        episode_reward += next_env_state.reward
        episode_north_reward += next_env_state.metrics['north_reward']
        episode_healthy_reward += next_env_state.metrics['reward_alive']
        episode_ctrl_cost += -next_env_state.metrics['reward_quadctrl'] # ctrl_cost is stored as negative in metrics
        episode_sideways_cost += next_env_state.metrics['sideways_cost']

        env_state = next_env_state
        current_total_steps += 1

        if env_state.done:
            print(f"Episode {episode_count} finished early at step {t+1}. Total Reward: {episode_reward:.2f}. Total steps: {current_total_steps}")
            print(f"  Reward Components: North Reward: {episode_north_reward:.2f}, Healthy Reward: {episode_healthy_reward:.2f}, Ctrl Cost: {episode_ctrl_cost:.2f}, Sideways Cost: {episode_sideways_cost:.2f}")
            break
    else: # If loop completes without break
      print(f"Episode {episode_count} completed {ppo_torch_args.traj_length} steps. Total Reward: {episode_reward:.2f}. Total steps: {current_total_steps}")
      print(f"  Reward Components: North Reward: {episode_north_reward:.2f}, Healthy Reward: {episode_healthy_reward:.2f}, Ctrl Cost: {episode_ctrl_cost:.2f}, Sideways Cost: {episode_sideways_cost:.2f}")

    # Train PPO agent after collecting traj_length steps
    if ppo_torch_agent.data.size() >= ppo_torch_args.traj_length:
        policy_pg_loss, entropy_term_loss, value_fn_loss = ppo_torch_agent.train_net(episode_count)
        print(f"PPO agent trained for episode {episode_count}. Losses: PG={policy_pg_loss:.4f}, Ent={entropy_term_loss:.4f}, Val={value_fn_loss:.4f}")

        # --- Log to CSV ---
        csv_writer.writerow([
            episode_count,
            float(episode_reward),
            float(episode_north_reward),
            float(episode_healthy_reward),
            float(episode_ctrl_cost),
            float(episode_sideways_cost),
            policy_pg_loss,
            entropy_term_loss,
            value_fn_loss,
            current_total_steps
        ])
        csv_file.flush() # Ensure data is written to disk immediately
        # --- End Log to CSV ---

        # Clear the buffer for next rollout (on-policy)
        ppo_torch_agent.data = ReplayBuffer(
            action_prob_exist=True,
            max_size=ppo_torch_args.traj_length,
            state_dim=env.observation_size,
            num_action=env.action_size
        )

# --- Close CSV file after training ---
csv_file.close()
print(f"Training data saved to {csv_filename}")
# --- End Close CSV file ---

print("\nPyTorch PPO training finished.")

# --- 5. Visualize Trained Policy (PyTorch PPO) ---
print("\nGenerating rollout with trained PyTorch PPO policy...")
eval_env = envs.get_environment(env_name) # Create a new env for evaluation if needed
jit_reset_eval = jax.jit(eval_env.reset)
jit_step_eval = jax.jit(eval_env.step)

rng, eval_rng = jax.random.split(rng)
eval_state = jit_reset_eval(eval_rng)
rollout_eval_torch = [eval_state.pipeline_state]

n_steps_eval = 500
render_every = 2

for i in range(n_steps_eval):
    # Convert JAX observation to PyTorch tensor and add batch dimension
    obs_torch_eval = torch.from_numpy(np.array(eval_state.obs)).float().to(device).unsqueeze(0)

    # Get action from PyTorch actor (deterministic for evaluation)
    with torch.no_grad():
        mu_eval, _ = ppo_torch_agent.get_action(obs_torch_eval)
        action_torch_eval = mu_eval # Use mean for deterministic action

    # Convert PyTorch action to JAX array and remove batch dimension for environment step
    action_jax_eval = jnp.asarray(action_torch_eval.squeeze(0).cpu().numpy())

    # Step JAX environment
    eval_state = jit_step_eval(eval_state, action_jax_eval)
    rollout_eval_torch.append(eval_state.pipeline_state)

    if eval_state.done:
        print(f"Evaluation episode finished early at step {i+1}")
        break

print("Rendering video from PyTorch PPO trained policy...")
media.show_video(eval_env.render(rollout_eval_torch[::render_every]), fps=1.0 / eval_env.dt / render_every)
```

## Modify PPO.train_net to return detailed losses

### Subtask:
Adjust the `train_net` method in the `PPO` class (cell `SYyu-Ijsdp_q`) to calculate and return average `policy_loss_raw` (policy gradient part), `entropy_loss_term` (entropy regularization part), and the final `critic_loss` (value loss part, including clipping and scaling) over the training epochs.


**Reasoning**:
The subtask requires modifying the `train_net` method in the `PPO` class to calculate and return average detailed loss components. I will update the code in cell `SYyu-Ijsdp_q` to implement this functionality.



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class PPO(nn.Module):
    def __init__(self, writer, device, state_dim, action_dim, args):
        super(PPO,self).__init__()
        self.args = args

        self.data = ReplayBuffer(action_prob_exist = True, max_size = self.args.traj_length, state_dim = state_dim, num_action = action_dim)
        self.actor = Actor(self.args.layer_num, state_dim, action_dim, self.args.hidden_dim, \
                           self.args.activation_function,self.args.last_activation,self.args.trainable_std)
        self.critic = Critic(self.args.layer_num, state_dim, 1, \
                             self.args.hidden_dim, self.args.activation_function,self.args.last_activation)

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.args.actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.args.critic_lr)

        self.writer = writer
        self.device = device

    def get_action(self,x):
        mu,sigma = self.actor(x)
        return mu,sigma

    def v(self,x):
        return self.critic(x)

    def put_data(self,transition):
        self.data.put_data(transition)

    def get_gae(self, states, rewards, next_states, dones):
        values = self.v(states).detach()  # (T, 1)
        next_values = self.v(next_states).detach()  # (T, 1)

        # TD errors: delta_t = r_t + gamma * V(s_{t+1}) * (1 - d_t) - V(s_t)
        td_errors = rewards + self.args.gamma * next_values * (1 - dones) - values  # (T, 1)

        advantages = torch.zeros_like(rewards).to(self.device)  # (T, 1)
        last_gae_lam = 0.0

        # Iterate backward to compute GAE: A_t = delta_t + gamma * lambda * A_{t+1} * (1 - d_t)
        for t in reversed(range(len(td_errors))):
            # (1 - dones[t]) acts as a mask, setting future advantage to 0 if terminal
            advantages[t] = td_errors[t] + self.args.gamma * self.args.lambda_ * (1 - dones[t]) * last_gae_lam
            last_gae_lam = advantages[t] # A_t becomes A_{t+1} for the next step

        return values, advantages

    def train_net(self,n_epi):
        data = self.data.sample(shuffle = False)
        states, actions, rewards, next_states, dones, old_log_probs = convert_to_tensor(self.device, data['state'], data['action'], data['reward'], data['next_state'], data['done'], data['log_prob'])

        old_values, advantages = self.get_gae(states, rewards, next_states, dones)
        returns = advantages + old_values
        advantages = (advantages - advantages.mean())/(advantages.std()+1e-3)

        # Initialize lists to store losses for averaging
        policy_losses_raw_epoch = []
        entropy_losses_term_epoch = []
        critic_losses_epoch = []

        for i in range(self.args.train_epoch):
            for state,action,old_log_prob,advantage,return_,old_value \
            in make_mini_batch(self.args.batch_size, states, actions, \
                                           old_log_probs,advantages,returns,old_values):
                curr_mu,curr_sigma = self.get_action(state)
                value = self.v(state).float()
                curr_dist = torch.distributions.Normal(curr_mu,curr_sigma)
                entropy = curr_dist.entropy() * self.args.entropy_coef
                curr_log_prob = curr_dist.log_prob(action).sum(1,keepdim = True)

                #policy clipping
                ratio = torch.exp(curr_log_prob - old_log_prob.detach())
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1-self.args.max_clip, 1+self.args.max_clip) * advantage
                policy_loss_raw = (-torch.min(surr1, surr2)).mean() # Policy gradient part
                entropy_loss_term = (-entropy).mean() # Entropy regularization part
                actor_loss = policy_loss_raw + entropy_loss_term # Total actor loss

                #value clipping (PPO2 technic)
                old_value_clipped = old_value + (value - old_value).clamp(-self.args.max_clip,self.args.max_clip)
                value_loss = (value - return_.detach().float()).pow(2)
                value_loss_clipped = (old_value_clipped - return_.detach().float()).pow(2)
                critic_loss = 0.5 * self.args.critic_coef * torch.max(value_loss,value_loss_clipped).mean()

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.max_grad_norm)
                self.actor_optimizer.step()

                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.max_grad_norm)
                self.critic_optimizer.step()

                # Append losses to lists
                policy_losses_raw_epoch.append(policy_loss_raw.item())
                entropy_losses_term_epoch.append(entropy_loss_term.item())
                critic_losses_epoch.append(critic_loss.item())

                if self.writer != None:
                    self.writer.add_scalar("loss/actor_loss", actor_loss.item(), n_epi)
                    self.writer.add_scalar("loss/critic_loss", critic_loss.item(), n_epi)

        # Calculate average losses over the epoch
        avg_policy_loss_raw = sum(policy_losses_raw_epoch) / len(policy_losses_raw_epoch)
        avg_entropy_loss_term = sum(entropy_losses_term_epoch) / len(entropy_losses_term_epoch)
        avg_critic_loss = sum(critic_losses_epoch) / len(critic_losses_epoch)

        return avg_policy_loss_raw, avg_entropy_loss_term, avg_critic_loss


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class PPO(nn.Module):
    def __init__(self, writer, device, state_dim, action_dim, args):
        super(PPO,self).__init__()
        self.args = args

        self.data = ReplayBuffer(action_prob_exist = True, max_size = self.args.traj_length, state_dim = state_dim, num_action = action_dim)
        self.actor = Actor(self.args.layer_num, state_dim, action_dim, self.args.hidden_dim, \
                           self.args.activation_function,self.args.last_activation,self.args.trainable_std)
        self.critic = Critic(self.args.layer_num, state_dim, 1, \
                             self.args.hidden_dim, self.args.activation_function,self.args.last_activation)

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.args.actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=self.args.critic_lr)

        self.writer = writer
        self.device = device

    def get_action(self,x):
        mu,sigma = self.actor(x)
        return mu,sigma

    def v(self,x):
        return self.critic(x)

    def put_data(self,transition):
        self.data.put_data(transition)

    def get_gae(self, states, rewards, next_states, dones):
        values = self.v(states).detach()  # (T, 1)
        next_values = self.v(next_states).detach()  # (T, 1)

        # TD errors: delta_t = r_t + gamma * V(s_{t+1}) * (1 - d_t) - V(s_t)
        td_errors = rewards + self.args.gamma * next_values * (1 - dones) - values  # (T, 1)

        advantages = torch.zeros_like(rewards).to(self.device)  # (T, 1)
        last_gae_lam = 0.0

        # Iterate backward to compute GAE: A_t = delta_t + gamma * lambda * A_{t+1} * (1 - d_t)
        for t in reversed(range(len(td_errors))):
            # (1 - dones[t]) acts as a mask, setting future advantage to 0 if terminal
            advantages[t] = td_errors[t] + self.args.gamma * self.args.lambda_ * (1 - dones[t]) * last_gae_lam
            last_gae_lam = advantages[t] # A_t becomes A_{t+1} for the next step

        return values, advantages

    def train_net(self,n_epi):
        data = self.data.sample(shuffle = False)
        states, actions, rewards, next_states, dones, old_log_probs = convert_to_tensor(self.device, data['state'], data['action'], data['reward'], data['next_state'], data['done'], data['log_prob'])

        old_values, advantages = self.get_gae(states, rewards, next_states, dones)
        returns = advantages + old_values
        advantages = (advantages - advantages.mean())/(advantages.std()+1e-3)

        # Initialize lists to store losses for averaging
        policy_losses_raw_epoch = []
        entropy_losses_term_epoch = []
        critic_losses_epoch = []

        for i in range(self.args.train_epoch):
            for state,action,old_log_prob,advantage,return_,old_value \
            in make_mini_batch(self.args.batch_size, states, actions, \
                                           old_log_probs,advantages,returns,old_values):
                curr_mu,curr_sigma = self.get_action(state)
                value = self.v(state).float()
                curr_dist = torch.distributions.Normal(curr_mu,curr_sigma)
                entropy = curr_dist.entropy() * self.args.entropy_coef
                curr_log_prob = curr_dist.log_prob(action).sum(1,keepdim = True)

                #policy clipping
                ratio = torch.exp(curr_log_prob - old_log_prob.detach())
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1-self.args.max_clip, 1+self.args.max_clip) * advantage
                policy_loss_raw = (-torch.min(surr1, surr2)).mean() # Policy gradient part
                entropy_loss_term = (-entropy).mean() # Entropy regularization part
                actor_loss = policy_loss_raw + entropy_loss_term # Total actor loss

                #value clipping (PPO2 technic)
                old_value_clipped = old_value + (value - old_value).clamp(-self.args.max_clip,self.args.max_clip)
                value_loss = (value - return_.detach().float()).pow(2)
                value_loss_clipped = (old_value_clipped - return_.detach().float()).pow(2)
                critic_loss = 0.5 * self.args.critic_coef * torch.max(value_loss,value_loss_clipped).mean()

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.max_grad_norm)
                self.actor_optimizer.step()

                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.max_grad_norm)
                self.critic_optimizer.step()

                # Append losses to lists
                policy_losses_raw_epoch.append(policy_loss_raw.item())
                entropy_losses_term_epoch.append(entropy_loss_term.item())
                critic_losses_epoch.append(critic_loss.item())

                if self.writer != None:
                    self.writer.add_scalar("loss/actor_loss", actor_loss.item(), n_epi)
                    self.writer.add_scalar("loss/critic_loss", critic_loss.item(), n_epi)

        # Calculate average losses over the epoch
        avg_policy_loss_raw = sum(policy_losses_raw_epoch) / len(policy_losses_raw_epoch)
        avg_entropy_loss_term = sum(entropy_losses_term_epoch) / len(entropy_losses_term_epoch)
        avg_critic_loss = sum(critic_losses_epoch) / len(critic_losses_epoch)

        return avg_policy_loss_raw, avg_entropy_loss_term, avg_critic_loss


# train parallel envs

In [None]:
# Import necessary libraries
import jax
import jax.numpy as jnp
from brax import envs
import torch
import numpy as np
import csv # Added for CSV logging
from datetime import datetime # Added for unique filename
from brax import envs

class ScaleRewardWrapper(envs.Wrapper):
    def __init__(self, env, scale=0.1):
        super().__init__(env)
        self.scale = scale

    def reset(self, rng):
        state = super().reset(rng)
        return state

    def step(self, state, action):
        state = super().step(state, action)
        # Skalujemy nagrodę (zmniejszamy ją)
        return state.replace(reward=state.reward * self.scale)# --- 1. Define Hyperparameters for the PyTorch PPO Agent ---
# Create a simple config class to hold hyperparameters
class PPOConfig:
    def __init__(self):
        self.traj_length = 64 # Number of steps to collect before update
        self.layer_num = 2
        self.hidden_dim = 256
        self.activation_function = torch.tanh
        self.last_activation = None
        self.trainable_std = True
        self.actor_lr = 3e-4
        self.critic_lr = 1e-3
        self.train_epoch = 1 # Number of PPO epochs
        self.batch_size = 1024 # Minibatch size for update
        self.gamma = 0.99
        self.lambda_ = 0.95
        self.max_clip = 0.2
        self.critic_coef = 0.5
        self.entropy_coef = 0.2
        self.max_grad_norm = 0.5
        self.num_envs = 6144 # Added: Number of parallel environments
        self.num_updates = 100 # Added: Number of PPO updates (replaces episode_count limit)

ppo_torch_args = PPOConfig()

# --- 2. Initialize Brax Environment (JAX) ---
env_name = 'humanoid'
# Modified: Use envs.create for batched environments
env = envs.create(env_name=env_name, episode_length=2000, batch_size=ppo_torch_args.num_envs)
env = ScaleRewardWrapper(env, scale=0.1)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# --- DEBUG: Print environment action size ---
print(f"DEBUG: env.observation_size: {env.observation_size}")
print(f"DEBUG: env.action_size: {env.action_size}")
print(f"DEBUG: env.sys.nu (num actuators): {env.sys.nu}")
# Added: Debug print for number of environments
print(f"DEBUG: Number of environments (batch_size): {env.batch_size}")

# --- 3. Initialize PyTorch PPO Agent ---
# Set device for PyTorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using PyTorch device: {device}")

# Dummy writer for now, as it's not provided in the notebook context
class DummyWriter:
    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        pass
writer = DummyWriter()

ppo_torch_agent = PPO(
    writer=writer,
    device=device,
    state_dim=env.observation_size,
    action_dim=env.action_size,
    args=ppo_torch_args
)
ppo_torch_agent.to(device)

print("PyTorch PPO agent initialized.")

# --- CSV Logging Setup ---
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_filename = f"ppo_training_log_{timestamp}.csv"
csv_file = open(csv_filename, 'w', newline='')
csv_writer = csv.writer(csv_file)

# Define header for CSV
# Modified: Use 'update_idx' and 'avg_' prefix for averaged metrics
header = [
    'update_idx', 'avg_total_reward', 'avg_north_reward', 'avg_healthy_reward',
    'avg_ctrl_cost', 'avg_sideways_cost', 'policy_gradient_loss',
    'entropy_loss', 'value_loss', 'total_steps'
]
csv_writer.writerow(header)
print(f"Logging training data to {csv_filename}")
# --- End CSV Logging Setup ---

# --- 4. Main Training Loop ---
current_total_steps = 0
# Initialize rng for the loop outside of it
rng = jax.random.PRNGKey(0)

print("Starting PyTorch PPO training loop...")

# Initialise replay buffer with correct size
ppo_torch_agent.data = ReplayBuffer(
    action_prob_exist=True,
    max_size=ppo_torch_args.traj_length * ppo_torch_args.num_envs, # Corrected max_size for batched collection
    state_dim=env.observation_size,
    num_action=env.action_size
)

# Outer loop for PPO updates
for update_idx in range(ppo_torch_args.num_updates):
    rng, reset_rng = jax.random.split(rng)
    env_state = jit_reset(reset_rng)

    # Initialize accumulators for metrics over the entire collected batch
    total_sum_rewards_collected = 0.0
    total_sum_north_rewards_collected = 0.0
    total_sum_healthy_rewards_collected = 0.0
    total_sum_ctrl_costs_collected = 0.0
    total_sum_sideways_costs_collected = 0.0

    # Collect for traj_length steps from num_envs parallel environments
    for t in range(ppo_torch_args.traj_length):
        # Modified: Use DLPack for zero-copy JAX to PyTorch transfer (fixed deprecated API)
        obs_torch = torch.utils.dlpack.from_dlpack(env_state.obs).to(device)

        # Get action from PyTorch actor (now takes a batch of observations)
        with torch.no_grad():
            mu, sigma = ppo_torch_agent.get_action(obs_torch)
            action_dist = torch.distributions.Normal(mu, sigma)
            action_torch = action_dist.sample()
            log_prob_torch = action_dist.log_prob(action_torch).sum(dim=-1, keepdim=True)

        # Modified: Use DLPack for zero-copy PyTorch to JAX transfer (fixed deprecated API)
        action_jax = jax.dlpack.from_dlpack(action_torch)

        # Step JAX environment (now processes a batch of actions)
        rng, step_rng = jax.random.split(rng)
        next_env_state = jit_step(env_state, action_jax)

        # Prepare numpy arrays for ReplayBuffer (for num_envs parallel transitions)
        # and accumulate metrics for logging
        current_obs_batch_np = obs_torch.cpu().numpy()
        current_action_batch_np = action_torch.cpu().numpy()
        current_log_prob_batch_np = log_prob_torch.cpu().numpy()

        next_obs_batch_np = np.array(next_env_state.obs)
        reward_batch_np = np.array(next_env_state.reward)
        done_batch_np = np.array(next_env_state.done)

        north_reward_batch_np = np.array(next_env_state.metrics['north_reward'])
        healthy_reward_batch_np = np.array(next_env_state.metrics['reward_alive'])
        ctrl_cost_batch_np = -np.array(next_env_state.metrics['reward_quadctrl']) # ctrl_cost is stored as negative in metrics
        sideways_cost_batch_np = np.array(next_env_state.metrics['sideways_cost'])

        # Store transitions in PyTorch ReplayBuffer and accumulate metrics
        for env_idx in range(ppo_torch_args.num_envs):
            # The `if env_idx == 0:` condition was removed, as all envs' metrics contribute to the sum.
            transition = make_transition(
                current_obs_batch_np[env_idx],
                current_action_batch_np[env_idx],
                reward_batch_np[env_idx].reshape(1),
                next_obs_batch_np[env_idx],
                done_batch_np[env_idx].reshape(1),
                current_log_prob_batch_np[env_idx]
            )
            ppo_torch_agent.put_data(transition)

            # Accumulate sum of rewards and costs across all environments
            total_sum_rewards_collected += reward_batch_np[env_idx]
            total_sum_north_rewards_collected += north_reward_batch_np[env_idx]
            total_sum_healthy_rewards_collected += healthy_reward_batch_np[env_idx]
            total_sum_ctrl_costs_collected += ctrl_cost_batch_np[env_idx]
            total_sum_sideways_costs_collected += sideways_cost_batch_np[env_idx]

        env_state = next_env_state
        current_total_steps += ppo_torch_args.num_envs # Increment total steps by the batch size

    # Removed: Early exit and print logic for single environment, now fixed length collection
    # The print statement now reflects completion of traj_length steps for all environments

    # After collecting traj_length steps for all num_envs:
    # Calculate average metrics over all collected transitions in this update
    num_collected_transitions = ppo_torch_args.traj_length * ppo_torch_args.num_envs
    avg_total_reward = total_sum_rewards_collected / num_collected_transitions
    avg_north_reward = total_sum_north_rewards_collected / num_collected_transitions
    avg_healthy_reward = total_sum_healthy_rewards_collected / num_collected_transitions
    avg_ctrl_cost = total_sum_ctrl_costs_collected / num_collected_transitions
    avg_sideways_cost = total_sum_sideways_costs_collected / num_collected_transitions

    # Modified: Print statements use averaged metrics
    print(f"Update {update_idx + 1} completed. Avg Total Reward: {avg_total_reward:.2f}. Total steps: {current_total_steps}")
    print(f"  Avg Reward Components: North Reward: {avg_north_reward:.2f}, Healthy Reward: {avg_healthy_reward:.2f}, Ctrl Cost: {avg_ctrl_cost:.2f}, Sideways Cost: {avg_sideways_cost:.2f}")

    # Train PPO agent after collecting traj_length * num_envs steps
    # The buffer size `max_size` is now correctly `traj_length * num_envs`
    if ppo_torch_agent.data.size() >= num_collected_transitions:
        # Modified: Pass update_idx to train_net
        policy_pg_loss, entropy_term_loss, value_fn_loss = ppo_torch_agent.train_net(update_idx)
        print(f"PPO agent trained for update {update_idx + 1}. Losses: PG={policy_pg_loss:.4f}, Ent={entropy_term_loss:.4f}, Val={value_fn_loss:.4f}")

        # --- Log to CSV ---
        # Modified: Log averaged metrics
        csv_writer.writerow([
            update_idx + 1, # Log update_idx
            float(avg_total_reward),
            float(avg_north_reward),
            float(avg_healthy_reward),
            float(avg_ctrl_cost),
            float(avg_sideways_cost),
            policy_pg_loss,
            entropy_term_loss,
            value_fn_loss,
            current_total_steps
        ])
        csv_file.flush() # Ensure data is written to disk immediately
        # --- End Log to CSV ---

        # Clear the buffer for next rollout (on-policy)
        ppo_torch_agent.data = ReplayBuffer(
            action_prob_exist=True,
            max_size=ppo_torch_args.traj_length * ppo_torch_args.num_envs,
            state_dim=env.observation_size,
            num_action=env.action_size
        )

# --- Close CSV file after training ---
csv_file.close()
print(f"Training data saved to {csv_filename}")
# --- End Close CSV file ---

print("\nPyTorch PPO training finished.")

# --- 5. Visualize Trained Policy (PyTorch PPO) ---
print("\nGenerating rollout with trained PyTorch PPO policy...")
# For evaluation, we still use a single environment for visualization consistency.
eval_env = envs.create(env_name=env_name, episode_length=2000, batch_size=1) # Create a single env for evaluation
jit_reset_eval = jax.jit(eval_env.reset)
jit_step_eval = jax.jit(eval_env.step)

rng, eval_rng = jax.random.split(rng)
eval_state = jit_reset_eval(eval_rng)
rollout_eval_torch = [eval_state.pipeline_state]

n_steps_eval = 1000
render_every = 2

for i in range(n_steps_eval):
    # Modified: Use DLPack for zero-copy JAX to PyTorch transfer in evaluation (fixed deprecated API)
    obs_torch_eval = torch.utils.dlpack.from_dlpack(eval_state.obs).unsqueeze(0).to(device)

    # Get action from PyTorch actor (deterministic for evaluation)
    with torch.no_grad():
        mu_eval, _ = ppo_torch_agent.get_action(obs_torch_eval)
        action_torch_eval = mu_eval # Use mean for deterministic action

    # Modified: Use DLPack for zero-copy PyTorch to JAX transfer in evaluation (fixed deprecated API)
    action_jax_eval = jax.dlpack.from_dlpack(action_torch_eval.squeeze(0))

    # Step JAX environment
    eval_state = jit_step_eval(eval_state, action_jax_eval)
    rollout_eval_torch.append(eval_state.pipeline_state)

    if eval_state.done[0]: # Check done for the first (and only) environment in eval batch
        print(f"Evaluation episode finished early at step {i+1}")
        break

print("Rendering video from PyTorch PPO trained policy...")
media.show_video(eval_env.render(rollout_eval_torch[::render_every]), fps=1.0 / eval_env.dt / render_every)


# Save model to path

In [None]:
import torch

# Define a filename for the saved model
model_save_path = "trained_ppo_model.pth"

# Save the state dictionary of the PPO agent
torch.save(ppo_torch_agent.state_dict(), model_save_path)

print(f"Trained PPO model saved to {model_save_path}")

In [None]:
# Import necessary libraries (assuming they are already imported in previous cells)
import jax
import jax.numpy as jnp
from brax import envs
import torch
import numpy as np
import csv # Added for CSV logging
from datetime import datetime # Added for unique filename

# --- 1. Define Hyperparameters for the PyTorch PPO Agent ---
# Re-instantiate PPOConfig if the kernel state might have been reset
class PPOConfig:
    def __init__(self):
        self.traj_length = 8 # Number of steps to collect before update
        self.layer_num = 2
        self.hidden_dim = 256
        self.activation_function = torch.tanh
        self.last_activation = None
        self.trainable_std = True
        self.actor_lr = 3e-4
        self.critic_lr = 1e-3
        self.train_epoch = 4 # Number of PPO epochs
        self.batch_size = 1024 # Minibatch size for update
        self.gamma = 0.99
        self.lambda_ = 0.95
        self.max_clip = 0.2
        self.critic_coef = 0.5
        self.entropy_coef = 0.2
        self.max_grad_norm = 0.5
        self.num_envs = 4096 # Number of parallel environments
        self.num_updates = 100 # Default number of new PPO updates

ppo_torch_args = PPOConfig()

# --- 2. Initialize Brax Environment (JAX) ---
env_name = 'humanoid'
# Reuse the ScaleRewardWrapper if it was part of the original environment setup
class ScaleRewardWrapper(envs.Wrapper):
    def __init__(self, env, scale=0.1):
        super().__init__(env)
        self.scale = scale

    def reset(self, rng):
        state = super().reset(rng)
        return state

    def step(self, state, action):
        state = super().step(state, action)
        return state.replace(reward=state.reward * self.scale)

env = envs.create(env_name=env_name, episode_length=500, batch_size=ppo_torch_args.num_envs)
env = ScaleRewardWrapper(env, scale=0.1) # Applying a reward scale during new training
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

print(f"DEBUG: env.observation_size: {env.observation_size}")
print(f"DEBUG: env.action_size: {env.action_size}")
print(f"DEBUG: env.sys.nu (num actuators): {env.sys.nu}")
print(f"DEBUG: Number of environments (batch_size): {env.batch_size}")

# --- 3. Initialize PyTorch PPO Agent ---
# Set device for PyTorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using PyTorch device: {device}")

# Dummy writer for now
class DummyWriter:
    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        pass
writer = DummyWriter()

ppo_torch_agent = PPO(
    writer=writer,
    device=device,
    state_dim=env.observation_size,
    action_dim=env.action_size,
    args=ppo_torch_args
)
ppo_torch_agent.to(device)

# Load the saved model state dictionary
model_save_path = "trained_ppo_model.pth"
try:
    ppo_torch_agent.load_state_dict(torch.load(model_save_path))
    print(f"Successfully loaded model from {model_save_path}")
except FileNotFoundError:
    print(f"Warning: Model file not found at {model_save_path}. Starting training from scratch.")
except Exception as e:
    print(f"Error loading model: {e}. Starting training from scratch.")

print("PyTorch PPO agent initialized (or loaded).")

# --- CSV Logging Setup for new training session ---
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_filename = f"ppo_training_log_resumed_{timestamp}.csv"
csv_file = open(csv_filename, 'w', newline='')
csv_writer = csv.writer(csv_file)

# Modified: Added 'avg_episode_length' to the header
header = [
    'update_idx', 'avg_total_reward', 'avg_north_reward', 'avg_healthy_reward',
    'avg_ctrl_cost', 'avg_sideways_cost', 'avg_orientation_cost', 'avg_z_angular_cost', # Added new costs
    'policy_gradient_loss',
    'entropy_loss', 'value_loss', 'total_steps', 'avg_episode_length'
]
csv_writer.writerow(header)
print(f"Logging new training data to {csv_filename}")

# --- 4. Main Training Loop ---
current_total_steps = 0
rng = jax.random.PRNGKey(0) # Re-initialize RNG for this training session

# Added: Initialize episode_step_counts for each parallel environment
episode_step_counts = jnp.zeros(ppo_torch_args.num_envs, dtype=jnp.int32)

print("Starting new PyTorch PPO training loop...")

ppo_torch_agent.data = ReplayBuffer(
    action_prob_exist=True,
    max_size=ppo_torch_args.traj_length * ppo_torch_args.num_envs,
    state_dim=env.observation_size,
    num_action=env.action_size
)

for update_idx in range(ppo_torch_args.num_updates):
    rng, reset_rng = jax.random.split(rng)
    env_state = jit_reset(reset_rng)

    total_sum_rewards_collected = 0.0
    total_sum_north_rewards_collected = 0.0
    total_sum_healthy_rewards_collected = 0.0
    total_sum_ctrl_costs_collected = 0.0
    total_sum_sideways_costs_collected = 0.0
    total_sum_orientation_costs_collected = 0.0 # Added
    total_sum_z_angular_costs_collected = 0.0 # Added
    # Added: Accumulators for average episode length
    total_episode_lengths_collected = 0
    terminated_episodes_count = 0

    for t in range(ppo_torch_args.traj_length):
        obs_torch = torch.utils.dlpack.from_dlpack(env_state.obs).to(device)

        with torch.no_grad():
            mu, sigma = ppo_torch_agent.get_action(obs_torch)
            action_dist = torch.distributions.Normal(mu, sigma)
            action_torch = action_dist.sample()
            log_prob_torch = action_dist.log_prob(action_torch).sum(dim=-1, keepdim=True)

        action_jax = jax.dlpack.from_dlpack(action_torch)

        rng, step_rng = jax.random.split(rng)
        next_env_state = jit_step(env_state, action_jax)

        # Added: Update episode_step_counts and accumulate terminated episode lengths
        # Increment step counts for all active environments
        episode_step_counts = episode_step_counts + 1

        # Identify environments that terminated in this step
        terminated_mask = next_env_state.done # This is a JAX array of booleans

        # Add lengths of newly terminated episodes to total_episode_lengths_collected
        # Only sum episode_step_counts for environments where terminated_mask is True
        terminated_lengths_this_step = episode_step_counts * terminated_mask
        total_episode_lengths_collected += jnp.sum(terminated_lengths_this_step).item() # .item() to convert JAX scalar to Python scalar
        terminated_episodes_count += jnp.sum(terminated_mask).item() # .item()

        # Reset step counts for environments that terminated
        episode_step_counts = episode_step_counts * (1 - terminated_mask)

        current_obs_batch_np = obs_torch.cpu().numpy()
        current_action_batch_np = action_torch.cpu().numpy()
        current_log_prob_batch_np = log_prob_torch.cpu().numpy()

        next_obs_batch_np = np.array(next_env_state.obs)
        reward_batch_np = np.array(next_env_state.reward)
        done_batch_np = np.array(next_env_state.done)

        north_reward_batch_np = np.array(next_env_state.metrics['north_reward'])
        healthy_reward_batch_np = np.array(next_env_state.metrics['reward_alive'])
        ctrl_cost_batch_np = -np.array(next_env_state.metrics['reward_quadctrl'])
        sideways_cost_batch_np = np.array(next_env_state.metrics['sideways_cost'])
        orientation_cost_batch_np = np.array(next_env_state.metrics['orientation_cost']) # Added
        z_angular_cost_batch_np = np.array(next_env_state.metrics['z_angular_cost']) # Added

        for env_idx in range(ppo_torch_args.num_envs):
            transition = make_transition(
                current_obs_batch_np[env_idx],
                current_action_batch_np[env_idx],
                reward_batch_np[env_idx].reshape(1),
                next_obs_batch_np[env_idx],
                done_batch_np[env_idx].reshape(1),
                current_log_prob_batch_np[env_idx]
            )
            ppo_torch_agent.put_data(transition)

            total_sum_rewards_collected += reward_batch_np[env_idx]
            total_sum_north_rewards_collected += north_reward_batch_np[env_idx]
            total_sum_healthy_rewards_collected += healthy_reward_batch_np[env_idx]
            total_sum_ctrl_costs_collected += ctrl_cost_batch_np[env_idx]
            total_sum_sideways_costs_collected += sideways_cost_batch_np[env_idx]
            total_sum_orientation_costs_collected += orientation_cost_batch_np[env_idx] # Added
            total_sum_z_angular_costs_collected += z_angular_cost_batch_np[env_idx] # Added

        env_state = next_env_state
        current_total_steps += ppo_torch_args.num_envs

    num_collected_transitions = ppo_torch_args.traj_length * ppo_torch_args.num_envs
    avg_total_reward = total_sum_rewards_collected / num_collected_transitions
    avg_north_reward = total_sum_north_rewards_collected / num_collected_transitions
    avg_healthy_reward = total_sum_healthy_rewards_collected / num_collected_transitions
    avg_ctrl_cost = total_sum_ctrl_costs_collected / num_collected_transitions
    avg_sideways_cost = total_sum_sideways_costs_collected / num_collected_transitions
    avg_orientation_cost = total_sum_orientation_costs_collected / num_collected_transitions # Added
    avg_z_angular_cost = total_sum_z_angular_costs_collected / num_collected_transitions # Added

    # Added: Calculate average episode length
    avg_episode_length = total_episode_lengths_collected / terminated_episodes_count if terminated_episodes_count > 0 else 0.0

    print(f"Update {update_idx + 1} completed. Avg Total Reward: {avg_total_reward:.2f}. Total steps: {current_total_steps}")
    print(f"  Avg Reward Components: North Reward: {avg_north_reward:.2f}, Healthy Reward: {avg_healthy_reward:.2f}, Ctrl Cost: {avg_ctrl_cost:.2f}, Sideways Cost: {avg_sideways_cost:.2f}")
    print(f"  Added Costs: Orientation Cost: {avg_orientation_cost:.2f}, Z Angular Cost: {avg_z_angular_cost:.2f}") # Added
    # Added: Print average episode length
    print(f"  Avg Episode Length: {avg_episode_length:.2f} (from {terminated_episodes_count} terminations)")

    if ppo_torch_agent.data.size() >= num_collected_transitions:
        policy_pg_loss, entropy_term_loss, value_fn_loss = ppo_torch_agent.train_net(update_idx)
        print(f"PPO agent trained for update {update_idx + 1}. Losses: PG={policy_pg_loss:.4f}, Ent={entropy_term_loss:.4f}, Val={value_fn_loss:.4f}")

        csv_writer.writerow([
            update_idx + 1,
            float(avg_total_reward),
            float(avg_north_reward),
            float(avg_healthy_reward),
            float(avg_ctrl_cost),
            float(avg_sideways_cost),
            float(avg_orientation_cost), # Added
            float(avg_z_angular_cost), # Added
            policy_pg_loss,
            entropy_term_loss,
            value_fn_loss,
            current_total_steps,
            float(avg_episode_length) # Added: Log average episode length
        ])
        csv_file.flush()

        ppo_torch_agent.data = ReplayBuffer(
            action_prob_exist=True,
            max_size=ppo_torch_args.traj_length * ppo_torch_args.num_envs,
            state_dim=env.observation_size,
            num_action=env.action_size
        )

csv_file.close()
print(f"New training data saved to {csv_filename}")

print("\nPyTorch PPO training finished.")

# --- 5. Visualize Trained Policy (PyTorch PPO) ---
print("\nGenerating rollout with newly trained PyTorch PPO policy...")
eval_env = envs.create(env_name=env_name, episode_length=2000, batch_size=1)
jit_reset_eval = jax.jit(eval_env.reset)
jit_step_eval = jax.jit(eval_env.step)

rng, eval_rng = jax.random.split(rng)
eval_state = jit_reset_eval(eval_rng)
rollout_eval_torch = [eval_state.pipeline_state]

n_steps_eval = 1000
render_every = 2

for i in range(n_steps_eval):
    obs_torch_eval = torch.utils.dlpack.from_dlpack(eval_state.obs).unsqueeze(0).to(device)

    with torch.no_grad():
        mu_eval, _ = ppo_torch_agent.get_action(obs_torch_eval)
        action_torch_eval = mu_eval

    action_jax_eval = jax.dlpack.from_dlpack(action_torch_eval.squeeze(0))

    eval_state = jit_step_eval(eval_state, action_jax_eval)
    rollout_eval_torch.append(eval_state.pipeline_state)

    if eval_state.done[0]:
        print(f"Evaluation episode finished early at step {i+1}")
        break

print("Rendering video from new PyTorch PPO trained policy...")
media.show_video(eval_env.render(rollout_eval_torch[::render_every]), fps=1.0 / eval_env.dt / render_every)
