In [1]:
# ============================= Cell 1: Imports =============================
import sys
import os

# Get the absolute path of the parent directory (EAI2025_RL_FINAL)
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
# Add it to sys.path
sys.path.append(parent_dir)

# Set environment variables BEFORE importing mujoco to ensure headless offscreen rendering
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
os.environ['MUJOCO_GL'] = 'egl'

import jax
import jax.numpy as jnp
import numpy as np
import functools
import flax.linen as nn
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.acme import running_statistics
import mujoco
from mujoco_playground.config import locomotion_params
from environments.custom_env import Joystick, default_config
from mujoco_playground._src.gait import draw_joystick_command
from IPython.display import HTML, display
import imageio

print("✓ All imports successful")

✓ All imports successful


In [2]:
# ============================= Cell 2: Variables + Environment Setup =============================
# Policy selection
STUDENT_DATA_COLLECTION = "dagger"  # "bc" or "dagger"
STUDENT_LOSS_FUNCTION = "mse"      # "mse" or "kl"
STUDENT_OBSERVATION_KEY = "student_state"  # "student_state" or "state" or "privileged_state"

# Evaluation settings
EVAL_SEED = 42
EVAL_STEPS = 450
EVAL_COMMAND = jnp.array([1.0, 0.5, 0.5])  # [x_vel, y_vel, yaw_vel]

# Rendering settings
RENDER_EVERY = 4
VIDEO_WIDTH = 640
VIDEO_HEIGHT = 480
FPS = int((1.0 / 0.005) / RENDER_EVERY)  # will be overwritten after env init if dt differs

# Output settings
experiment_name = f"student_{STUDENT_DATA_COLLECTION}_{STUDENT_LOSS_FUNCTION}_{STUDENT_OBSERVATION_KEY}"
results_path = os.path.join(parent_dir, 'results', experiment_name)
output_gif_path = os.path.join(results_path, 'teacher_vs_student_comparison.gif')
os.makedirs(results_path, exist_ok=True)

print(f"Experiment: {experiment_name}")
print(f"Output will be saved to: {output_gif_path}")
print(f"Evaluation command: x={EVAL_COMMAND[0]}, y={EVAL_COMMAND[1]}, yaw={EVAL_COMMAND[2]}")

# ============================= ENVIRONMENT SETUP =============================
xml_path = '../environments/custom_env.xml'
env_cfg = default_config()
# Disable perturbations for a clean comparison
env_cfg.pert_config.enable = True
env_cfg.pert_config.velocity_kick = [0.0, 0.0]
env_cfg.pert_config.kick_wait_times = [5.0, 15.0]
env_cfg.command_config.a = [1.5, 0.8, 2 * jnp.pi]

env = Joystick(xml_path=xml_path, config=env_cfg)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# Discover observation/action shapes
_dummy_key = jax.random.PRNGKey(0)
dummy_state = jit_reset(_dummy_key)
obs_shape = jax.tree_util.tree_map(lambda x: x.shape, dummy_state.obs)
action_size = env.action_size
student_obs_dim = int(dummy_state.obs[STUDENT_OBSERVATION_KEY].shape[0])

FPS = int((1.0 / env.dt) / RENDER_EVERY)

print(f"\n✓ Environment initialized")
print(f"  Action size: {action_size}")
print(f"  Student observation dim: {student_obs_dim}")
print(f"  Available observation keys: {list(dummy_state.obs.keys())}")



Experiment: student_dagger_mse_student_state
Output will be saved to: /home/arito/robot-learning/EAI2025_RL_FINAL/results/student_dagger_mse_student_state/teacher_vs_student_comparison.gif
Evaluation command: x=1.0, y=0.5, yaw=0.5

✓ Environment initialized
  Action size: 12
  Student observation dim: 48
  Available observation keys: ['privileged_state', 'state', 'student_state']


In [3]:
# ============================= LOAD TEACHER POLICY =============================
teacher_params_path = "../parameters/params_with_height_and_knee.npy"
_loaded = np.load(teacher_params_path, allow_pickle=True)
if getattr(_loaded, 'ndim', 1) == 0:
    normalizer_params, policy_params, value_params = _loaded.item()
else:
    normalizer_params, policy_params, value_params = tuple(_loaded.tolist())
teacher_params = (normalizer_params, policy_params, value_params)

# Build teacher network
normalize = running_statistics.normalize
ppo_params = locomotion_params.brax_ppo_config('Go1JoystickRoughTerrain')
network_factory = ppo_networks.make_ppo_networks
if hasattr(ppo_params, 'network_factory'):
    network_factory = functools.partial(ppo_networks.make_ppo_networks, **ppo_params.network_factory)

