In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import os

os.environ['__NV_PRIME_RENDER_OFFLOAD'] = '1'
os.environ['__GLX_VENDOR_LIBRARY_NAME'] = 'nvidia'
os.environ['MUJOCO_GL'] = 'egl'

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

import mujoco
from datetime import datetime
from functools import partial
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo

from IPython.display import  clear_output
import jax
from jax import numpy as jp
from matplotlib import pyplot as plt
import mediapy as media
from tqdm import tqdm
from randomize import domain_randomize
np.set_printoptions(precision=3, suppress=True, linewidth=100)


In [2]:

import balance
env = partial(balance.G1Env)()
env_cfg = balance.default_config()

# jit_reset = jax.jit(env.reset)
# jit_step = jax.jit(env.step)
# state = jit_reset(jax.random.PRNGKey(0))
# rollout = [state]

# f = 0.5

# for i in tqdm(range(200)):
#   action = []
#   for j in range(env.action_size):

#     if env.mj_model.actuator(j).name == "right_knee_joint" or env.mj_model.actuator(j).name == "left_shoulder_roll_joint":
#       value = jp.sin(
#             state.data.time * 2 * jp.pi * f 
#         ) * 1.
#     else:
#       value = 0.
      
#     action.append( value)
#   action = jp.array(action)
#   state = jit_step(state, action)
#   rollout.append(state)
# frames = env.render(rollout, camera="track")
# media.show_video(frames, fps=1.0 / env.dt)

In [None]:
from ml_collections import config_dict
# ppo_params= config_dict.create(
#     num_timesteps=1_000_000,
#     num_evals=10,
#     reward_scaling=10.0,
#     episode_length=env_cfg.episode_length,
#     normalize_observations=True,
#     action_repeat=1,
#     unroll_length=30,
#     num_minibatches=32,
#     num_updates_per_batch=16,
#     discounting=0.995,
#     learning_rate=1e-3,
#     entropy_cost=1e-2,
#     num_envs=2048,
#     batch_size=1024,
# )

from ml_collections import config_dict
ppo_params= config_dict.create(
    num_timesteps=60_000_000,
    reward_scaling=10.0,
    episode_length=env_cfg.episode_length,
    normalize_observations=True,
    action_repeat=1,
    unroll_length=30,
    num_minibatches=32,
    num_updates_per_batch=16,
    discounting=0.995,
    learning_rate=1e-3,
    entropy_cost=1e-2,
    num_envs=2048,
    batch_size=1024,
    num_evals=0,
    log_training_metrics=True
)
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]


def progress_cli(num_steps, metrics):
  """Prints progress metrics to the console, including all available metrics."""

  # Print the current step number
  print(f"Step: {num_steps}")

  # Print the entire metrics dictionary for debugging
  print("Metrics:", metrics)

  # You can add a separator for clarity if you run this multiple times
  print("-" * 20)

ppo_training_params = dict(ppo_params)
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
  del ppo_training_params["network_factory"]
  network_factory = partial(
      ppo_networks.make_ppo_networks,
      **ppo_params.network_factory
  )

train_fn = partial(
    ppo.train, **dict(ppo_training_params),
    network_factory=network_factory,
    progress_fn=progress_cli
)

In [None]:
from mujoco_playground import wrapper

make_inference_fn, params, metrics = train_fn(
    environment=env,
    wrap_env_fn=wrapper.wrap_for_brax_training,
    randomization_fn=domain_randomize,
)
# print(f"time to jit: {times[1] - times[0]}")
# print(f"time to train: {times[-1] - times[1]}")
# # ~6m11s

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


Step: 245760
Metrics: {'episode/length': np.float64(48.56), 'episode/reward/dof_pos_limits': np.float64(-1.1564310322701932), 'episode/reward/height': np.float64(3.138432297061663), 'episode/reward/orientation': np.float64(14.137529747486115), 'episode/reward/pose': np.float64(-5.937100163698196), 'episode/sum_reward': np.float64(0.20364861333742737)}
--------------------
Step: 491520
Metrics: {'episode/length': np.float64(42.27), 'episode/reward/dof_pos_limits': np.float64(-0.6049358803220093), 'episode/reward/height': np.float64(1.5855725535817329), 'episode/reward/orientation': np.float64(10.892314279079438), 'episode/reward/pose': np.float64(-2.643221758008003), 'episode/sum_reward': np.float64(0.18459457702934742)}
--------------------
Step: 737280
Metrics: {'episode/length': np.float64(43.74), 'episode/reward/dof_pos_limits': np.float64(-0.6541939178109168), 'episode/reward/height': np.float64(1.80118247309234), 'episode/reward/orientation': np.float64(11.423776710033417), 'episo

In [5]:
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))
rng = jax.random.PRNGKey(42)
rollout = []
n_episodes = 1

for _ in range(n_episodes):
  state = jit_reset(rng)
  rollout.append(state)
  for i in range(env_cfg.episode_length):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state)

render_every = 1
frames = env.render(rollout[::render_every])
rewards = [s.reward for s in rollout]
media.show_video(frames, fps=1.0 / env.dt / render_every)
# ~11s

 26%|██▋       | 263/1001 [00:19<00:54, 13.45it/s]


KeyboardInterrupt: 