In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import os

os.environ['__NV_PRIME_RENDER_OFFLOAD'] = '1'
os.environ['__GLX_VENDOR_LIBRARY_NAME'] = 'nvidia'
os.environ['MUJOCO_GL'] = 'egl'

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

from functools import partial
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
import wandb


import jax
import mediapy as media
from randomize import domain_randomize
np.set_printoptions(precision=3, suppress=True, linewidth=100)

import balance
env = balance.G1Env()
eval_env = balance.G1Env()
env_cfg = balance.default_config()


from datetime import datetime

env_name = "g1_balance"
now = datetime.now()
timestamp = now.strftime("%Y%m%d-%H%M%S")
exp_name = f"{env_name}-{timestamp}"
import os

ckpt_path = os.path.abspath(os.path.join(".", "checkpoints", exp_name))
os.makedirs(ckpt_path, exist_ok=True)
print(f"Checkpoint path: {ckpt_path}")# media.show_video(frames, fps=1.0 / env.dt)


wandb.init(project="mjxrl", config=env_cfg)
wandb.config.update({
    "env_name": env_name,
})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Checkpoint path: /home/logan/Projects/g1_mjx_helloworld/balance_experiment/checkpoints/g1_balance-20250415-202615


[34m[1mwandb[0m: Currently logged in as: [33mjloganolson[0m ([33mjloganolson-n-a[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
# state = jit_reset(jax.random.PRNGKey(0))
# rollout = [state]
# f = 0.5
# from tqdm import tqdm
# import jax.numpy as jp
# for i in tqdm(range(200)):
#   action = []
#   for j in range(env.action_size):

#     if env.mj_model.actuator(j).name == "right_knee_joint" or env.mj_model.actuator(j).name == "left_shoulder_roll_joint":
#       value = jp.sin(
#             state.data.time * 2 * jp.pi * f 
#         ) * 1.
#     else:
#       value = 0.
      
#     action.append( value)
#   action = jp.array(action)
#   state = jit_step(state, action)
#   rollout.append(state)
# frames = env.render(rollout,camera="track")
# frames_np = np.array(frames)
# frames_np_rearranged = np.transpose(frames_np, (0, 3, 1, 2))
# wandb.log({"video": wandb.Video(frames_np_rearranged, fps=1.0 / env.dt, format="gif")})

In [None]:
from ml_collections import config_dict
# ppo_params= config_dict.create(
#     num_timesteps=1_000_000,
#     num_evals=10,
#     reward_scaling=10.0,
#     episode_length=env_cfg.episode_length,
#     normalize_observations=True,
#     action_repeat=1,
#     unroll_length=30,
#     num_minibatches=32,
#     num_updates_per_batch=16,
#     discounting=0.995,
#     learning_rate=1e-3,
#     entropy_cost=1e-2,
#     num_envs=2048,
#     batch_size=1024,
# )
# from ml_collections import config_dict
# ppo_params= config_dict.create(
#     num_timesteps=100_000_000,
#     reward_scaling=10.0,
#     episode_length=env_cfg.episode_length,
#     normalize_observations=True,
#     action_repeat=1,
#     unroll_length=30,
#     num_minibatches=32,
#     num_updates_per_batch=16,
#     discounting=0.995,
#     learning_rate=1e-3,
#     entropy_cost=1e-2,
#     num_envs=2048,
#     batch_size=1024,
#     num_evals=0,
#     log_training_metrics=True,
#     network_factory=config_dict.create(
#         policy_hidden_layer_sizes=(512, 256, 128),
#         value_hidden_layer_sizes=(512, 256, 128),
#         policy_obs_key="state",
#         value_obs_key="privileged_state",
#     )
# )
from ml_collections import config_dict
ppo_params= config_dict.create(
    num_timesteps=100_000_000,
    reward_scaling=0.1,
    episode_length=env_cfg.episode_length,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=32,
    num_minibatches=32,
    num_updates_per_batch=5,
    discounting=0.98,
    learning_rate=1e-4,
    entropy_cost=1e-2,
    num_envs=32768,
    batch_size=1024,
    num_evals=16,
    log_training_metrics=True,
    network_factory=config_dict.create(
        policy_hidden_layer_sizes=(512, 256, 64),
        value_hidden_layer_sizes=(256, 256, 256,256),
        # policy_obs_key="state",
        value_obs_key="privileged_state"
    )
)
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress_cli(num_steps, metrics):
  """Prints progress metrics to the console, including all available metrics."""

  wandb.log(metrics, step=num_steps)

  # Print the current step number
  print(f"Step: {num_steps}")

  # Print the entire metrics dictionary for debugging
  print("Metrics:", metrics)

  # You can add a separator for clarity if you run this multiple times
  print("-" * 20)
  # # Assuming `jit_reset` and `jit_step` are already JIT-compiled



def policy_params_fn(current_step, make_policy, params):
  del make_policy
  print("Policy params fn")
  jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))
  rng = jax.random.PRNGKey(42)
  state = jit_reset(rng)  # Initialize the environment state
  rollout = []

  # Run the rollout until termination
  while not state.done:  # Assuming `state.done` indicates termination
      rng, act_rng = jax.random.split(rng)
      ctrl, _ = jit_inference_fn(state.obs, act_rng)  # Get action from the policy
      state = jit_step(state, ctrl)  # Step the environment
      rollout.append(state)

  # Render and log the video
  frames = eval_env.render(rollout, camera="track")
  frames_np = np.array(frames)
  frames_np_rearranged = np.transpose(frames_np, (0, 3, 1, 2))
  wandb.log({"video": wandb.Video(frames_np_rearranged, fps=1.0 / env.dt, format="gif")}, step=current_step)


ppo_training_params = dict(ppo_params)
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
  del ppo_training_params["network_factory"]
  network_factory = partial(
      ppo_networks.make_ppo_networks,
      **ppo_params.network_factory
  )
from mujoco_playground import wrapper

train_fn = partial(
    ppo.train, **dict(ppo_training_params),
    network_factory=network_factory,
    progress_fn=progress_cli,
    # policy_params_fn=policy_params_fn,
    randomization_fn=domain_randomize,
    wrap_env_fn=wrapper.wrap_for_brax_training,
    save_checkpoint_path=ckpt_path  
)

In [5]:
make_inference_fn, params, metrics = train_fn(
    environment=env,
    eval_env=eval_env
)
# print(f"time to jit: {times[1] - times[0]}")
# print(f"time to train: {times[-1] - times[1]}")
# # ~6m11s

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7c53af85d570>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 7c53aeffb4f0, raw_cell="make_inference_fn, params, metrics = train_fn(
   .." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/home/logan/Projects/g1_mjx_helloworld/balance_experiment/main.ipynb#W4sZmlsZQ%3D%3D>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[5,5] wrapped in a BatchTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7c53af85d570>> (for post_run_cell), with arguments args (<ExecutionResult object at 7c53aeffae60, execution_count=5 error_before_exec=None error_in_exec=Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[5,5] wrapped in a BatchTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError info=<ExecutionInfo object at 7c53aeffb4f0, raw_cell="make_inference_fn, params, metrics = train_fn(
   .." store_history=True silent=False shell_futures=True c

BrokenPipeError: [Errno 32] Broken pipe

In [None]:


jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))

rng = jax.random.PRNGKey(42)
rollout = []
n_episodes = 1

for _ in range(n_episodes):
  state = jit_reset(rng)
  rollout.append(state)
  for i in range(env_cfg.episode_length):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state)


frames = eval_env.render(rollout, camera="track")
frames_np = np.array(frames)
frames_np_rearranged = np.transpose(frames_np, (0, 3, 1, 2))
wandb.log({"video": wandb.Video(frames_np_rearranged, fps=1.0 / env.dt, format="gif")})
