# H1-2 Collision Avoidance Training Notebook

This notebook:
1. Runs MVP tests to validate the environment
2. Trains an asymmetric actor-critic policy with PPO
3. Evaluates and visualizes the trained policy

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"

In [1]:
from mujoco_playground import locomotion, wrapper
from mujoco_playground.config import locomotion_params


In [2]:
# Load environment
print("Loading environment...")
env_name = "H12SkinAvoid"
env_cfg = locomotion.get_default_config(env_name)
env = locomotion.load(env_name, config=env_cfg)

Loading environment...
Loaded 103 assets for H12 Skin.
Assets: ['h1_2_skin_mjx copy.xml', 'scene_mjx.xml', 'h1_2_skin_mjx.xml', 'h1_2_skin.xml', 'scene.xml', 'left_shoulder_pitch_link.STL', 'link18_R.STL', 'link19_R.STL', 'link12_L.STL', 'L_thumb_proximal.STL', 'left_ankle_B_link.STL', 'L_hand_base_link.STL', 'left_hip_pitch_link.STL', 'right_wrist_pitch_link.STL', 'right_shoulder_roll_link.STL', 'R_ring_intermediate.STL', 'link20_R.STL', 'left_ankle_A_rod_link.STL', 'link15_R.STL', 'link22_L.STL', 'left_knee_link.STL', 'R_ring_proximal.STL', 'link22_R.STL', 'L_thumb_intermediate.STL', 'left_wrist_roll_link.STL', 'right_ankle_A_rod_link.STL', 'right_ankle_pitch_link.STL', 'link21_L.STL', 'left_ankle_roll_link.STL', 'left_ankle_B_rod_link.STL', 'L_pinky_intermediate.STL', 'link16_R.STL', 'R_index_proximal.STL', 'R_pinky_proximal.STL', 'left_hip_roll_link.STL', 'link12_R.STL', 'link17_L.STL', 'wrist_yaw_link.STL', 'left_shoulder_roll_link.STL', 'R_pinky_intermediate.STL', 'right_pitch_li

W1204 18:51:24.678091   80589 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1204 18:51:24.681279   80464 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


Downsampling trajectories by 5x: 500.0Hz -> 100.0Hz
Loading 1600 trajectory files from /home/wxie/workspace/h1_mujoco/augmented...
Successfully loaded 1600 trajectories
Trajectory lengths: min=22, max=69, mean=34.7
Loaded trajectory database: {'num_trajectories': 1600, 'min_length': 22, 'max_length': 69, 'mean_length': 34.6875, 'std_length': 12.158940897545312, 'total_timesteps': 55500}
Found 63 skin sensors


In [3]:
import functools
import json
from datetime import datetime

import jax
import jax.numpy as jp
import matplotlib.pyplot as plt
import mediapy as media
import mujoco
import wandb
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from etils import epath
from flax.training import orbax_utils
from IPython.display import clear_output, display
from orbax import checkpoint as ocp
from tqdm import tqdm

from mujoco_playground import locomotion, wrapper
from mujoco_playground.config import locomotion_params

# Enable persistent compilation cache
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

## Section 1: MVP Testing

Run basic tests to validate the environment before training.

In [4]:
# Test configuration
env_name = "H12SkinAvoid"
env_cfg = locomotion.get_default_config(env_name)

# Override trajectory directory if needed
# env_cfg.traj_dir = "./traj_logs"

print("Environment configuration:")
print(f"  ctrl_dt: {env_cfg.ctrl_dt}")
print(f"  episode_length: {env_cfg.episode_length}")
print(f"  history_len: {env_cfg.history_len}")
print(f"  action_scale: {env_cfg.action_scale}")
print(f"  traj_dir: {env_cfg.traj_dir}")

Environment configuration:
  ctrl_dt: 0.02
  episode_length: 1000
  history_len: 3
  action_scale: 0.6
  traj_dir: ./traj_logs


In [5]:
# Load environment
print("Loading environment...")
env = locomotion.load(env_name, config=env_cfg)

print(f"✓ Environment loaded successfully")
print(f"  Controlled joints: {env._num_controlled_joints}")
print(f"  Skin sensors: {env._num_skin_sensors}")
print(f"  Action size: {env.action_size}")
print(f"  Trajectories loaded: {env._traj_db.num_trajectories}")

# Print trajectory database stats
stats = env._traj_db.get_stats()
print("\nTrajectory database stats:")
for key, value in stats.items():
    print(f"  {key}: {value}")

Loading environment...
mujoco_menagerie not found. Downloading...


Cloning mujoco_menagerie: ██████████| 100/100 [00:31<00:00]


Checking out commit e95616395529f0c4093d97a8759b0eb1160a95e6
Successfully downloaded mujoco_menagerie


ValueError: Error: Error opening file '../mujoco_menagerie/unitree_h1_2_skin/assets/skin/torso_skin.stl': No such file or directory

In [None]:
# Test reset
print("Testing reset...")
rng = jax.random.PRNGKey(0)
state = env.reset(rng)

print(f"✓ Reset successful")
print(f"  Observation keys: {state.obs.keys()}")
print(f"  Actor obs shape: {state.obs['state'].shape}")
print(f"  Critic obs shape: {state.obs['privileged_state'].shape}")
print(f"  Trajectory length: {state.info['traj_length']}")

# Check observation sizes
actor_size, critic_size = env._compute_obs_size()
print(f"\nComputed observation sizes:")
print(f"  Actor: {actor_size} (actual: {state.obs['state'].shape[0]})")
print(f"  Critic: {critic_size} (actual: {state.obs['privileged_state'].shape[0]})")

assert state.obs['state'].shape[0] == actor_size, "Actor obs size mismatch!"
assert state.obs['privileged_state'].shape[0] == critic_size, "Critic obs size mismatch!"
print("✓ Observation shapes validated")

In [None]:
# Test step
print("Testing step...")
rng, action_rng = jax.random.split(rng)
action = jax.random.uniform(action_rng, shape=(env._num_controlled_joints,), minval=-1.0, maxval=1.0)
next_state = env.step(state, action)

print(f"✓ Step successful")
print(f"  Reward: {next_state.reward:.4f}")
print(f"  Done: {next_state.done}")
print(f"  Trajectory step: {next_state.info['traj_step']}/{next_state.info['traj_length']}")

# Print top reward components
print("\n  Top reward components:")
reward_items = [(k, float(v)) for k, v in next_state.metrics.items() if k.startswith('reward/')]
reward_items.sort(key=lambda x: abs(x[1]), reverse=True)
for k, v in reward_items[:10]:
    print(f"    {k}: {v:.6f}")

In [None]:
# Test capacitance computation
print("Testing capacitance sensing...")
capacitances = state.info['capacitances']

print(f"  Num sensors: {len(capacitances)}")
print(f"  Num detections: {jp.sum(capacitances > 0.0):.0f}")
print(f"  Max capacitance: {jp.max(capacitances):.4f}")
min_nonzero = jp.min(jp.where(capacitances > 0, capacitances, jp.inf))
if jp.isfinite(min_nonzero):
    print(f"  Min (non-zero) capacitance: {min_nonzero:.4f}")
else:
    print(f"  No detections (obstacle far from robot)")

print(f"✓ Capacitance computation working")

In [None]:
# Test JIT compilation
print("Testing JIT compilation...")
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

rng = jax.random.PRNGKey(42)
state = jit_reset(rng)
action = jp.zeros(env._num_controlled_joints)
state = jit_step(state, action)

print("✓ JIT compilation successful")

In [None]:
# Test short rollout
print("Testing 50-step rollout...")
rng = jax.random.PRNGKey(123)
state = jit_reset(rng)

total_reward = 0.0
collision_count = 0

for step in range(50):
    rng, action_rng = jax.random.split(rng)
    action = jax.random.uniform(action_rng, shape=(env._num_controlled_joints,), minval=-1.0, maxval=1.0)
    state = jit_step(state, action)
    total_reward += float(state.reward)
    
    if float(state.metrics['collision_detected']) > 0.5:
        collision_count += 1
    
    if float(state.done) > 0.5:
        print(f"  Episode terminated at step {step+1}")
        break

print(f"✓ Rollout completed")
print(f"  Steps: {step+1}")
print(f"  Total reward: {total_reward:.2f}")
print(f"  Avg reward: {total_reward/(step+1):.4f}")
print(f"  Collisions: {collision_count}")

In [None]:
# Test reward scale signs
print("Checking reward scale signs...")
scales = env._config.reward_config.scales

penalty_keys = ['collision_penalty', 'proximity_penalty', 'action_rate', 'torque', 'energy', 'joint_limit']
issues = []

for key in penalty_keys:
    if key in scales:
        if scales[key] < 0:
            issues.append(f"  ✗ {key} = {scales[key]} (should be positive)")
        else:
            print(f"  ✓ {key} = {scales[key]}")

if issues:
    print("\n⚠ WARNING: Issues found:")
    for issue in issues:
        print(issue)
else:
    print("\n✓ All reward scales correct")

## Section 2: Training Setup

Configure PPO training parameters and initialize W&B logging.

In [None]:
# Get PPO parameters
# You'll need to create this in locomotion_params.py
# For now, let's create basic parameters
from ml_collections import config_dict

ppo_params = config_dict.ConfigDict()
ppo_params.num_timesteps = 100_000_000  # 100M steps
ppo_params.num_evals = 20
ppo_params.reward_scaling = 1.0
ppo_params.episode_length = env_cfg.episode_length
ppo_params.normalize_observations = True
ppo_params.action_repeat = 1
ppo_params.unroll_length = 10
ppo_params.num_minibatches = 32
ppo_params.num_updates_per_batch = 4
ppo_params.discounting = 0.97
ppo_params.learning_rate = 3e-4
ppo_params.entropy_cost = 1e-3
ppo_params.num_envs = 4096
ppo_params.batch_size = 4096
ppo_params.seed = 0

# Network architecture for asymmetric actor-critic
ppo_params.network_factory = config_dict.ConfigDict()
ppo_params.network_factory.policy_hidden_layer_sizes = (512, 256, 128)
ppo_params.network_factory.value_hidden_layer_sizes = (512, 256, 128)
ppo_params.network_factory.activation = 'swish'

from pprint import pprint
print("PPO Parameters:")
pprint(ppo_params.to_dict())

In [None]:
# Alternatively, if you've added to locomotion_params.py:
# ppo_params = locomotion_params.brax_ppo_config(env_name)

In [None]:
# Setup wandb logging
USE_WANDB = True  # Set to True to enable W&B logging

if USE_WANDB:
    wandb.init(project="mjxrl-avoid", config=env_cfg)
    wandb.config.update({"env_name": env_name})
    wandb.config.update(ppo_params.to_dict())

In [None]:
# Setup checkpointing
SUFFIX = None  # Add suffix to experiment name if desired
FINETUNE_PATH = None  # Set to checkpoint path to resume training

# Generate unique experiment name
now = datetime.now()
timestamp = now.strftime("%Y%m%d-%H%M%S")
exp_name = f"{env_name}-{timestamp}"
if SUFFIX is not None:
    exp_name += f"-{SUFFIX}"
print(f"Experiment name: {exp_name}")

# Possibly restore from checkpoint
if FINETUNE_PATH is not None:
    FINETUNE_PATH = epath.Path(FINETUNE_PATH)
    latest_ckpts = list(FINETUNE_PATH.glob("*"))
    latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]
    latest_ckpts.sort(key=lambda x: int(x.name))
    latest_ckpt = latest_ckpts[-1]
    restore_checkpoint_path = latest_ckpt
    print(f"Restoring from: {restore_checkpoint_path}")
