# Teacher-Student Training Script 

(In progress)

In [1]:
#!/usr/bin/env python3
"""Visualization script for Go1 with height scanner."""
import os
import subprocess
# Tell XLA to use Triton GEMM
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
os.environ['MUJOCO_GL'] = 'egl'

import jax
import jax.numpy as jp
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import losses as ppo_losses

import mujoco
from mujoco_playground import wrapper
from mujoco_playground import registry
from mujoco_playground.config import locomotion_params
from custom_env import Joystick, default_config

from datetime import datetime
import mediapy as media
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import cv2
import functools
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
from IPython.display import Image as IPyimage, display, HTML, clear_output

from utils import render_video_during_training, evaluate_policy

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True   # Show visual geoms
scene_option.geomgroup[3] = False  # Hide collision geoms
scene_option.geomgroup[5] = True   # Show sites (including height scanner visualization)
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True  # Show contact points
scene_option.flags[mujoco.mjtVisFlag.mjVIS_RANGEFINDER] = True
print("Creating Visualization...")

xml_path = 'custom_env.xml' # 'custom_env_debug_wall.xml'
env = Joystick(xml_path=xml_path, config=default_config())

# JIT compile the functions for speed
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_terrain_height = jax.jit(env._get_torso_terrain_height)

seed = 1234
num_envs = ()
key = jax.random.PRNGKey(seed)
key, key_env, eval_key, key_policy, key_value = jax.random.split(key, 5)
key_envs = jax.random.split(key_env, num_envs)
env_state = jit_reset(key_envs)

Creating Visualization...


2025-10-05 15:54:14.099551: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


# Teacher Policy

- pretrained inside train.ipynb
- we want to load the parameters

- Inputs: privileged_state with heightmap
- Output: action

### Load parameters for pre-trained teacher


In [2]:
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import losses as ppo_losses
from brax.training.acme import running_statistics


# Needs to match training
obs_shape = (96,)
action_size = env.action_size

# Observation normalisation
loaded_params = np.load("params.npy", allow_pickle=True)
normalizer_params = loaded_params[0]  
normalize = running_statistics.normalize

# Setup
ppo_params = locomotion_params.brax_ppo_config('Go1JoystickRoughTerrain')
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
    network_factory = functools.partial(
        ppo_networks.make_ppo_networks,
        **ppo_params.network_factory
    )

ppo_network = network_factory(
    obs_shape, action_size, preprocess_observations_fn=normalize
)

init_params = ppo_losses.PPONetworkParams(
    policy=ppo_network.policy_network.init(key_policy),
    value=ppo_network.value_network.init(key_value),
)

# Create policy function
make_policy = ppo_networks.make_inference_fn(ppo_network)

params = np.load("params.npy", allow_pickle=True)

jit_inference_fn   = jax.jit(make_policy(params, deterministic=True))


# Training

- Teacher is run for one episode at a time with privileged observations
- Non-priveleged observations are saved as student input
- action distribution (logits) are saved as student targets
- RNN is trained with data from episode


In [3]:
from custom_ppo_train import _maybe_wrap_env

seed = 42

action_size = 12
hidden_size = 64

episodes = 1000 # preivously 1000
envs_per_episode = 1
episode_length = 1024 # preivously 1024
action_repeat = 1

learning_rate = 1e-4 # preivously 1e-5

teacher_visualisation = False

# Student Policy

Experiment 1: training with a newly initialized LSTM for a Recurrent Neural Network

In [4]:
# Policy Definition
import jax.numpy as jnp
import jax
import flax.linen as nn
import optax

class StudentPolicy(nn.Module):
    """Feedforward student policy that mimics teacher's action distribution"""
    action_size: int
    hidden_size: int   # Number of units in the hidden layer
    
    @nn.compact
    def __call__(self, x):
        # x shape: (batch_size, obs_dim)
        
        # First dense layer with activation
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)  # or use nn.tanh(x) if preferred

        # Second dense layer to output logits (actions + log_stds)
        logits = nn.Dense(features=2 * self.action_size)(x)
        # logits = nn.Dense(features=self.action_size)(x)
        
        
        return logits

