In [None]:
#!/usr/bin/env python3
"""Visualization script for Go1 with height scanner."""
import os
import subprocess
# Tell XLA to use Triton GEMM
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
os.environ['MUJOCO_GL'] = 'egl'

import jax
import jax.numpy as jp
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import losses as ppo_losses

import mujoco
from mujoco_playground import wrapper
from mujoco_playground import registry
from mujoco_playground.config import locomotion_params
from custom_env import Joystick, default_config

from datetime import datetime
import mediapy as media
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import cv2
import functools
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
from IPython.display import Image as IPyimage, display, HTML, clear_output

from utils import render_video_during_training, evaluate_policy

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True   # Show visual geoms
scene_option.geomgroup[3] = False  # Hide collision geoms
scene_option.geomgroup[5] = True   # Show sites (including height scanner visualization)
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True  # Show contact points
scene_option.flags[mujoco.mjtVisFlag.mjVIS_RANGEFINDER] = True
print("Creating Visualization...")

xml_path = 'custom_env.xml' # 'custom_env_debug_wall.xml'
env = Joystick(xml_path=xml_path, config=default_config())

# JIT compile the functions for speed
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_terrain_height = jax.jit(env._get_torso_terrain_height)

seed = 1234
num_envs = ()
key = jax.random.PRNGKey(seed)
key, key_env, eval_key, key_policy, key_value = jax.random.split(key, 5)
key_envs = jax.random.split(key_env, num_envs)
env_state = jit_reset(key_envs)

obs_size = jax.tree.map(lambda x: x.shape[1:], env_state.obs)
action_size = env.action_size

# Teacher Policy

- pretrained inside train.ipynb
- we want to load the parameters

- Inputs: privileged_state with heightmap
- Output: action

In [None]:
import functools, jax, jax.numpy as jnp, numpy as np
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as ppo_networks
from Teacher_Student.teacherOracle import TeacherOracle

# 1) Load saved teacher params (normalizer, policy, value)
teacher_normalizer, teacher_policy_params, teacher_value_params = np.load(
    # Use teacher_params from train_teacher.ipynb since 
    # policy network is trained on privileged_state
    'teacher_params.npy', allow_pickle=True
).tolist()

# 2) Infer the teacher’s hidden sizes from saved params
pp = teacher_policy_params
k0 = pp['params']['hidden_0']['kernel'].shape  # (in_dim, hidden0)
policy_hidden = (
    int(pp['params']['hidden_0']['kernel'].shape[1]),
    int(pp['params']['hidden_1']['kernel'].shape[1]),
    int(pp['params']['hidden_2']['kernel'].shape[1]),
)
in_dim = int(k0[0])

# 3) Build observation_size mapping from the normalizer (robust even if obs_size is ()):
#    This ensures networks are initialized with the exact trained feature dims.
teacher_obs_sizes = {k: tuple(teacher_normalizer.mean[k].shape)
                     for k in teacher_normalizer.mean}

# Select which key the teacher used by matching input dim
teacher_obs_key = None
for k, shape in teacher_obs_sizes.items():
    if len(shape) == 1 and int(shape[0]) == in_dim:
        teacher_obs_key = k
        break
if teacher_obs_key is None:
    # Fallback: prefer 'state' if present
    teacher_obs_key = 'state' if 'state' in teacher_obs_sizes else next(iter(teacher_obs_sizes.keys()))

print('Teacher obs_key:', teacher_obs_key, 'hidden:', policy_hidden)

# 4) Rebuild teacher networks to match saved shapes
teacher_nets = ppo_networks.make_ppo_networks(
    observation_size=teacher_obs_sizes,          # use sizes from normalizer, not env obs
    action_size=env.action_size,
    preprocess_observations_fn=running_statistics.normalize,
    policy_obs_key=teacher_obs_key,
    value_obs_key=teacher_obs_key,
    policy_hidden_layer_sizes=policy_hidden,
    distribution_type='tanh_normal',
    noise_std_type='log',
)
make_teacher_policy = ppo_networks.make_inference_fn(teacher_nets)

# 5) Create the oracle (uses the teacher’s stochastic policy by default)
teacher = TeacherOracle(
    make_policy=make_teacher_policy,
    normalizer_params=teacher_normalizer,
    policy_params=teacher_policy_params,
    value_params=teacher_value_params,
)

# Quick check
a, ex = teacher.act(env_state.obs, jax.random.PRNGKey(0))
print('action shape:', a.shape, 'has raw_action:', 'raw_action' in ex)


# Student Policy

Student policy using Teacher_Student/studentPolicy.py with tanh-Normal head.


In [None]:
# Student policy built via Teacher_Student.studentPolicy
import jax, jax.numpy as jnp
from flax import linen
from brax.training.acme import running_statistics
from Teacher_Student.studentPolicy import make_student_policy, init_student_normalizer

# Derive student obs size robustly even if no batch axis is present
student_obs_dim = int(env_state.obs['state'].shape[-1])
student_obs_sizes = {'state': (student_obs_dim,)}

student = make_student_policy(
    obs_size=student_obs_sizes,
    action_size=env.action_size,
    preprocess_fn=running_statistics.normalize,
    hidden=(128, 128),
    activation=linen.swish,
    obs_key='state',
)