ppo_network = network_factory(obs_shape, action_size, preprocess_observations_fn=normalize)
make_policy = ppo_networks.make_inference_fn(ppo_network)
teacher_policy_fn = jax.jit(make_policy(teacher_params, deterministic=True))

print(f"✓ Teacher policy loaded from: {teacher_params_path}")

# ============================= LOAD STUDENT POLICY =============================
# Student network definition (same as training)
class StudentPolicy(nn.Module):
    action_size: int
    hidden_size: int = 256
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(features=(self.hidden_size//2))(x)
        x = nn.relu(x)
        logits = nn.Dense(features=2 * self.action_size)(x)
        return logits

# Initialize student network structure
student_net = StudentPolicy(action_size=action_size)

# Load student parameters
student_params_filename = f"student_params_{STUDENT_DATA_COLLECTION}_{STUDENT_LOSS_FUNCTION}.npy"
student_params_path = os.path.join(results_path, student_params_filename)

if not os.path.exists(student_params_path):
    raise FileNotFoundError(
        f"Student parameters not found at: {student_params_path}\n"
        f"Please train the student policy first using training/teacher_student_MLP.ipynb"
    )

loaded_params = np.load(student_params_path, allow_pickle=True).item()
student_params = jax.tree.map(lambda x: jnp.array(x), loaded_params)

# Create student inference function
student_policy_fn = jax.jit(lambda obs, rng: student_net.apply(student_params, obs))

print(f"✓ Student policy loaded from: {student_params_path}")
print(f"\n✓ Both policies ready for evaluation")

✓ Teacher policy loaded from: ../parameters/params_with_height_and_knee.npy
✓ Student policy loaded from: /home/arito/robot-learning/EAI2025_RL_FINAL/results/student_dagger_mse_student_state/student_params_dagger_mse.npy

✓ Both policies ready for evaluation


# SINGLE COMMAND EVALUATION
# ============================= Cell 4: Side-by-Side Evaluation and GIF Generation =============================
# Runs teacher and student on the same fixed command for EVAL_STEPS and saves a side-by-side GIF.

# Roll out both policies
key = jax.random.PRNGKey(EVAL_SEED)
key_teacher, key_student, key_env = jax.random.split(key, 3)

state_teacher = jit_reset(key_env)
state_student = jit_reset(key_env)
state_teacher.info["command"] = EVAL_COMMAND
state_student.info["command"] = EVAL_COMMAND

rollout_teacher = []
rollout_student = []
modify_scene_fns_teacher = []
modify_scene_fns_student = []

for _ in range(EVAL_STEPS):
    # Teacher step (expects full observation tree)
    act_rng_teacher, key_teacher = jax.random.split(key_teacher)
    ctrl_teacher, _ = teacher_policy_fn(state_teacher.obs, act_rng_teacher)
    state_teacher = jit_step(state_teacher, ctrl_teacher)
    state_teacher.info["command"] = EVAL_COMMAND
    rollout_teacher.append(state_teacher)

    # Student step: logits -> tanh(mu)
    act_rng_student, key_student = jax.random.split(key_student)
    student_obs = state_student.obs[STUDENT_OBSERVATION_KEY].reshape(1, -1)
    student_logits = student_policy_fn(student_obs, act_rng_student)
    mu = student_logits[0, :action_size]
    ctrl_student = jnp.tanh(mu)
    state_student = jit_step(state_student, ctrl_student)
    state_student.info["command"] = EVAL_COMMAND
    rollout_student.append(state_student)

    # Visual command overlays
    for st, mod_list in [
        (state_teacher, modify_scene_fns_teacher),
        (state_student, modify_scene_fns_student),
    ]:
        xyz = np.array(st.data.xpos[env._torso_body_id])
        xyz += np.array([0, 0, 0.2])
        x_axis = st.data.xmat[env._torso_body_id, 0]
        yaw = -np.arctan2(x_axis[1], x_axis[0])
        mod_list.append(
            functools.partial(
                draw_joystick_command,
                cmd=st.info["command"],
                xyz=xyz,
                theta=yaw,
                scl=abs(st.info["command"][0]) / env_cfg.command_config.a[0],
            )
        )

# Subsample for rendering
traj_teacher = rollout_teacher[::RENDER_EVERY]
traj_student = rollout_student[::RENDER_EVERY]
mod_fns_teacher = modify_scene_fns_teacher[::RENDER_EVERY]
mod_fns_student = modify_scene_fns_student[::RENDER_EVERY]

# Render frames and save a side-by-side GIF
frames_teacher = render_rollout(env, traj_teacher, mod_fns_teacher)
frames_student = render_rollout(env, traj_student, mod_fns_student)

# Ensure equal length
min_len = min(len(frames_teacher), len(frames_student))
frames_teacher = frames_teacher[:min_len]
frames_student = frames_student[:min_len]

# Combine side-by-side
combined_frames = [
    np.concatenate([t, s], axis=1)  # horizontal concat
    for t, s in zip(frames_teacher, frames_student)
]

# Derive FPS from env.dt
fps = int((1.0 / env.dt) / RENDER_EVERY)

imageio.mimsave(output_gif_path, combined_frames, fps=fps)
print(f"✓ Saved comparison GIF to: {output_gif_path}")

In [4]:
# ============================= EXTRA: Sequence of 6 commands → single concatenated GIF =============================
# Runs the six commands sequentially, renders teacher (left) and student (right) for each,
# and concatenates all runs into a single GIF saved at `output_gif_path`.
import base64
from PIL import Image, ImageDraw, ImageFont

x_vels =   [0.7, 0.0, 0.5]
y_vels =   [0.0, 0.7, 0.5]
yaw_vels = [0.0, 0.0, 0.1]

frames_all = []
# We'll use a reproducible base key and split per run so teacher/student share the same reset
base_key = jax.random.PRNGKey(EVAL_SEED)

# Helper to render a rollout to frames
def render_rollout(env, traj, mod_fns, width=VIDEO_WIDTH, height=VIDEO_HEIGHT):
    scene_option = mujoco.MjvOption()
    scene_option.geomgroup[2] = True
    scene_option.geomgroup[3] = False
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True
    frames = env.render(
        traj,
        camera="track",
        scene_option=scene_option,
        width=width,
        height=height,
        modify_scene_fns=mod_fns,
    )
    return [np.asarray(f, dtype=np.uint8) for f in frames]

# Helper draw top command and bottom labels on a combined frame (numpy uint8 HxWx3)
def draw_overlays(frame_np, command_text, left_label="TEACHER", right_label="STUDENT"):
    img = Image.fromarray(frame_np)
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.truetype("DejaVuSans-Bold.ttf", 18)
        label_font = ImageFont.truetype("DejaVuSans-Bold.ttf", 20)
    except Exception:
        # fallback
        font = ImageFont.load_default()
        label_font = ImageFont.load_default()

    W, H = img.size
    half_W = W // 2

    # Top command (centered)
    text = command_text
    bbox = draw.textbbox((0, 0), text, font=font)
    tw = bbox[2] - bbox[0]
    th = bbox[3] - bbox[1]
    top_x = W // 2 - tw // 2
    top_y = 6
    # Draw outline for readability
    draw.text((top_x, top_y), text, font=font, fill=(255, 255, 255), stroke_width=2, stroke_fill=(0,0,0))

    # Bottom labels (left and right)
    left_text = left_label
    right_text = right_label
    lbbox = draw.textbbox((0, 0), left_text, font=label_font)
    lw = lbbox[2] - lbbox[0]
    lh = lbbox[3] - lbbox[1]
    rbbox = draw.textbbox((0, 0), right_text, font=label_font)
    rw = rbbox[2] - rbbox[0]
    rh = rbbox[3] - rbbox[1]

    left_x = half_W // 2 - lw // 2
    right_x = half_W + (half_W // 2) - rw // 2
    bottom_y = H - lh - 6

    draw.text((left_x, bottom_y), left_text, font=label_font, fill=(255,255,255), stroke_width=2, stroke_fill=(0,0,0))
    draw.text((right_x, bottom_y), right_text, font=label_font, fill=(255,255,255), stroke_width=2, stroke_fill=(0,0,0))

    return np.array(img)

for run_id, (xv, yv, yawv) in enumerate(zip(x_vels, y_vels, yaw_vels)):
    print(f"Run {run_id+1}/{len(x_vels)} — command: x={xv}, y={yv}, yaw={yawv}")
    command = jnp.array([xv, yv, yawv])

    # Reset envs with a fresh key so each run starts from the same deterministic initial state
    base_key, key_env = jax.random.split(base_key)
    state_teacher = jit_reset(key_env)
    state_student = jit_reset(key_env)
    state_teacher.info["command"] = command
    state_student.info["command"] = command

    rollout_teacher = []
    rollout_student = []
    modify_scene_fns_teacher = []
    modify_scene_fns_student = []

    # Per-run RNGs for action sampling
    key_t = jax.random.PRNGKey(EVAL_SEED + run_id + 1)
    key_s = jax.random.PRNGKey(EVAL_SEED + run_id + 1)

    for _step in range(EVAL_STEPS):
        # Teacher step
        key_t, act_rng_t = jax.random.split(key_t)
        ctrl_teacher, _ = teacher_policy_fn(state_teacher.obs, act_rng_t)
        state_teacher = jit_step(state_teacher, ctrl_teacher)
        state_teacher.info["command"] = command
        rollout_teacher.append(state_teacher)

        # Student step
        key_s, act_rng_s = jax.random.split(key_s)
        student_obs = state_student.obs[STUDENT_OBSERVATION_KEY].reshape(1, -1)
        student_logits = student_policy_fn(student_obs, act_rng_s)
        mu = student_logits[0, :action_size]
        ctrl_student = jnp.tanh(mu)
        state_student = jit_step(state_student, ctrl_student)
        state_student.info["command"] = command
        rollout_student.append(state_student)

        # Visual overlays for both
        for st, mod_list in [
            (state_teacher, modify_scene_fns_teacher),
            (state_student, modify_scene_fns_student),
        ]:
            xyz = np.array(st.data.xpos[env._torso_body_id])
            xyz += np.array([0, 0, 0.2])
            x_axis = st.data.xmat[env._torso_body_id, 0]
            yaw = -np.arctan2(x_axis[1], x_axis[0])
            mod_list.append(
                functools.partial(
                    draw_joystick_command,
                    cmd=st.info["command"],
                    xyz=xyz,
                    theta=yaw,
                    scl=abs(st.info["command"][0]) / env_cfg.command_config.a[0],
                )
            )

    # Subsample and render this run
    traj_t = rollout_teacher[::RENDER_EVERY]
    traj_s = rollout_student[::RENDER_EVERY]
    mod_t = modify_scene_fns_teacher[::RENDER_EVERY]
    mod_s = modify_scene_fns_student[::RENDER_EVERY]

    frames_t = render_rollout(env, traj_t, mod_t, width=VIDEO_WIDTH, height=VIDEO_HEIGHT)
    frames_s = render_rollout(env, traj_s, mod_s, width=VIDEO_WIDTH, height=VIDEO_HEIGHT)

    # Ensure same length and append side-by-side frames for this run
    min_len = min(len(frames_t), len(frames_s))
    if min_len == 0:
        print(f"Warning: no frames produced for run {run_id}. Skipping.")
        continue

    # Prepare top text once per run
    command_text = f"Command: x={xv:.2f}, y={yv:.2f}, yaw={yawv:.2f}"

    for i in range(min_len):
        combined = np.concatenate([frames_t[i], frames_s[i]], axis=1)
        # Draw overlays: top command and bottom labels
        combined = draw_overlays(combined, command_text, left_label="TEACHER", right_label="STUDENT")
        frames_all.append(combined)

    # Explicitly delete large lists and let GC free memory before next run
    del rollout_teacher, rollout_student, modify_scene_fns_teacher, modify_scene_fns_student

    # Reset envs explicitly (next run will call jit_reset again) — this satisfies the "reset after each iteration completes" requirement
    _ = jit_reset(jax.random.PRNGKey(EVAL_SEED + run_id + 1000))

# Save the concatenated GIF for all runs
if len(frames_all) == 0:
    raise RuntimeError("No frames were collected from any runs. Check rendering / env configuration.")

fps = int((1.0 / env.dt) / RENDER_EVERY)
imageio.mimsave(output_gif_path, frames_all, fps=fps)
print(f"✓ Saved concatenated comparison GIF to: {output_gif_path}")

# Display inline (base64) so notebook shows the full GIF
# with open(output_gif_path, "rb") as f:
#     encoded = base64.b64encode(f.read()).decode("ascii")
# html = f'<div style="text-align:center"><h3>Teacher (left) vs Student (right) — concatenated runs</h3><img src="data:image/gif;base64,{encoded}" width="{VIDEO_WIDTH*2}" height="{VIDEO_HEIGHT}"/></div>'

# display(HTML(html))

Run 1/3 — command: x=0.7, y=0.0, yaw=0.0


100%|██████████| 113/113 [00:01<00:00, 106.44it/s]
100%|██████████| 113/113 [00:00<00:00, 116.04it/s]


Run 2/3 — command: x=0.0, y=0.7, yaw=0.0


100%|██████████| 113/113 [00:01<00:00, 112.57it/s]
100%|██████████| 113/113 [00:00<00:00, 114.18it/s]


Run 3/3 — command: x=0.5, y=0.5, yaw=0.1


100%|██████████| 113/113 [00:01<00:00, 103.93it/s]
100%|██████████| 113/113 [00:01<00:00, 104.21it/s]


✓ Saved concatenated comparison GIF to: /home/arito/robot-learning/EAI2025_RL_FINAL/results/student_dagger_mse_student_state/teacher_vs_student_comparison.gif