# Initialize student network
student_obs_dim = 52
batch_size = 128
episode_length = 1024
batches = episode_length // batch_size

# Create student network
student_net = StudentPolicy(action_size = action_size, hidden_size = hidden_size)

# Initialize with dummy data
dummy_input = jnp.ones((batch_size, student_obs_dim))
key_student = jax.random.PRNGKey(42)
student_params = student_net.init(key_student, dummy_input)

print(f"Student network initialized!")
print(f"Input shape: {dummy_input.shape}")
print(f"Expected output shape: (batch_size, 24) for logits")

# Test student network
test_output = student_net.apply(student_params, dummy_input)
print(f"Test output shape: {test_output.shape}")

2025-10-05 15:55:15.061710: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-10-05 15:55:15.986165: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


Student network initialized!
Input shape: (128, 52)
Expected output shape: (batch_size, 24) for logits
Test output shape: (128, 24)


# Helper Functions for Teacher-Student Training
Define the functions needed to extract logits from the teacher network.

In [5]:
# Function to get logits from teacher network
@jax.jit
def get_teacher_logits(params, observations):
    param_subset = (params[0], params[1])  # normalizer and policy params
    return ppo_network.policy_network.apply(*param_subset, observations)

# Training setup
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(student_params)

# Training function
@jax.jit
def train_step(params, opt_state, inputs, targets):
    def loss_fn(params):
        predictions = student_net.apply(params, inputs)
        loss = jnp.mean((predictions - targets) ** 2)
        return loss
    
    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss


## Training Loop



In [6]:
from mujoco_playground._src.gait import draw_joystick_command
env_cfg = default_config()
env_cfg.pert_config.enable = True
env_cfg.pert_config.velocity_kick = [0.0, 0.0]
env_cfg.pert_config.kick_wait_times = [5.0, 15.0]
env_cfg.command_config.a = [1.5, 0.8, 2*jp.pi]