else:
    restore_checkpoint_path = None

In [None]:
# Create checkpoint directory
ckpt_path = epath.Path("checkpoints").resolve() / exp_name
ckpt_path.mkdir(parents=True, exist_ok=True)
print(f"Checkpoint path: {ckpt_path}")

# Save config
with open(ckpt_path / "config.json", "w") as fp:
    json.dump(env_cfg.to_dict(), fp, indent=4)

## Section 3: Training

Train the asymmetric actor-critic policy with PPO.

In [None]:
# Setup progress tracking
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress(num_steps, metrics):
    """Progress callback for training."""
    # Log to wandb
    if USE_WANDB:
        wandb.log(metrics, step=num_steps)

    # Plot
    clear_output(wait=True)
    times.append(datetime.now())
    x_data.append(num_steps)
    y_data.append(metrics["eval/episode_reward"])
    y_dataerr.append(metrics["eval/episode_reward_std"])

    plt.xlim([0, ppo_params.num_timesteps * 1.25])
    # plt.ylim([0, 100])  # Adjust based on your expected reward range
    plt.xlabel("# environment steps")
    plt.ylabel("reward per episode")
    plt.title(f"y={y_data[-1]:.3f}")
    plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

    display(plt.gcf())


def policy_params_fn(current_step, make_policy, params):
    """Checkpoint callback."""
    del make_policy  # Unused
    orbax_checkpointer = ocp.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(params)
    path = ckpt_path / f"{current_step}"
    orbax_checkpointer.save(path, params, force=True, save_args=save_args)

