In [None]:
#!/usr/bin/env python3
"""Training and visualization script for Go1 with height scanner, including student distillation."""
import os
import jax
import jax.numpy as jnp
import numpy as np
import functools
import optax
import flax.linen as nn
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import losses as ppo_losses
from brax.training.acme import running_statistics
import mujoco
from mujoco_playground import wrapper
from mujoco_playground.config import locomotion_params
from custom_env import Joystick, default_config
from mujoco_playground._src.gait import draw_joystick_command
from IPython.display import HTML, display
import mediapy as media
import imageio
import base64

# Set environment variables for GPU usage
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
os.environ['MUJOCO_GL'] = 'egl'

# Verify GPU usage
print("JAX Devices:", jax.devices())

# Environment setup
xml_path = 'custom_env.xml'
env = Joystick(xml_path=xml_path, config=default_config())
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)


env_cfg = default_config()
env_cfg.pert_config.enable = True
env_cfg.pert_config.velocity_kick = [0.0, 0.0]
env_cfg.pert_config.kick_wait_times = [5.0, 15.0]
env_cfg.command_config.a = [1.5, 0.8, 2 * jnp.pi]

# Training configuration
seed = 42
num_envs = 1  # Single environment for simplicity
episode_length = 128
action_repeat = 1
episodes = 100
batch_size = 32  # Adjusted for multiple batches
batches = episode_length // batch_size
learning_rate = 1e-5
obs_shape = (52,)  # Matches student_obs_dim
action_size = env.action_size
student_obs_dim = 52  # From state.obs['state']

# Teacher network setup
loaded_params = np.load("params.npy", allow_pickle=True)
normalizer_params = loaded_params[0]
policy_params = loaded_params[1]
value_params = loaded_params[2]
teacher_params = (normalizer_params, policy_params, value_params)
normalize = running_statistics.normalize
ppo_params = locomotion_params.brax_ppo_config('Go1JoystickRoughTerrain')
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
    network_factory = functools.partial(ppo_networks.make_ppo_networks, **ppo_params.network_factory)
ppo_network = network_factory(obs_shape, action_size, preprocess_observations_fn=normalize)
make_policy = ppo_networks.make_inference_fn(ppo_network)
# Teacher inference expects flat state vector (52,)
jit_inference_fn = jax.jit(make_policy(teacher_params, deterministic=True))

JAX Devices: [CudaDevice(id=0)]


In [None]:
# Student network definition
class StudentPolicy(nn.Module):
    action_size: int
    hidden_size: int = 100
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_size)(x)
        x = nn.relu(x)
        logits = nn.Dense(features=2 * self.action_size)(x)  # Means and log_stds
        return logits

# Initialize student network
student_net = StudentPolicy(action_size=action_size)

dummy_input = jnp.ones((batch_size, student_obs_dim))
key_student = jax.random.PRNGKey(42)
student_params = student_net.init(key_student, dummy_input)
optimizer = optax.adamw(learning_rate)
opt_state = optimizer.init(student_params)
print(f"Student network initialized! Input shape: {dummy_input.shape}, Output shape: {(batch_size, 2 * action_size)}")

# Evaluation function
def evaluate_policy(env, policy_fn, key, steps=episode_length):
    state = jit_reset(key)
    total_reward = 0.0
    for _ in range(steps):
        key, act_key = jax.random.split(key)
        obs = state.obs['state']
        # Add batch dimension for the network
        obs_batch = obs.reshape(1, -1)
        # Normalize with teacher's normalizer to mirror teacher preprocessing
        obs_batch = running_statistics.normalize(normalizer_params, obs_batch)
        logits = policy_fn(obs_batch, act_key)
        # Convert mean logits to actions via tanh (teacher deterministic action semantics)
        mu = logits[0, :action_size]
        actions = jnp.tanh(mu)
        state = jit_step(state, actions)
        total_reward += state.reward
    return float(total_reward)

2025-10-02 17:58:30.836145: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-10-02 17:58:31.740986: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-10-02 17:58:32.560608: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