# Loop over episodes
training_losses = []
test_losses = []
for episode in range(episodes):
    print(f"\n Episode {episode + 1}/{episodes} ")
    
    key = jax.random.PRNGKey(seed + episode)
    key, key_env, eval_key, key_policy, key_value = jax.random.split(key, 5)

    wrapper_env = _maybe_wrap_env(
        env,
        wrap_env=True,
        num_envs=envs_per_episode,
        episode_length=episode_length,
        action_repeat=action_repeat,
        key_env=key_env,
    )

    # Reset environment 
    reset_fn = jax.jit(env.reset)
    key_envs = jax.random.split(key_env, num_envs)
    env_state = reset_fn(key_envs)

    rng = jax.random.PRNGKey(episode)
    # raw_command = jax.random.uniform(rng, shape=(3), minval=0.0, maxval=1.0)
    raw_command = jp.array([0.5, 0.0, 0.0]) 

    command = jp.array([
        raw_command[0] * env_cfg.command_config.a[0], 
        raw_command[1] * env_cfg.command_config.a[1],
        raw_command[2] * env_cfg.command_config.a[2] 
    ])
    state = jit_reset(rng)
    state.info["command"] = command

    # Visualisation storing
    rollout = []
    modify_scene_fns = []

    # Training data storing
    student_inputs = []
    student_targets = []
    
    for step in range(episode_length):
        # Get teacher action and logits
        act_rng, rng = jax.random.split(rng)
        ctrl, actions = jit_inference_fn(state.obs, act_rng)
        # print(f"control after inference: {ctrl}")

        # Get teacher logits (distribution parameters)
        param_subset = (params[0], params[1])
        logits = ppo_network.policy_network.apply(*param_subset, state.obs)

        # TO TEST IF CONVERSION WORKS
        # Convert logits to action
        # parametric_action_distribution = ppo_network.parametric_action_distribution
        # raw_action = parametric_action_distribution.sample_no_postprocessing(logits, rng)
        # ctrl = parametric_action_distribution.postprocess(raw_action)
        # ctrl = ppo_network.parametric_action_distribution.mode(logits) # This seems to be the correct way to do it
        # print(f"control after conversion back and fourth: {ctrl}")
        
        # Store data for student training
        student_inputs.append(state.obs['state'])  # Non-privileged observations
        student_targets.append(logits)  # Teacher's action distribution logits
        # student_targets.append(ctrl)  # Teacher's action distribution ctrl
        
        # Step environment
        state = jit_step(state, ctrl)
        state.info["command"] = command
        
        # Visualization data
        rollout.append(state)
        xyz = np.array(state.data.xpos[env._torso_body_id])
        xyz += np.array([0, 0, 0.2])
        x_axis = state.data.xmat[env._torso_body_id, 0]
        yaw = -np.arctan2(x_axis[1], x_axis[0])
        modify_scene_fns.append(
            functools.partial(
                draw_joystick_command,
                cmd=state.info["command"],
                xyz=xyz,
                theta=yaw,
                scl=abs(state.info["command"][0]) / env_cfg.command_config.a[0],
            )
        )

    # Prepare training data
    student_inputs_array = jnp.array(student_inputs)  # (1024, 52)
    student_targets_array = jnp.array(student_targets)  # (1024, 24)
        
    # Train student
    total_loss = 0
    for batch_idx in range(batches):
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        
        batch_inputs = student_inputs_array[start_idx:end_idx]
        batch_targets = student_targets_array[start_idx:end_idx]
        
        student_params, opt_state, loss = train_step(
            student_params, opt_state, batch_inputs, batch_targets
        )
        total_loss += loss
    training_losses.append(total_loss / batches)

    # Optional visualization
    if teacher_visualisation: 
        render_every = 2
        fps = 1.0 / env.dt / render_every
        traj = rollout[::render_every]
        mod_fns = modify_scene_fns[::render_every]

        scene_option = mujoco.MjvOption()
        scene_option.geomgroup[2] = True
        scene_option.geomgroup[3] = False
        scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
        scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
        scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True

        frames = env.render(
            traj,
            camera="track",
            scene_option=scene_option,
            width=640,
            height=480,
            modify_scene_fns=mod_fns,
        )   
        media.show_video(frames, fps=fps)




 Episode 1/1000 


2025-10-05 11:23:22.641427: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-10-05 11:23:22.641459: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-10-05 11:23:22.641468: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.



 Episode 2/1000 

 Episode 3/1000 

 Episode 4/1000 

 Episode 5/1000 

 Episode 6/1000 

 Episode 7/1000 

 Episode 8/1000 

 Episode 9/1000 

 Episode 10/1000 

 Episode 11/1000 

 Episode 12/1000 

 Episode 13/1000 

 Episode 14/1000 

 Episode 15/1000 

 Episode 16/1000 

 Episode 17/1000 

 Episode 18/1000 

 Episode 19/1000 

 Episode 20/1000 

 Episode 21/1000 

 Episode 22/1000 

 Episode 23/1000 

 Episode 24/1000 

 Episode 25/1000 

 Episode 26/1000 

 Episode 27/1000 

 Episode 28/1000 

 Episode 29/1000 

 Episode 30/1000 

 Episode 31/1000 

 Episode 32/1000 


### Plot training metrics


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12,6))
plt.plot(range(1, len(training_losses) + 1), training_losses)
plt.xlabel("Episode")
plt.ylabel("MSE Loss")
plt.legend(("Training Loss"))
plt.show()

NameError: name 'training_losses' is not defined

<Figure size 1200x600 with 0 Axes>

## Save RNN


In [None]:
# Save student RNN parameters
import pickle

# Save
with open('student_params_MLP_MSE_1000episodes.pkl', 'wb') as f:
    pickle.dump(student_params, f)
print("Student saved to student_params_MLP_MSE_1000episodes.pkl")

Student RNN saved to student_params.pkl


### (OPTIONAL): Load Trained RNN


In [6]:
import pickle
# Load student RNN parameters
with open('student_params_MLP_MSE_1000episodes.pkl', 'rb') as f:
    student_params = pickle.load(f)
print("Student loaded from student_params_MLP_MSE_1000episodes.pkl")