In [None]:
# Setup training function
# Get domain randomization if available
try:
    randomizer = locomotion.get_domain_randomizer(env_name)
except:
    randomizer = None
    print("No domain randomizer found, using None")

training_params = dict(ppo_params)
del training_params["network_factory"]

train_fn = functools.partial(
    ppo.train,
    **training_params,
    network_factory=functools.partial(
        ppo_networks.make_ppo_networks,
        **ppo_params.network_factory
    ),
    restore_checkpoint_path=restore_checkpoint_path,
    progress_fn=progress,
    wrap_env_fn=wrapper.wrap_for_brax_training,
    policy_params_fn=policy_params_fn,
    randomization_fn=randomizer,
)

In [None]:
# Train!
print("Starting training...")
print(f"Total timesteps: {ppo_params.num_timesteps:,}")
print(f"Number of environments: {ppo_params.num_envs}")
print(f"Batch size: {ppo_params.batch_size}")

env = locomotion.load(env_name, config=env_cfg)
eval_env = locomotion.load(env_name, config=env_cfg)
make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)

if len(times) > 1:
    print(f"\nTime to jit: {times[1] - times[0]}")
    print(f"Time to train: {times[-1] - times[1]}")

In [None]:
# Final plot of reward vs wallclock time
plt.figure()
# plt.ylim([0, 100])
plt.xlabel("wallclock time (s)")
plt.ylabel("reward per episode")
plt.title(f"y={y_data[-1]:.3f}")
plt.errorbar(
    [(t - times[0]).total_seconds() for t in times[:-1]],
    y_data,
    yerr=y_dataerr,
    color="blue",
)
plt.show()

