In [4]:
# ============================= Cell 1: Imports =============================
import sys
import os

# Get the absolute path of the parent directory (EAI2025_RL_FINAL)
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
# Add it to sys.path
sys.path.append(parent_dir)

# Set environment variables BEFORE importing mujoco to ensure headless offscreen rendering
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
os.environ['MUJOCO_GL'] = 'egl'

import jax
import jax.numpy as jnp
import numpy as np
import functools
import flax.linen as nn
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.acme import running_statistics
import mujoco
from mujoco_playground.config import locomotion_params
from environments.custom_env import Joystick, default_config
from mujoco_playground._src.gait import draw_joystick_command
from IPython.display import HTML, display
import imageio

print("✓ All imports successful")

✓ All imports successful


In [5]:
# ============================= Cell 2: Variables + Environment Setup =============================
# Policy selection
STUDENT_DATA_COLLECTION = "dagger"  # "bc" or "dagger"
STUDENT_LOSS_FUNCTION = "mse"      # "mse" or "kl"
STUDENT_OBSERVATION_KEY = "student_state"  # "student_state" or "state" or "privileged_state"

# Evaluation settings
EVAL_SEED = 42
EVAL_STEPS = 400
EVAL_COMMAND = jnp.array([1.0, 0.5, 0.5])  # [x_vel, y_vel, yaw_vel]

# Rendering settings
RENDER_EVERY = 3
VIDEO_WIDTH = 430 # 640
VIDEO_HEIGHT = 320 # 480
FPS = int((1.0 / 0.005) / RENDER_EVERY)  # will be overwritten after env init if dt differs

# Output settings
experiment_name = f"student_{STUDENT_DATA_COLLECTION}_{STUDENT_LOSS_FUNCTION}_{STUDENT_OBSERVATION_KEY}"
results_path = os.path.join(parent_dir, 'results', experiment_name)
output_gif_path = os.path.join(results_path, 'stairs_env.gif')
os.makedirs(results_path, exist_ok=True)

print(f"Experiment: {experiment_name}")
print(f"Output will be saved to: {output_gif_path}")
print(f"Evaluation command: x={EVAL_COMMAND[0]}, y={EVAL_COMMAND[1]}, yaw={EVAL_COMMAND[2]}")

# ============================= ENVIRONMENT SETUP =============================
xml_path = '../environments/stairs_env.xml'
env_cfg = default_config()
# Disable perturbations for a clean comparison
env_cfg.pert_config.enable = True
env_cfg.pert_config.velocity_kick = [0.0, 0.0]
env_cfg.pert_config.kick_wait_times = [5.0, 15.0]
env_cfg.command_config.a = [1.5, 0.8, 2 * jnp.pi]

env = Joystick(xml_path=xml_path, config=env_cfg)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# Discover observation/action shapes
_dummy_key = jax.random.PRNGKey(0)
dummy_state = jit_reset(_dummy_key)
obs_shape = jax.tree_util.tree_map(lambda x: x.shape, dummy_state.obs)
action_size = env.action_size
student_obs_dim = int(dummy_state.obs[STUDENT_OBSERVATION_KEY].shape[0])

FPS = int((1.0 / env.dt) / RENDER_EVERY)

print(f"\n✓ Environment initialized")
print(f"  Action size: {action_size}")
print(f"  Student observation dim: {student_obs_dim}")
print(f"  Available observation keys: {list(dummy_state.obs.keys())}")



Experiment: student_dagger_mse_student_state
Output will be saved to: /home/arito/robot-learning/EAI2025_RL_FINAL/results/student_dagger_mse_student_state/stairs_env.gif
Evaluation command: x=1.0, y=0.5, yaw=0.5
Evaluation command: x=1.0, y=0.5, yaw=0.5

✓ Environment initialized
  Action size: 12
  Student observation dim: 48
  Available observation keys: ['privileged_state', 'state', 'student_state']

✓ Environment initialized
  Action size: 12
  Student observation dim: 48
  Available observation keys: ['privileged_state', 'state', 'student_state']