Student loaded from student_params_MLP_MSE_1000episodes.pkl


## Evaluate Student Polcicy

### Evaluation Environment Config


In [7]:
# Env config
student_env_cfg = default_config()
student_env_cfg.pert_config.enable = True
student_env_cfg.pert_config.velocity_kick = [0.0, 0.0]
student_env_cfg.pert_config.kick_wait_times = [5.0, 15.0]
student_env_cfg.command_config.a = [1.5, 0.8, 2*jp.pi] # Max command values

seed = 42

num_episodes = 5
episode_length = 500 # previously 500


In [9]:
# Student eval
from mujoco_playground._src.gait import draw_joystick_command

for episode in range(num_episodes):
    key = jax.random.PRNGKey(episode + seed)
    key, key_env, eval_key, key_policy, key_value = jax.random.split(key, 5)

    wrapper_env = _maybe_wrap_env(
    env,
    wrap_env = True,
    num_envs = 1,
    episode_length = episode_length,
    action_repeat = 1,
    key_env = key_env,
    )

    # Reset environment 
    reset_fn = jax.jit(env.reset)
    key_envs = jax.random.split(key_env, num_envs)
    env_state = reset_fn(key_envs)

    
    # Set commands 
    rng = jax.random.PRNGKey(episode)
    # raw_command = jax.random.uniform(rng, shape=(3), minval=0.0, maxval=1.0)
    raw_command = jp.array([0.5, 0.0, 0.0])  # Hard coded for testing

    command = jp.array([
        raw_command[0] * student_env_cfg.command_config.a[0], 
        raw_command[1] * student_env_cfg.command_config.a[1],
        raw_command[2] * student_env_cfg.command_config.a[2] 
    ])
    state = jit_reset(rng)
    state.info["command"] = command

    # Visualisation storing
    rollout = []
    modify_scene_fns = []

    for step in range(episode_length):
        # feed non-priveleged observations to network
        student_obs = state.obs['state'] 

        # Get student logits and convert to action
        student_obs_batch = student_obs[:student_obs_dim].reshape(1, -1)
        logits = student_net.apply(student_params, student_obs_batch).squeeze(0)
        ctrl = ppo_network.parametric_action_distribution.mode(logits)

        # ctrl = student_net.apply(student_params, student_obs_batch).squeeze(0)

        # Take step
        state = jit_step(state, ctrl)
        state.info["command"] = command

        # Visualization magic
        rollout.append(state)
        xyz = np.array(state.data.xpos[env._torso_body_id])
        xyz += np.array([0, 0, 0.2])
        x_axis = state.data.xmat[env._torso_body_id, 0]
        yaw = -np.arctan2(x_axis[1], x_axis[0])
        modify_scene_fns.append(
            functools.partial(
                draw_joystick_command,
                cmd=state.info["command"],
                xyz=xyz,
                theta=yaw,
                scl=abs(state.info["command"][0]) / student_env_cfg.command_config.a[0],
            )
        )

    # Display visualisation magic
    render_every = 2
    fps = 1.0 / env.dt / render_every
    traj = rollout[::render_every]
    mod_fns = modify_scene_fns[::render_every]

    scene_option = mujoco.MjvOption()
    scene_option.geomgroup[2] = True
    scene_option.geomgroup[3] = False
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True

    frames = env.render(
        traj,
        camera="track",
        scene_option=scene_option,
        width=640,
        height=480,
        modify_scene_fns=mod_fns,
    )   
    media.show_video(frames, fps=fps)

100%|██████████| 250/250 [00:46<00:00,  5.34it/s]


0
This browser does not support the video tag.


100%|██████████| 250/250 [00:50<00:00,  4.93it/s]


0
This browser does not support the video tag.


100%|██████████| 250/250 [00:50<00:00,  4.92it/s]


0
This browser does not support the video tag.


100%|██████████| 250/250 [00:51<00:00,  4.90it/s]


0
This browser does not support the video tag.


100%|██████████| 250/250 [00:51<00:00,  4.82it/s]


0
This browser does not support the video tag.
