In [1]:
#!/usr/bin/env python3
"""Visualization script for Go1 with height scanner."""
import os
import subprocess
# Tell XLA to use Triton GEMM
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
os.environ['MUJOCO_GL'] = 'egl'

import jax
import jax.numpy as jp
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import losses as ppo_losses

import mujoco
from mujoco_playground import wrapper
from mujoco_playground import registry
from mujoco_playground.config import locomotion_params
from custom_env import Joystick, default_config

from datetime import datetime
import mediapy as media
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import cv2
import functools
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
from IPython.display import Image as IPyimage, display, HTML, clear_output

from utils import render_video_during_training, evaluate_policy

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True   # Show visual geoms
scene_option.geomgroup[3] = False  # Hide collision geoms
scene_option.geomgroup[5] = True   # Show sites (including height scanner visualization)
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True  # Show contact points
scene_option.flags[mujoco.mjtVisFlag.mjVIS_RANGEFINDER] = True
print("Creating Visualization...")

xml_path = 'custom_env.xml' # 'custom_env_debug_wall.xml'
env = Joystick(xml_path=xml_path, config=default_config())

# JIT compile the functions for speed
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_terrain_height = jax.jit(env._get_torso_terrain_height)

seed = 1234
num_envs = ()
key = jax.random.PRNGKey(seed)
key, key_env, eval_key, key_policy, key_value = jax.random.split(key, 5)
key_envs = jax.random.split(key_env, num_envs)
env_state = jit_reset(key_envs)

Creating Visualization...


2025-09-30 15:35:36.596825: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


# Teacher Policy

- pretrained inside train.ipynb
- we want to load the parameters

- Inputs: privileged_state with heightmap
- Output: action

In [3]:
import flax
import optax
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.acme import specs
from brax.training.acme import running_statistics
from brax.training import types

obs_shape = (52,)
normalize = lambda x, y: x
action_size = env.action_size

ppo_params = locomotion_params.brax_ppo_config('Go1JoystickRoughTerrain')
ppo_training_params = dict(ppo_params)

ppo_training_params["num_timesteps"] = 1
print(ppo_training_params)

network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
    del ppo_training_params["network_factory"]
    network_factory = functools.partial(
        ppo_networks.make_ppo_networks,
        **ppo_params.network_factory
    )

ppo_network = network_factory(
    obs_shape, action_size, preprocess_observations_fn=normalize
)

init_params = ppo_losses.PPONetworkParams(
    policy=ppo_network.policy_network.init(key_policy),
    value=ppo_network.value_network.init(key_value),
)

make_policy = ppo_networks.make_inference_fn(ppo_network)

params = np.load("params.npy", allow_pickle=True)

jit_inference_fn = jax.jit(make_policy(params, deterministic=True))


{'action_repeat': 1, 'batch_size': 256, 'discounting': 0.97, 'entropy_cost': 0.01, 'episode_length': 1000, 'learning_rate': 0.0003, 'max_grad_norm': 1.0, 'network_factory': policy_hidden_layer_sizes: &id001 !!python/tuple
- 512
- 256
- 128
policy_obs_key: state
value_hidden_layer_sizes: *id001
value_obs_key: privileged_state
, 'normalize_observations': True, 'num_envs': 8192, 'num_evals': 10, 'num_minibatches': 32, 'num_resets_per_eval': 1, 'num_timesteps': 1, 'num_updates_per_batch': 4, 'reward_scaling': 1.0, 'unroll_length': 20}


# Student Policy

Experiment 1: training with a newly initialized LSTM for a Recurrent Neural Network


In [4]:
# Policy Definition
import jax.numpy as jnp
import jax
import flax.linen as nn

lstm = nn.RNN(nn.LSTMCell(features=64))
x = jnp.ones((10, 50, 52)) # (batch, time, features)
variables = lstm.init(jax.random.PRNGKey(0), x)


# Training


In [19]:
from custom_ppo_train import _maybe_wrap_env

seed = 1234

episodes = 1
envs_per_episode = 1
episode_length = 10
action_repeat = 1

learning_rate = 1e-5


## Training Loop

In [27]:

for episode in range(episodes):
    key = jax.random.PRNGKey(seed)
    key, key_env, eval_key, key_policy, key_value = jax.random.split(key, 5)

    wrapper_env = _maybe_wrap_env(
        env,
        wrap_env=True,
        num_envs=envs_per_episode,
        episode_length=episode_length,
        action_repeat=action_repeat,
        key_env=key_env,
    )

    reset_fn = jax.jit(env.reset)
    key_envs = jax.random.split(key_env, num_envs)
    env_state = reset_fn(key_envs)
    # print(f"ENV STATE: {env_state.done}")
    obs_shape = jax.tree_util.tree_map(lambda x: x.shape[1:], env_state.obs)
    # print(f"OBS STATE: {env_state.obs}")


    rng = jax.random.PRNGKey(0)
    command = jax.random.uniform(rng, shape=(3), minval=0.0, maxval=1.0)
    state = jit_reset(rng)
    state.info["command"] = command

    for _ in range(episode_length): # NOT DOING DONE: CHECK .done?
        act_rng, rng = jax.random.split(rng)
        ctrl, _ = jit_inference_fn(state.obs, act_rng)
        state = jit_step(state, ctrl)
        state.info["command"] = command
        
    # if normalize_observations:
    #     normalize = running_statistics.normalize

    # optimizer = optax.adam(learning_rate=learning_rate)

    # loss_fn = #

    # gradient_update_fn = #

In [21]:
jit_inference_fn

<PjitFunction of <function make_inference_fn.<locals>.make_policy.<locals>.policy at 0x7f0fc81071c0>>

In [None]:
 done = False
    
while not done:
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    done = True