In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import os

os.environ['__NV_PRIME_RENDER_OFFLOAD'] = '1'
os.environ['__GLX_VENDOR_LIBRARY_NAME'] = 'nvidia'
os.environ['MUJOCO_GL'] = 'egl'

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# import mujoco
from datetime import datetime
from functools import partial
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import checkpoint as ppo_checkpoint

from IPython.display import  clear_output
import jax
from jax import numpy as jp
from matplotlib import pyplot as plt
import mediapy as media
from tqdm import tqdm
np.set_printoptions(precision=3, suppress=True, linewidth=100)


In [None]:

import g1hello
env = partial(g1hello.G1_23dof)()
env_cfg = g1hello.default_config()
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:

# state = jit_reset(jax.random.PRNGKey(0))
# rollout = [state]

# f = 0.5

# for i in tqdm(range(200)):
#   action = []
#   for j in range(env.action_size):

#     if env.mj_model.actuator(j).name == "right_knee_joint" or env.mj_model.actuator(j).name == "left_shoulder_roll_joint":
#       value = jp.sin(
#             state.data.time * 2 * jp.pi * f 
#         ) * 1.
#     else:
#       value = 0.
      
#     action.append( value)
#   action = jp.array(action)
#   state = jit_step(state, action)
#   rollout.append(state)
# frames = env.render(rollout)
# media.show_video(frames, fps=1.0 / env.dt)

In [None]:

from datetime import datetime

env_name = "arm_hello_world"
now = datetime.now()
timestamp = now.strftime("%Y%m%d-%H%M%S")
exp_name = f"{env_name}-{timestamp}"
import os

ckpt_path = os.path.abspath(os.path.join(".", "checkpoints", exp_name))
os.makedirs(ckpt_path, exist_ok=True)
print(f"Checkpoint path: {ckpt_path}")

In [None]:
from ml_collections import config_dict
ppo_params= config_dict.create(
    num_timesteps=1_000_000,
    reward_scaling=10.0,
    episode_length=env_cfg.episode_length,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=30,
    num_minibatches=32,
    num_updates_per_batch=16,
    discounting=0.995,
    learning_rate=1e-3,
    entropy_cost=1e-2,
    num_envs=512, #made these smaller just to see some logging feedback
    batch_size=256,#made these smaller just to see some logging feedback
    num_evals=0
)

x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress(num_steps, metrics):
  clear_output(wait=True)

  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics["eval/episode_reward"])
  y_dataerr.append(metrics["eval/episode_reward_std"])

  plt.xlim([0, ppo_params["num_timesteps"] * 1.25])
  plt.ylim([0, 1100])
  plt.xlabel("# environment steps")
  plt.ylabel("reward per episode")
  plt.title(f"y={y_data[-1]:.3f}")
  plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

  display(plt.gcf())

def progress_cli(num_steps, metrics):
  """Prints progress metrics to the console, including all available metrics."""

  # Print the current step number
  print(f"Step: {num_steps}")

  # Print the entire metrics dictionary for debugging
  print("Metrics:", metrics)

  # You can add a separator for clarity if you run this multiple times
  print("-" * 20)

ppo_training_params = dict(ppo_params)
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
  del ppo_training_params["network_factory"]
  network_factory = partial(
      ppo_networks.make_ppo_networks,
      **ppo_params.network_factory
  )

train_fn = partial(
    ppo.train, **dict(ppo_training_params),
    network_factory=network_factory,
    progress_fn=progress_cli,
    # policy_params_fn=policy_params_fn,
    log_training_metrics=True,
    save_checkpoint_path=ckpt_path
)

In [None]:
from mujoco_playground import wrapper

make_inference_fn, params, metrics = train_fn(
    environment=env,
    wrap_env_fn=wrapper.wrap_for_brax_training,
)

# ~6m11s

In [None]:
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))
rng = jax.random.PRNGKey(42)
rollout = []
n_episodes = 1

for _ in range(n_episodes):
  state = jit_reset(rng)
  rollout.append(state)
  for i in range(env_cfg.episode_length):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state)

render_every = 1
frames = env.render(rollout[::render_every])
rewards = [s.reward for s in rollout]
media.show_video(frames, fps=1.0 / env.dt / render_every)
# ~11s

In [None]:
## load policy from scratch
import g1hello
from functools import partial
import mediapy as media
from brax.training.agents.ppo import checkpoint as ppo_checkpoint
import jax
from pathlib import Path

relative_path = Path("./checkpoints/arm_hello_world-20250409-202838/000001228800")
policy_fn = ppo_checkpoint.load_policy(relative_path.resolve())

env = partial(g1hello.G1_23dof)()
env_cfg = g1hello.default_config()
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(policy_fn)
rng = jax.random.PRNGKey(42)
rollout = []
n_episodes = 1

for _ in range(n_episodes):
  state = jit_reset(rng)
  rollout.append(state)
  for i in range(200):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state)

render_every = 1
frames = env.render(rollout[::render_every])
rewards = [s.reward for s in rollout]


media.show_video(frames, fps=1.0 / env.dt / render_every)
# ~11s