key, key_student = jax.random.split(key)
student_policy_params = student.policy.init(key_student)
student_normalizer_params = init_student_normalizer(env_state.obs['state'].shape[-1])

def make_student_inference_fn(student):
    def make_policy(params, deterministic=False):
        def policy(observations, key_sample):
            logits = student.policy.apply(params[0], params[1], observations)
            if deterministic:
                return student.parametric.mode(logits), {}
            raw = student.parametric.sample_no_postprocessing(logits, key_sample)
            logp = student.parametric.log_prob(logits, raw)
            act = student.parametric.postprocess(raw)
            return act, {'log_prob': logp, 'raw_action': raw}
        return policy
    return make_policy

make_student_policy_inference = make_student_inference_fn(student)


In [None]:
import optax
from brax.training.acme import running_statistics

optimizer = optax.adam(learning_rate=1e-4)
opt_state = optimizer.init(student_policy_params)
entropy_coef = 1e-3

teacher_dist = teacher_nets.parametric_action_distribution

def distill_step_loss(student_params, student_norm, obs, rng):
    t_logits = teacher_nets.policy_network.apply(teacher_normalizer, teacher_policy_params, obs)
    s_logits = student.policy.apply(student_norm, student_params, obs)
    rng, rng_sample, rng_ent = jax.random.split(rng, 3)
    raw_a = teacher_dist.sample_no_postprocessing(t_logits, rng_sample)
    logp_t = teacher_dist.log_prob(t_logits, raw_a)
    logp_s = student.parametric.log_prob(s_logits, raw_a)
    kl = jnp.mean(logp_t - logp_s)
    ent_s = jnp.mean(student.parametric.entropy(s_logits, rng_ent))
    loss = kl - entropy_coef * ent_s
    return loss, {'kl': kl, 'entropy': ent_s}

@jax.jit
def distill_step(student_params, student_norm, obs, rng, opt_state):
    # Use value_and_grad to get both loss and gradients in one pass
    (loss, metrics), grads = jax.value_and_grad(distill_step_loss, has_aux=True)(
        student_params, student_norm, obs, rng
    )
    updates, opt_state = optimizer.update(grads, opt_state, student_params)
    new_params = optax.apply_updates(student_params, updates)
    return new_params, opt_state, metrics

# Training


In [None]:
from custom_ppo_train import _maybe_wrap_env

seed = 1234

episodes = 1
envs_per_episode = 1
episode_length = 10
action_repeat = 1

learning_rate = 1e-5


## Training Loop

In [None]:

episodes = 25
steps_per_episode = 100000000

rng = jax.random.PRNGKey(seed)
reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)

episode_mean_kls = []
episode_mean_entropy = []

for ep in range(episodes):
    rng, key_env = jax.random.split(rng)
    state = reset_fn(key_env)
    rng, key_cmd = jax.random.split(rng)
    command = jax.random.uniform(key_cmd, shape=(3,), minval=0.0, maxval=1.0)
    state.info['command'] = command
    ep_kls = []
    ep_ents = []
    for t in range(steps_per_episode):
        rng, key_act, key_distill = jax.random.split(rng, 3)
        action, _ = teacher.act(state.obs, key_act)
        next_state = step_fn(state, action)
        next_state.info['command'] = command
        student_normalizer_params = running_statistics.update(
            student_normalizer_params, {'state': state.obs['state']}
        )
        student_policy_params, opt_state, m = distill_step(
            student_policy_params, student_normalizer_params, state.obs, key_distill, opt_state
        )
        ep_kls.append(m['kl'])
        ep_ents.append(m['entropy'])
        state = next_state
    mean_kl = float(jnp.mean(jnp.array(ep_kls)))
    mean_ent = float(jnp.mean(jnp.array(ep_ents)))
    episode_mean_kls.append(mean_kl)
    episode_mean_entropy.append(mean_ent)
    if ep % 1000 == 0:
        print(f'Episode {ep}: mean KL={mean_kl:.4f} | mean entropy={mean_ent:.4f}')


In [None]:
# Plot mean KL and entropy per episode
import matplotlib.pyplot as plt
plt.figure(figsize=(6,4))
plt.plot(episode_mean_kls, label='Mean KL per episode')
plt.plot(episode_mean_entropy, label='Mean entropy per episode')
plt.xlabel('Episode')
plt.ylabel('Value')
plt.title('Distillation Progress')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()


In [None]:
# After training, evaluate the comparison
from utils import evaluate_policy_comparison

# Create inference functions
teacher_inference_fn = jax.jit(teacher._make_policy((teacher_normalizer, teacher_policy_params), deterministic=True))
student_inference_fn = jax.jit(make_student_policy_inference((student_normalizer_params, student_policy_params), deterministic=True))

# Evaluate side by side
evaluate_policy_comparison(
    env=env,
    teacher_inference_fn=teacher_inference_fn,
    student_inference_fn=student_inference_fn,
    jit_step=jit_step,
    jit_reset=jit_reset,
    env_cfg=default_config(),
    eval_env=env,
)