Student network initialized! Input shape: (32, 52), Output shape: (32, 24)


In [None]:
# Comparison visualization function
def compare_teacher_student_gifs(
    env,
    jit_reset,
    jit_step,
    teacher_policy_fn,
    student_policy_fn,
    student_params,
    episode_length,
    command,
    seed,
    width=640,
    height=480,
    fps=30,
    render_every=2,
):
    scene_option = mujoco.MjvOption()
    scene_option.geomgroup[2] = True
    scene_option.geomgroup[3] = False
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True

    key = jax.random.PRNGKey(seed)
    key_teacher, key_student, key_env = jax.random.split(key, 3)

    state_teacher = jit_reset(key_env)
    state_student = jit_reset(key_env)
    state_teacher.info["command"] = command
    state_student.info["command"] = command

    rollout_teacher = []
    rollout_student = []
    modify_scene_fns_teacher = []
    modify_scene_fns_student = []

    for step in range(episode_length):
        # Teacher
        act_rng_teacher, key_teacher = jax.random.split(key_teacher)
        # Pass only 'state' observation to teacher inference (matches obs_shape=(52,))
        ctrl_teacher, _ = teacher_policy_fn(state_teacher.obs['state'], act_rng_teacher)
        state_teacher = jit_step(state_teacher, ctrl_teacher)
        state_teacher.info["command"] = command
        rollout_teacher.append(state_teacher)

        # Student
        act_rng_student, key_student = jax.random.split(key_student)
        student_obs = state_student.obs['state'].reshape(1, -1)
        # Normalize student inputs with teacher's normalizer for consistency
        student_obs = running_statistics.normalize(normalizer_params, student_obs)
        student_logits = student_policy_fn(student_obs, act_rng_student)
        mu = student_logits[0, :env.action_size]
        ctrl_student = jnp.tanh(mu)
        state_student = jit_step(state_student, ctrl_student)
        state_student.info["command"] = command
        rollout_student.append(state_student)

        for state, modify_scene_fns in [
            (state_teacher, modify_scene_fns_teacher),
            (state_student, modify_scene_fns_student),
        ]:
            xyz = np.array(state.data.xpos[env._torso_body_id])
            xyz += np.array([0, 0, 0.2])
            x_axis = state.data.xmat[env._torso_body_id, 0]
            yaw = -np.arctan2(x_axis[1], x_axis[0])
            modify_scene_fns.append(
                functools.partial(
                    draw_joystick_command,
                    cmd=state.info["command"],
                    xyz=xyz,
                    theta=yaw,
                    scl=abs(state.info["command"][0]) / env_cfg.command_config.a[0],
                )
            )

    traj_teacher = rollout_teacher[::render_every]
    traj_student = rollout_student[::render_every]
    mod_fns_teacher = modify_scene_fns_teacher[::render_every]
    mod_fns_student = modify_scene_fns_student[::render_every]

    frames_teacher = env.render(
        traj_teacher,
        camera="track",
        scene_option=scene_option,
        width=width,
        height=height,
        modify_scene_fns=mod_fns_teacher,
    )
    frames_student = env.render(
        traj_student,
        camera="track",
        scene_option=scene_option,
        width=width,
        height=height,
        modify_scene_fns=mod_fns_student,
    )

    teacher_gif_path = "teacher_policy.gif"
    student_gif_path = "student_policy.gif"
    media.write_video(frames_teacher, teacher_gif_path, fps=fps)
    media.write_video(frames_student, student_gif_path, fps=fps)

    def gif_to_base64(gif_path):
        with open(gif_path, "rb") as f:
            encoded = base64.b64encode(f.read()).decode("ascii")
        return f"data:image/gif;base64,{encoded}"

    teacher_base64 = gif_to_base64(teacher_gif_path)
    student_base64 = gif_to_base64(student_gif_path)
    html = f"""
    <div style="display: flex; justify-content: center;">
        <div style="margin-right: 10px; text-align: center;">
            <h3>Teacher Policy</h3>
            <img src="{teacher_base64}" width="{width}" height="{height}"/>
        </div>
        <div style="text-align: center;">
            <h3>Student Policy</h3>
            <img src="{student_base64}" width="{width}" height="{height}"/>
        </div>
    </div>
    """
    display(HTML(html))
    os.remove(teacher_gif_path)
    os.remove(student_gif_path)