In [None]:
# Save normalizer and policy params
import pickle

normalizer_params, policy_params, value_params = params
with open(ckpt_path / "params.pkl", "wb") as f:
    data = {
        "normalizer_params": normalizer_params,
        "policy_params": policy_params,
        "value_params": value_params,
    }
    pickle.dump(data, f)

print(f"Parameters saved to {ckpt_path / 'params.pkl'}")

## Section 4: Evaluation

Evaluate the trained policy and visualize rollouts.

In [None]:
# Setup inference
inference_fn = make_inference_fn(params, deterministic=True)
jit_inference_fn = jax.jit(inference_fn)

eval_env = locomotion.load(env_name, config=env_cfg)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

In [None]:
# Run evaluation rollouts
print("Running evaluation rollouts...")
rng = jax.random.PRNGKey(12345)
rollout = []
rewards = []
tracking_rewards = []
recovery_rewards = []
collision_rewards = []
capacitances_max = []
tracking_errors = []

num_episodes = 10
for ep in tqdm(range(num_episodes)):
    rng, reset_rng = jax.random.split(rng)
    state = jit_reset(reset_rng)
    
    for i in range(env_cfg.episode_length):
        act_rng, rng = jax.random.split(rng)
        ctrl, _ = jit_inference_fn(state.obs, act_rng)
        state = jit_step(state, ctrl)
        
        rollout.append(state)
        rewards.append(
            {k[7:]: float(v) for k, v in state.metrics.items() if k.startswith("reward/")}
        )
        tracking_rewards.append(float(state.metrics['reward/joint_pos_tracking']))
        recovery_rewards.append(float(state.metrics['reward/recovery_upright']))
        collision_rewards.append(float(state.metrics['reward/collision_penalty']))
        capacitances_max.append(float(jp.max(state.info['capacitances'])))
        tracking_errors.append(float(state.metrics['tracking_error']))
        
        if state.done:
            print(f"  Episode {ep} terminated at step {i}")
            break

print(f"\nCollected {len(rollout)} timesteps from {num_episodes} episodes")

In [None]:
# Plot reward components over time
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Tracking reward
axes[0, 0].plot(tracking_rewards)
axes[0, 0].set_title('Joint Position Tracking Reward')
axes[0, 0].set_xlabel('Timestep')
axes[0, 0].set_ylabel('Reward')
axes[0, 0].grid(True)

