# 02 - XML 載入測試

測試 MJX 能否載入 `booster_lower_t1.xml` 並進行 GPU 模擬。

**前置條件**：已完成 `01_environment_validation.ipynb`

In [None]:
# Cell 1: 環境變數
import os
os.environ["MUJOCO_GL"] = "disabled"

import jax
import jax.numpy as jnp
import mujoco
from mujoco import mjx
import numpy as np

print(f"JAX devices: {jax.devices()}")

In [None]:
# Cell 2: 載入官方 XML
# 注意：路徑可能需要根據 Databricks Workspace 調整
xml_path = "../booster_soccer_showdown/mimic/assets/booster_t1/booster_lower_t1.xml"

# 如果在 Databricks 中，可能需要這樣：
# xml_path = "/Workspace/Users/<email>/booster-soccer/booster_soccer_showdown/mimic/assets/booster_t1/booster_lower_t1.xml"

try:
    mj_model = mujoco.MjModel.from_xml_path(xml_path)
    mj_data = mujoco.MjData(mj_model)
    print("✅ XML 載入成功")
except Exception as e:
    print(f"❌ XML 載入失敗: {e}")
    print("\n請檢查：")
    print("1. xml_path 是否正確")
    print("2. <include> 引用的文件是否存在")
    raise

In [None]:
# Cell 3: 模型資訊
print("=" * 50)
print("模型基本資訊")
print("=" * 50)
print(f"Bodies: {mj_model.nbody}")
print(f"Joints: {mj_model.njnt}")
print(f"DOF (nv): {mj_model.nv}")
print(f"qpos dim (nq): {mj_model.nq}")
print(f"Actuators: {mj_model.nu}")
print(f"Sensors: {mj_model.nsensor}")

print("\n" + "=" * 50)
print("預期值檢查")
print("=" * 50)

# 預期值
expected_nu = 12  # 12 個下肢 actuator
expected_nq = 19  # 7 (freejoint) + 12 (actuated joints)

if mj_model.nu == expected_nu:
    print(f"✅ Actuators: {mj_model.nu} (預期: {expected_nu})")
else:
    print(f"⚠️ Actuators: {mj_model.nu} (預期: {expected_nu})")

if mj_model.nq == expected_nq:
    print(f"✅ qpos dim: {mj_model.nq} (預期: {expected_nq})")
else:
    print(f"⚠️ qpos dim: {mj_model.nq} (預期: {expected_nq})")

In [None]:
# Cell 4: 列出所有 Joint 名稱
print("=" * 50)
print("Joint 列表")
print("=" * 50)

for i in range(mj_model.njnt):
    name = mj_model.joint(i).name
    jnt_type = mj_model.joint(i).type
    type_name = ["free", "ball", "slide", "hinge"][jnt_type]
    print(f"  {i}: {name} ({type_name})")

In [None]:
# Cell 5: 列出所有 Actuator 名稱
print("=" * 50)
print("Actuator 列表")
print("=" * 50)

for i in range(mj_model.nu):
    name = mj_model.actuator(i).name
    ctrl_range = mj_model.actuator(i).ctrlrange
    print(f"  {i}: {name} (range: [{ctrl_range[0]:.1f}, {ctrl_range[1]:.1f}])")

In [None]:
# Cell 6: Body ID 獲取示範（使用 mj_name2id，禁止硬編碼！）
print("=" * 50)
print("Body ID 示範（使用 mj_name2id）")
print("=" * 50)

# 正確做法：使用 mj_name2id
trunk_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "Trunk")
left_foot_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "left_foot_link")
right_foot_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "right_foot_link")

print(f"Trunk body ID: {trunk_id}")
print(f"Left foot body ID: {left_foot_id}")
print(f"Right foot body ID: {right_foot_id}")

# 獲取 site ID（用於 sensor）
imu_site_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_SITE, "imu")
print(f"IMU site ID: {imu_site_id}")

In [None]:
# Cell 7: MJX 編譯
print("=" * 50)
print("MJX 編譯測試")
print("=" * 50)

try:
    mjx_model = mjx.put_model(mj_model)
    mjx_data = mjx.put_data(mj_model, mj_data)
    print("✅ MJX 編譯成功")
    print(f"\nmjx_data.qpos shape: {mjx_data.qpos.shape}")
    print(f"mjx_data.qvel shape: {mjx_data.qvel.shape}")
    print(f"mjx_data.ctrl shape: {mjx_data.ctrl.shape}")
except Exception as e:
    print(f"❌ MJX 編譯失敗: {e}")
    print("\n可能原因：")
    print("- XML 包含 MJX 不支援的功能（<equality>, <tendon> 等）")
    print("- 需要簡化 XML")
    raise

In [None]:
# Cell 8: MJX step 測試
print("=" * 50)
print("MJX Step 測試")
print("=" * 50)

@jax.jit
def mjx_step(mjx_model, mjx_data):
    return mjx.step(mjx_model, mjx_data)

import time

# 第一次 step（包含 JIT 編譯）
t0 = time.time()
mjx_data_new = mjx_step(mjx_model, mjx_data)
mjx_data_new.qpos.block_until_ready()
t1 = time.time()
print(f"First step (with JIT): {(t1-t0)*1000:.2f} ms")

# 連續 1000 步
t0 = time.time()
for _ in range(1000):
    mjx_data_new = mjx_step(mjx_model, mjx_data_new)
mjx_data_new.qpos.block_until_ready()
t1 = time.time()
print(f"1000 steps: {(t1-t0)*1000:.2f} ms ({(t1-t0):.4f} ms/step)")

print(f"\n最終 qpos[:7] (freejoint): {mjx_data_new.qpos[:7]}")
print("✅ MJX step 成功")

In [None]:
# Cell 9: Sensor 數據讀取
print("=" * 50)
print("Sensor 數據")
print("=" * 50)

# 先做一個 forward 計算來更新 sensor
mujoco.mj_forward(mj_model, mj_data)

print(f"sensordata shape: {mj_data.sensordata.shape}")
print(f"\nSensor 列表：")

offset = 0
for i in range(mj_model.nsensor):
    name = mj_model.sensor(i).name
    dim = mj_model.sensor(i).dim
    data = mj_data.sensordata[offset:offset+dim]
    print(f"  {name}: {data} (dim={dim})")
    offset += dim

## 下一步

如果所有測試通過：
1. ✅ XML 可以被 MJX 載入
2. ✅ GPU step 正常運作
3. → 可以開始實現 `soccer_env.xml`（加入 ball、goal）

### MJX 限制提醒

如果編譯失敗，可能需要移除：
- `<equality>` 約束
- `<tendon>`
- 某些不支援的 contact 類型

參考：[MJX 文檔](https://mujoco.readthedocs.io/en/latest/mjx.html#feature-parity)