In [1]:
# Setup CUDA library paths for JAX GPU support
import os
import sys
import ctypes
import glob

# Find the nvidia CUDA packages - search all site-packages directories
nvidia_base = None
for path in sys.path:
    potential_nvidia = os.path.join(path, 'nvidia')
    if os.path.exists(potential_nvidia) and os.path.isdir(potential_nvidia):
        nvidia_base = potential_nvidia
        print(f"Found nvidia packages at: {nvidia_base}")
        break

if nvidia_base:
    # Find all lib directories under nvidia packages
    lib_dirs = glob.glob(f"{nvidia_base}/*/lib")
    
    if lib_dirs:
        # Set LD_LIBRARY_PATH
        current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
        new_ld_path = ':'.join(lib_dirs)
        if current_ld_path:
            new_ld_path = f"{new_ld_path}:{current_ld_path}"
        os.environ['LD_LIBRARY_PATH'] = new_ld_path
        
        # CRITICAL: Preload CUDA libraries using ctypes before JAX loads
        preloaded = []
        for lib_dir in lib_dirs:
            # Try to preload key CUDA libraries
            for lib_name in ['libcudart.so.12', 'libcublas.so.12', 'libcublasLt.so.12']:
                lib_path = os.path.join(lib_dir, lib_name)
                if os.path.exists(lib_path):
                    try:
                        ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
                        preloaded.append(lib_name)
                    except Exception as e:
                        print(f"Warning loading {lib_name}: {e}")
        
        print(f"✓ Set LD_LIBRARY_PATH with {len(lib_dirs)} CUDA directories")
        print(f"✓ Preloaded {len(set(preloaded))} CUDA libraries: {set(preloaded)}")
    else:
        print(f"⚠ Found nvidia directory but no lib subdirectories")
else:
    print("⚠ Could not find nvidia CUDA packages in sys.path")
    print("sys.path entries:", sys.path[:3])

Found nvidia packages at: /home/maklam/projects/braxphysics/venv_gpu/lib/python3.12/site-packages/nvidia
✓ Set LD_LIBRARY_PATH with 11 CUDA directories
✓ Preloaded 3 CUDA libraries: {'libcublasLt.so.12', 'libcublas.so.12', 'libcudart.so.12'}


In [9]:
# Check GPU availability
import jax
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())
print("Default backend:", jax.default_backend())

JAX version: 0.8.2
Available devices: [CudaDevice(id=0)]
Default backend: gpu


In [3]:
# Imports

from jax import numpy as jp
from brax import envs
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks

Failed to import warp: No module named 'warp'
Failed to import mujoco_warp: No module named 'warp'


In [4]:
# Environment setup

# Use humanoid, the pre-made bipedal robot from Brax
# Use brax.generalized for physics
env_name = 'humanoid'
env = envs.get_environment(env_name=env_name, backend='generalized')



In [7]:
# Define training function

# PPO (Proximal Policy Optimization)
def make_networks_factory(obs_shape, action_size, preprocess_observations_fn=lambda x: x):
    return ppo_networks.make_ppo_networks(
        obs_shape, action_size,
        preprocess_observations_fn=preprocess_observations_fn,
        policy_hidden_layer_sizes=(128, 128, 128, 128),
        value_hidden_layer_sizes=(128, 128, 128, 128),
    )

In [12]:
# Train

print("Training started...")
train_fn = lambda: ppo.train(
    environment=env,
    num_timesteps=50_000_000, #total simulation steps
    num_evals=10,
    reward_scaling=0.1,
    episode_length=1000,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=20,
    num_minibatches=32,
    num_updates_per_batch=4,
    discounting=0.97,
    learning_rate=3e-4,
    entropy_cost=1e-2,
    num_envs=2048,
    batch_size=1024,
    network_factory=make_networks_factory,
)

# Run the training on the GPU
inference_fn, params, metrics = train_fn()
print("Training Complete!")

Training started...
Training Complete!


In [19]:
# Visualize

# Initialize a fresh environment for playback
env_vis = envs.create(env_name=env_name, backend='generalized')
jit_reset = jax.jit(env_vis.reset)
jit_step = jax.jit(env_vis.step)

# Get the policy function from inference_fn
policy_fn = inference_fn(params, deterministic=True)

# Wrap policy to extract just the action
def get_action(obs, key):
    result = policy_fn(obs, key)
    # Policy returns (action, ...) tuple, we need just the action
    if isinstance(result, tuple):
        return result[0]
    return result

jit_policy = jax.jit(get_action)

# Start the robot
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
states = []

# Run for 500 steps
print("Generating visualization...")
for i in range(500):
    if i % 100 == 0:
        print(f"  Step {i}/500")
    
    # Generate a random key for the policy
    rng, key = jax.random.split(rng)
    
    # Ask the AI what to do based on the current state
    act = jit_policy(state.obs, key)
    
    # Step the physics
    state = jit_step(state, act)
    states.append(state.pipeline_state)

# Save to HTML
print("Rendering HTML...")
with open("humanoid_walk.html", "w") as f:
    f.write(html.render(env_vis.sys, states))

print("\n✓ Saved visualization to 'humanoid_walk.html'")
print("Open the file in a browser to see your trained humanoid walk!")



Generating visualization...
  Step 0/500
  Step 100/500
  Step 200/500
  Step 300/500
  Step 400/500
Rendering HTML...

✓ Saved visualization to 'humanoid_walk.html'
Open the file in a browser to see your trained humanoid walk!