# Recovery reward
axes[0, 1].plot(recovery_rewards)
axes[0, 1].set_title('Recovery Upright Reward')
axes[0, 1].set_xlabel('Timestep')
axes[0, 1].set_ylabel('Reward')
axes[0, 1].grid(True)

# Collision penalty
axes[1, 0].plot(collision_rewards)
axes[1, 0].set_title('Collision Penalty')
axes[1, 0].set_xlabel('Timestep')
axes[1, 0].set_ylabel('Penalty')
axes[1, 0].axhline(y=0, color='r', linestyle='--', label='No collision')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Max capacitance
axes[1, 1].plot(capacitances_max)
axes[1, 1].set_title('Max Capacitance (Proximity)')
axes[1, 1].set_xlabel('Timestep')
axes[1, 1].set_ylabel('Capacitance')
axes[1, 1].axhline(y=env._cap_collision_threshold, color='r', linestyle='--', label='Collision threshold')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Plot tracking error over time
plt.figure(figsize=(10, 4))
plt.plot(tracking_errors)
plt.title('Joint Position Tracking Error')
plt.xlabel('Timestep')
plt.ylabel('L2 Error')
plt.grid(True)
plt.show()

print(f"Mean tracking error: {jp.mean(jp.array(tracking_errors)):.4f}")
print(f"Std tracking error: {jp.std(jp.array(tracking_errors)):.4f}")

In [None]:
# Analyze collision statistics
num_collisions = sum(1 for r in collision_rewards if r < 0)
collision_rate = num_collisions / len(collision_rewards)

print(f"Collision Statistics:")
print(f"  Total timesteps: {len(collision_rewards)}")
print(f"  Collisions detected: {num_collisions}")
print(f"  Collision rate: {collision_rate*100:.2f}%")
print(f"  Max capacitance observed: {max(capacitances_max):.4f}")

In [None]:
# Render video
print("Rendering video...")
render_every = 2
fps = 1.0 / eval_env.dt / render_every
traj = rollout[::render_every]

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False

frames = eval_env.render(
    traj, camera="side", scene_option=scene_option, height=480, width=640
)
media.show_video(frames, fps=fps, loop=False)

# Optionally save video
# media.write_video(f"{env_name}.mp4", frames, fps=fps, qp=18)

## Section 5: Analysis

Additional analysis and visualization.

In [None]:
# Analyze reward distribution
if rewards:
    # Get all unique reward keys
    all_keys = set()
    for r in rewards:
        all_keys.update(r.keys())
    
    # Compute mean and std for each component
    print("Reward Component Statistics:")
    print(f"{'Component':<30} {'Mean':>12} {'Std':>12} {'Min':>12} {'Max':>12}")
    print("-" * 78)
    
    for key in sorted(all_keys):
        values = [r.get(key, 0.0) for r in rewards]
        values = jp.array(values)
        print(f"{key:<30} {jp.mean(values):>12.4f} {jp.std(values):>12.4f} {jp.min(values):>12.4f} {jp.max(values):>12.4f}")

In [None]:
# Plot histogram of capacitance readings
plt.figure(figsize=(10, 4))
cap_nonzero = [c for c in capacitances_max if c > 0]
if cap_nonzero:
    plt.hist(cap_nonzero, bins=50)
    plt.axvline(x=env._cap_collision_threshold, color='r', linestyle='--', label='Collision threshold')
    plt.xlabel('Capacitance')
    plt.ylabel('Frequency')
    plt.title('Distribution of Max Capacitance Readings (non-zero only)')
    plt.legend()
    plt.show()
else:
    print("No non-zero capacitance readings detected")

In [None]:
# Save analysis results
analysis_results = {
    'num_episodes': num_episodes,
    'total_timesteps': len(rollout),
    'collision_rate': float(collision_rate),
    'mean_tracking_error': float(jp.mean(jp.array(tracking_errors))),
    'std_tracking_error': float(jp.std(jp.array(tracking_errors))),
    'max_capacitance': float(max(capacitances_max)) if capacitances_max else 0.0,
}

with open(ckpt_path / "eval_results.json", "w") as f:
    json.dump(analysis_results, f, indent=4)

print(f"\nAnalysis results saved to {ckpt_path / 'eval_results.json'}")
print("\nEvaluation Summary:")
for key, value in analysis_results.items():
    print(f"  {key}: {value}")

In [None]:
# Finish W&B run
if USE_WANDB:
    wandb.finish()
    print("W&B run finished")