In [6]:
# ============================= LOAD STUDENT POLICY =============================
# Student network definition (same as training)
class StudentPolicy(nn.Module):
    action_size: int
    hidden_size: int = 256
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(features=(self.hidden_size//2))(x)
        x = nn.relu(x)
        logits = nn.Dense(features=2 * self.action_size)(x)
        return logits

# Initialize student network structure
student_net = StudentPolicy(action_size=action_size)

# Load student parameters
student_params_filename = f"student_params_{STUDENT_DATA_COLLECTION}_{STUDENT_LOSS_FUNCTION}.npy"
student_params_path = os.path.join(results_path, student_params_filename)

if not os.path.exists(student_params_path):
    raise FileNotFoundError(
        f"Student parameters not found at: {student_params_path}\n"
        f"Please train the student policy first using training/teacher_student_MLP.ipynb"
    )

loaded_params = np.load(student_params_path, allow_pickle=True).item()
student_params = jax.tree.map(lambda x: jnp.array(x), loaded_params)

# Create student inference function
student_policy_fn = jax.jit(lambda obs, rng: student_net.apply(student_params, obs))

print(f"✓ Student policy loaded from: {student_params_path}")
print(f"\n✓ Student policy ready for evaluation")

✓ Student policy loaded from: /home/arito/robot-learning/EAI2025_RL_FINAL/results/student_dagger_mse_student_state/student_params_dagger_mse.npy

✓ Student policy ready for evaluation


In [7]:
# ============================= Sequence of commands → single concatenated GIF for Student =============================
# Runs the six commands sequentially for the student policy,
# and concatenates all runs into a single GIF saved at `output_gif_path`.
import base64
from PIL import Image, ImageDraw, ImageFont

x_vels =   [1.0, 0.0, 0.5]
y_vels =   [0.0, 1.0, 0.5]
yaw_vels = [0.0, 0.0, 0.1]

frames_all = []
# We'll use a reproducible base key and split per run
base_key = jax.random.PRNGKey(EVAL_SEED)

# Helper to render a rollout to frames
def render_rollout(env, traj, mod_fns, width=VIDEO_WIDTH, height=VIDEO_HEIGHT):
    scene_option = mujoco.MjvOption()
    scene_option.geomgroup[2] = True
    scene_option.geomgroup[3] = False
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True
    frames = env.render(
        traj,
        camera="track",
        scene_option=scene_option,
        width=width,
        height=height,
        modify_scene_fns=mod_fns,
    )
    return [np.asarray(f, dtype=np.uint8) for f in frames]

# Helper draw top command and bottom label on a combined frame (numpy uint8 HxWx3)
def draw_overlays(frame_np, command_text, label):
    img = Image.fromarray(frame_np)
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.truetype("DejaVuSans-Bold.ttf", 18)
        label_font = ImageFont.truetype("DejaVuSans-Bold.ttf", 20)
    except Exception:
        # fallback
        font = ImageFont.load_default()
        label_font = ImageFont.load_default()

    W, H = img.size

    # Top command (centered)
    text = command_text
    bbox = draw.textbbox((0, 0), text, font=font)
    tw = bbox[2] - bbox[0]
    th = bbox[3] - bbox[1]
    top_x = W // 2 - tw // 2
    top_y = 6
    # Draw outline for readability
    draw.text((top_x, top_y), text, font=font, fill=(255, 255, 255), stroke_width=2, stroke_fill=(0,0,0))

    # Bottom label (centered)
    text = label
    bbox = draw.textbbox((0, 0), text, font=label_font)
    tw = bbox[2] - bbox[0]
    th = bbox[3] - bbox[1]
    bottom_x = W // 2 - tw // 2
    bottom_y = H - th - 6
    draw.text((bottom_x, bottom_y), text, font=label_font, fill=(255,255,255), stroke_width=2, stroke_fill=(0,0,0))

    return np.array(img)

for run_id, (xv, yv, yawv) in enumerate(zip(x_vels, y_vels, yaw_vels)):
    print(f"Run {run_id+1}/{len(x_vels)} — command: x={xv}, y={yv}, yaw={yawv}")
    command = jnp.array([xv, yv, yawv])

    # Reset env with a fresh key
    base_key, key_env = jax.random.split(base_key)
    state_student = jit_reset(key_env)
    state_student.info["command"] = command

    rollout_student = []
    modify_scene_fns_student = []

    # Per-run RNG for action sampling
    key_s = jax.random.PRNGKey(EVAL_SEED + run_id + 1)

    for _step in range(EVAL_STEPS):
        # Student step
        key_s, act_rng_s = jax.random.split(key_s)
        student_obs = state_student.obs[STUDENT_OBSERVATION_KEY].reshape(1, -1)
        student_logits = student_policy_fn(student_obs, act_rng_s)
        mu = student_logits[0, :action_size]
        ctrl_student = jnp.tanh(mu)
        state_student = jit_step(state_student, ctrl_student)
        state_student.info["command"] = command
        rollout_student.append(state_student)

        # Visual overlays
        xyz = np.array(state_student.data.xpos[env._torso_body_id])
        xyz += np.array([0, 0, 0.2])
        x_axis = state_student.data.xmat[env._torso_body_id, 0]
        yaw = -np.arctan2(x_axis[1], x_axis[0])
        modify_scene_fns_student.append(
            functools.partial(
                draw_joystick_command,
                cmd=state_student.info["command"],
                xyz=xyz,
                theta=yaw,
                scl=abs(state_student.info["command"][0]) / env_cfg.command_config.a[0],
            )
        )

    # Subsample and render this run
    traj_s = rollout_student[::RENDER_EVERY]
    mod_s = modify_scene_fns_student[::RENDER_EVERY]

    frames_s = render_rollout(env, traj_s, mod_s, width=VIDEO_WIDTH, height=VIDEO_HEIGHT)

    if len(frames_s) == 0:
        print(f"Warning: no frames produced for run {run_id}. Skipping.")
        continue

    # Prepare top text once per run
    command_text = f"Command: x={xv:.2f}, y={yv:.2f}, yaw={yawv:.2f}"

    for i in range(len(frames_s)):
        combined = draw_overlays(frames_s[i], command_text, label=experiment_name)
        frames_all.append(combined)

    # Explicitly delete large lists and let GC free memory before next run
    del rollout_student, modify_scene_fns_student

    # Reset env explicitly (next run will call jit_reset again)
    _ = jit_reset(jax.random.PRNGKey(EVAL_SEED + run_id + 1000))

# Save the concatenated GIF for all runs
if len(frames_all) == 0:
    raise RuntimeError("No frames were collected from any runs. Check rendering / env configuration.")

fps = int((1.0 / env.dt) / RENDER_EVERY)
imageio.mimsave(output_gif_path, frames_all, fps=fps)
print(f"✓ Saved concatenated student evaluation GIF to: {output_gif_path}")

Run 1/3 — command: x=1.0, y=0.0, yaw=0.0


100%|██████████| 134/134 [00:01<00:00, 82.71it/s] 



Run 2/3 — command: x=0.0, y=1.0, yaw=0.0


100%|██████████| 134/134 [00:01<00:00, 83.05it/s]



Run 3/3 — command: x=0.5, y=0.5, yaw=0.1


100%|██████████| 134/134 [00:01<00:00, 73.75it/s]



✓ Saved concatenated student evaluation GIF to: /home/arito/robot-learning/EAI2025_RL_FINAL/results/student_dagger_mse_student_state/stairs_env.gif