In [None]:
# Function to get teacher logits
@jax.jit
def get_teacher_logits(observations):
    param_subset = (teacher_params[0], teacher_params[1])
    return ppo_network.policy_network.apply(*param_subset, observations)

# Training function with MSE loss
@jax.jit
def train_step(params, opt_state, inputs, targets):
    def loss_fn(params):
        predictions = student_net.apply(params, inputs)
        loss = jnp.mean((predictions - targets) ** 2)
        return loss
    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Training loop
training_losses = []

key = jax.random.PRNGKey(seed)
for episode in range(episodes):
    print(f"\nEpisode {episode + 1}/{episodes}")

    key, env_key, act_key = jax.random.split(key, 3)
    state = jit_reset(env_key)

    raw_command = jax.random.uniform(act_key, shape=(3,), minval=0.0, maxval=1.0)
    command = jnp.array([
        raw_command[0] * env_cfg.command_config.a[0],
        raw_command[1] * env_cfg.command_config.a[1],
        raw_command[2] * env_cfg.command_config.a[2]
    ])

    state.info["command"] = command
    student_inputs = jnp.zeros((episode_length, student_obs_dim))
    student_targets = jnp.zeros((episode_length, 2 * action_size))
    rollout = []
    modify_scene_fns = []
    
    for step in range(episode_length):
        act_rng, act_key = jax.random.split(act_key)
        # Teacher deterministic action using 'state' only
        ctrl, _ = jit_inference_fn(state.obs['state'], act_rng)
        # Teacher logits target using 'state' only to match policy net input
        logits = get_teacher_logits(state.obs['state'])
        # Normalize student inputs like teacher
        state_flat = state.obs['state']
        state_flat_norm = running_statistics.normalize(normalizer_params, state_flat)
        student_inputs = student_inputs.at[step].set(state_flat_norm)
        student_targets = student_targets.at[step].set(logits)
        state = jit_step(state, ctrl)
        state.info["command"] = command
        rollout.append(state)
        xyz = np.array(state.data.xpos[env._torso_body_id])
        xyz += np.array([0, 0, 0.2])
        x_axis = state.data.xmat[env._torso_body_id, 0]
        yaw = -np.arctan2(x_axis[1], x_axis[0])
        modify_scene_fns.append(
            functools.partial(
                draw_joystick_command,
                cmd=state.info["command"],
                xyz=xyz,
                theta=yaw,
                scl=abs(state.info["command"][0]) / env_cfg.command_config.a[0],
            )
        )
    total_loss = 0.0
    for batch_idx in range(batches):
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        batch_inputs = student_inputs[start_idx:end_idx]
        batch_targets = student_targets[start_idx:end_idx]
        student_params, opt_state, loss = train_step(student_params, opt_state, batch_inputs, batch_targets)
        total_loss += loss
    avg_loss = total_loss / batches
    training_losses.append(avg_loss)
    print(f"Training Loss: {avg_loss}")
    # Student inference function: expects normalized input
    student_policy_fn = jax.jit(lambda obs, rng: student_net.apply(student_params, obs))
    # Evaluate with normalization + tanh(mean)
    eval_reward = evaluate_policy(env, student_policy_fn, act_key)
    print(f"Student Eval Reward: {eval_reward}")


Episode 1/100


2025-10-02 17:58:37.908916: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


In [None]:
compare_teacher_student_gifs(
    env=env,
    jit_reset=jit_reset,
    jit_step=jit_step,
    teacher_policy_fn=jit_inference_fn,
    student_policy_fn=student_policy_fn,
    student_params=student_params,
    episode_length=episode_length,
    command=command,
    seed=seed + episode,
    width=640,
    height=480,
    fps=int(1.0 / env.dt / 2),
    render_every=2
)