In [9]:
#!/usr/bin/env python3
"""Visualization script for Go1 with height scanner."""
import os
import subprocess
# Tell XLA to use Triton GEMM
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
os.environ['MUJOCO_GL'] = 'egl'

import jax
import jax.numpy as jp
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import losses as ppo_losses

import mujoco
from mujoco_playground import wrapper
from mujoco_playground import registry
from mujoco_playground.config import locomotion_params
from custom_env import Joystick, default_config

from datetime import datetime
import mediapy as media
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import cv2
import functools
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
from IPython.display import Image as IPyimage, display, HTML, clear_output

from utils import render_video_during_training, evaluate_policy

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True   # Show visual geoms
scene_option.geomgroup[3] = False  # Hide collision geoms
scene_option.geomgroup[5] = True   # Show sites (including height scanner visualization)
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True  # Show contact points
scene_option.flags[mujoco.mjtVisFlag.mjVIS_RANGEFINDER] = True
print("Creating Visualization...")

xml_path = 'custom_env.xml' # 'custom_env_debug_wall.xml'
env = Joystick(xml_path=xml_path, config=default_config())

# JIT compile the functions for speed
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_terrain_height = jax.jit(env._get_torso_terrain_height)

seed = 1234
num_envs = ()
key = jax.random.PRNGKey(seed)
key, key_env, eval_key, key_policy, key_value = jax.random.split(key, 5)
key_envs = jax.random.split(key_env, num_envs)
env_state = jit_reset(key_envs)

obs_size = jax.tree.map(lambda x: x.shape[1:], env_state.obs)
action_size = env.action_size

Creating Visualization...


# Teacher Policy

- pretrained inside train.ipynb
- we want to load the parameters

- Inputs: privileged_state with heightmap
- Output: action

In [None]:
import functools, jax, jax.numpy as jnp, numpy as np
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as ppo_networks
from Teacher_Student.teacherOracle import TeacherOracle

# 1) Load saved teacher params (normalizer, policy, value)
teacher_normalizer, teacher_policy_params, teacher_value_params = np.load(
    # Use teacher_params from train_teacher.ipynb since 
    # policy network is trained on privileged_state
    'teacher_params.npy', allow_pickle=True
).tolist()

# 2) Infer the teacher’s hidden sizes from saved params
pp = teacher_policy_params
k0 = pp['params']['hidden_0']['kernel'].shape  # (in_dim, hidden0)
policy_hidden = (
    int(pp['params']['hidden_0']['kernel'].shape[1]),
    int(pp['params']['hidden_1']['kernel'].shape[1]),
    int(pp['params']['hidden_2']['kernel'].shape[1]),
)
in_dim = int(k0[0])

# 3) Build observation_size mapping from the normalizer (robust even if obs_size is ()):
#    This ensures networks are initialized with the exact trained feature dims.
teacher_obs_sizes = {k: tuple(teacher_normalizer.mean[k].shape)
                     for k in teacher_normalizer.mean}

# Select which key the teacher used by matching input dim
teacher_obs_key = None
for k, shape in teacher_obs_sizes.items():
    if len(shape) == 1 and int(shape[0]) == in_dim:
        teacher_obs_key = k
        break
if teacher_obs_key is None:
    # Fallback: prefer 'state' if present
    teacher_obs_key = 'state' if 'state' in teacher_obs_sizes else next(iter(teacher_obs_sizes.keys()))

print('Teacher obs_key:', teacher_obs_key, 'hidden:', policy_hidden)

# 4) Rebuild teacher networks to match saved shapes
teacher_nets = ppo_networks.make_ppo_networks(
    observation_size=teacher_obs_sizes,          # use sizes from normalizer, not env obs
    action_size=env.action_size,
    preprocess_observations_fn=running_statistics.normalize,
    policy_obs_key=teacher_obs_key,
    value_obs_key=teacher_obs_key,
    policy_hidden_layer_sizes=policy_hidden,
    distribution_type='tanh_normal',
    noise_std_type='log',
)
make_teacher_policy = ppo_networks.make_inference_fn(teacher_nets)

# 5) Create the oracle (uses the teacher’s stochastic policy by default)
teacher = TeacherOracle(
    make_policy=make_teacher_policy,
    normalizer_params=teacher_normalizer,
    policy_params=teacher_policy_params,
    value_params=teacher_value_params,
)

# Quick check
a, ex = teacher.act(env_state.obs, jax.random.PRNGKey(0))
print('action shape:', a.shape, 'has raw_action:', 'raw_action' in ex)


Teacher obs_key: privileged_state hidden: (512, 256, 128)
action shape: (12,) has raw_action: False


# Student Policy

Experiment 1: training with a newly initialized LSTM for a Recurrent Neural Network


In [None]:
# Policy Definition
import jax.numpy as jnp
import jax
import flax.linen as nn

lstm = nn.RNN(nn.LSTMCell(features=64))
x = jnp.ones((10, 50, 52)) # (batch, time, features)
variables = lstm.init(jax.random.PRNGKey(0), x)


# Training


In [None]:
from custom_ppo_train import _maybe_wrap_env

seed = 1234

episodes = 1
envs_per_episode = 1
episode_length = 10
action_repeat = 1

learning_rate = 1e-5


## Training Loop

In [None]:

for episode in range(episodes):
    key = jax.random.PRNGKey(seed)
    key, key_env, eval_key, key_policy, key_value = jax.random.split(key, 5)

    wrapper_env = _maybe_wrap_env(
        env,
        wrap_env=True,
        num_envs=envs_per_episode,
        episode_length=episode_length,
        action_repeat=action_repeat,
        key_env=key_env,
    )

    reset_fn = jax.jit(env.reset)
    key_envs = jax.random.split(key_env, num_envs)
    env_state = reset_fn(key_envs)
    # print(f"ENV STATE: {env_state.done}")
    obs_shape = jax.tree_util.tree_map(lambda x: x.shape[1:], env_state.obs)
    # print(f"OBS STATE: {env_state.obs}")


    rng = jax.random.PRNGKey(0)
    command = jax.random.uniform(rng, shape=(3), minval=0.0, maxval=1.0)
    state = jit_reset(rng)
    state.info["command"] = command

    for _ in range(episode_length): # NOT DOING DONE: CHECK .done?
        act_rng, rng = jax.random.split(rng)
        action, _ = teacher.act(state.obs, act_rng)
        state = jit_step(state, action)
        state.info["command"] = command
        
    # if normalize_observations:
    #     normalize = running_statistics.normalize

    # optimizer = optax.adam(learning_rate=learning_rate)

    # loss_fn = #

    # gradient_update_fn = #

In [None]:
jit_inference_fn

In [None]:
 done = False
    
while not done:
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    done = True