In [None]:
import os
import torch
import h5py
os.environ["MUJOCO_GL"] = "egl"
import mujoco
import numpy as np
import mediapy as media
from robosuite.utils.transform_utils import convert_quat, quat2axisangle, axisangle2quat
from plato_copilot.kinodynamics.dynamics_model import DynamicsDeltaOutputModel, DynamicsPositionInputDeltaOutputModel, safe_device
import matplotlib.pyplot as plt
from collections import defaultdict

def replay_state(mj_model, mj_data, qpos, qvel, mocap_pos, mocap_quat):
    mujoco.mj_setState(mj_model, mj_data, qpos, mujoco.mjtState.mjSTATE_QPOS)
    mujoco.mj_setState(mj_model, mj_data, qvel, mujoco.mjtState.mjSTATE_QVEL)
    mujoco.mj_setState(mj_model, mj_data, np.squeeze(mocap_pos), mujoco.mjtState.mjSTATE_MOCAP_POS)
    mujoco.mj_setState(mj_model, mj_data, np.squeeze(mocap_quat), mujoco.mjtState.mjSTATE_MOCAP_QUAT)
    mujoco.mj_step1(mj_model, mj_data)

def quaternion_geodesic_error(q1: np.ndarray, q2: np.ndarray) -> float:
    """Returns angular error in radians between two unit quaternions."""
    # Normalize to avoid drift
    q1 = q1 / np.linalg.norm(q1)
    q2 = q2 / np.linalg.norm(q2)
    dot = np.clip(np.abs(np.dot(q1, q2)), 0.0, 1.0)
    return 2 * np.arccos(dot)  # radians

def pose_axisangle_to_quat(pose_6d):
    pos = pose_6d[:, :3]
    quat = np.array([convert_quat(axisangle2quat(a), to="wxyz") for a in pose_6d[:, 3:]])
    return pos, quat

In [None]:
dataset_path = "../datasets/force_vis"
experiment_path = "../experiments/force_train/run_010"

dynamics_model = safe_device(DynamicsDeltaOutputModel())
dynamics_model.load_state_dict(torch.load(os.path.join(experiment_path, "dynamics_model_delta.pth")))
dynamics_model.eval()
device = dynamics_model.device

# all_errors = []
interval_errors_pos = defaultdict(list)
interval_errors_rot = defaultdict(list)

stepwise_pos_errors = []
stepwise_rot_errors = []

with h5py.File(os.path.join(dataset_path, "training_data.hdf5"), "r") as f_train, \
     h5py.File(os.path.join(dataset_path, "data.hdf5"), "r") as f_data:

    for key in list(f_train["data"].keys())[:50]:
        print(f"\n=== Rollout {key} ===")

        # Load data
        obj_pose = f_train["data"][key]['object_pose'][()]
        tool_pose = f_train["data"][key]['mocap_pose'][()]
        force_fb  = f_train["data"][key]['force_feedback'][()]
        actions   = f_train["data"][key]['actions'][()]
        traj_idx  = f_train["data"][key].attrs["traj_idx"]
        xml       = f_data['data'][traj_idx].attrs['xml']
        block_name = f_data['data'][traj_idx].attrs['block_name']

        # Setup MuJoCo
        mj_model = mujoco.MjModel.from_xml_string(xml)
        mj_data = mujoco.MjData(mj_model)
        mujoco.mj_resetData(mj_model, mj_data)
        mujoco.mj_step(mj_model, mj_data)

        joint_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_JOINT, f"{block_name}_freejoint")
        tgt_joint_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_JOINT, "vis_block_freejoint")
        qpos_addr_tuple = (mj_model.jnt_qposadr[joint_id], mj_model.jnt_qposadr[joint_id] + 7)
        tgt_qpos_addr_tuple = (mj_model.jnt_qposadr[tgt_joint_id], mj_model.jnt_qposadr[tgt_joint_id] + 7)

        qpos_list = f_data['data'][traj_idx]['qpos']
        qvel_list = f_data['data'][traj_idx]['qvel']
        mocap_pos_list = f_data['data'][traj_idx]['mocap_pos']
        mocap_quat_list = f_data['data'][traj_idx]['mocap_quat']

        # === Model rollout and MSE ===
        seq_len = 10
        interval = 50
        rollout_len = 100

        rollout_mse_list = []

        for start in range(seq_len, obj_pose.shape[0] - rollout_len, interval):
            block_hist = [obj_pose[i] for i in range(start - seq_len, start)]
            tool_hist = [tool_pose[i] for i in range(start - seq_len, start)]
            force_hist = [force_fb[i] for i in range(start - seq_len, start)]
            action_hist = [actions[i] for i in range(start - seq_len, start)]

            pred_seq = []

            for step in range(rollout_len):
                t = start + step
                x_state = np.concatenate([
                    np.stack(block_hist),
                    np.stack(tool_hist),
                    np.stack(force_hist),
                ], axis=1)
                a_seq = np.stack(action_hist)

                data = {
                    "x_seq": torch.from_numpy(x_state[None]).float().to(device),
                    "a_seq": torch.from_numpy(a_seq[None]).float().to(device),
                }

                with torch.no_grad():
                    pred = dynamics_model.predict(data).cpu().numpy()[0]
                pred_seq.append(pred)

                # roll window
                block_hist.pop(0)
                block_hist.append(pred)
                tool_hist.pop(0)
                tool_hist.append(tool_pose[t])
                force_hist.pop(0)
                force_hist.append(force_fb[t])
                action_hist.pop(0)
                action_hist.append(actions[t])

            pred_seq = np.stack(pred_seq)  # (rollout_len, 6)
            gt_seq = obj_pose[start:start + rollout_len]  # (rollout_len, 6)

            pos_pred, quat_pred = pose_axisangle_to_quat(pred_seq)
            pos_gt, quat_gt = pose_axisangle_to_quat(gt_seq)

            rot_errors_rad = np.array([
                quaternion_geodesic_error(qp, qg) for qp, qg in zip(quat_pred, quat_gt)
            ])
            rot_errors_deg = np.degrees(rot_errors_rad)

            pos_mse = np.mean((pos_pred - pos_gt) ** 2)
            rot_deg_mean = np.mean(rot_errors_deg)

            print(f"  t={start:4d} → t={start + rollout_len:4d} | pos MSE = {pos_mse:.6f}, rot err = {rot_deg_mean:.2f}°")
            rollout_mse_list.append((start, pos_mse, rot_deg_mean))
            
            if start == 60:
                pos_errors = np.mean((pos_pred - pos_gt)**2, axis=1)
                rot_errors = np.array([
                    quaternion_geodesic_error(qp, qg) for qp, qg in zip(quat_pred, quat_gt)
                ])
                rot_errors = np.degrees(rot_errors)
                stepwise_pos_errors.append(pos_errors)
                stepwise_rot_errors.append(rot_errors)

        for idx, (start, pos_mse, rot_deg_mean) in enumerate(rollout_mse_list):
            interval_errors_pos[idx].append(pos_mse)
            interval_errors_rot[idx].append(rot_deg_mean)

        # mean_pos_mse = np.mean([m[1] for m in rollout_mse_list])
        # mean_rot_err = np.mean([m[2] for m in rollout_mse_list])
        # print(f"Mean Pos MSE for {key}: {mean_pos_mse:.6f}")
        # print(f"Mean Rot error for {key}: {mean_rot_err:.6f}")
        # all_errors.append((mean_pos_mse, mean_rot_err))

# print(f"\n=== OVERALL AVERAGE ERROR across {len(all_errors)} rollouts: {np.mean(all_errors):.6f} ===")

In [None]:
dataset_path = "../datasets/force_vis"
experiment_path = "../experiments/curved_train/run_001"

pos_dynamics_model = safe_device(DynamicsPositionInputDeltaOutputModel())
pos_dynamics_model.load_state_dict(torch.load(os.path.join(experiment_path, "dynamics_model_delta.pth")))
pos_dynamics_model.eval()
device = pos_dynamics_model.device

pos_interval_errors_pos = defaultdict(list)
pos_interval_errors_rot = defaultdict(list)

pos_stepwise_pos_errors = []
pos_stepwise_rot_errors = []

with h5py.File(os.path.join(dataset_path, "training_data.hdf5"), "r") as f_train, \
     h5py.File(os.path.join(dataset_path, "data.hdf5"), "r") as f_data:

    for key in list(f_train["data"].keys())[:50]:
        print(f"\n=== Rollout {key} ===")

        # Load data
        obj_pose = f_train["data"][key]['object_pose'][()]
        actions   = f_train["data"][key]['actions'][()]
        traj_idx  = f_train["data"][key].attrs["traj_idx"]
        xml       = f_data['data'][traj_idx].attrs['xml']
        block_name = f_data['data'][traj_idx].attrs['block_name']

        # Setup MuJoCo
        mj_model = mujoco.MjModel.from_xml_string(xml)
        mj_data = mujoco.MjData(mj_model)
        mujoco.mj_resetData(mj_model, mj_data)
        mujoco.mj_step(mj_model, mj_data)

        joint_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_JOINT, f"{block_name}_freejoint")
        tgt_joint_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_JOINT, "vis_block_freejoint")
        qpos_addr_tuple = (mj_model.jnt_qposadr[joint_id], mj_model.jnt_qposadr[joint_id] + 7)
        tgt_qpos_addr_tuple = (mj_model.jnt_qposadr[tgt_joint_id], mj_model.jnt_qposadr[tgt_joint_id] + 7)

        qpos_list = f_data['data'][traj_idx]['qpos']
        qvel_list = f_data['data'][traj_idx]['qvel']
        mocap_pos_list = f_data['data'][traj_idx]['mocap_pos']
        mocap_quat_list = f_data['data'][traj_idx]['mocap_quat']

        # === Model rollout and MSE ===
        seq_len = 10
        interval = 50
        rollout_len = 100

        rollout_mse_list = []

        for start in range(seq_len, obj_pose.shape[0] - rollout_len, interval):
            block_hist = [obj_pose[i] for i in range(start - seq_len, start)]
            action_hist = [actions[i][:3] for i in range(start - seq_len, start)]

            pred_seq = []

            for step in range(rollout_len):
                t = start + step
                x_state = np.stack(block_hist)
                x_state = x_state[None, :, :] 
                a_seq = np.stack(action_hist)
                a_seq = a_seq[None, :, None, :] 
                
                data = {
                    "x_seq": torch.from_numpy(x_state).float().to(device),
                    "a_seq": torch.from_numpy(a_seq).float().to(device),
                }

                with torch.no_grad():
                    pred = pos_dynamics_model.predict(data).cpu().numpy()[0][0]
                pred_seq.append(pred)

                # roll window
                block_hist.pop(0)
                block_hist.append(pred)
                action_hist.pop(0)
                action_hist.append(actions[t][:3])

            pred_seq = np.stack(pred_seq)  # (rollout_len, 6)
            gt_seq = obj_pose[start:start + rollout_len]  # (rollout_len, 6)

            pos_pred, quat_pred = pose_axisangle_to_quat(pred_seq)
            pos_gt, quat_gt = pose_axisangle_to_quat(gt_seq)

            rot_errors_rad = np.array([
                quaternion_geodesic_error(qp, qg) for qp, qg in zip(quat_pred, quat_gt)
            ])
            rot_errors_deg = np.degrees(rot_errors_rad)

            # report statistics
            pos_mse = np.mean((pos_pred - pos_gt) ** 2)
            rot_deg_mean = np.mean(rot_errors_deg)

            print(f"  t={start:4d} → t={start + rollout_len:4d} | pos MSE = {pos_mse:.6f}, rot err = {rot_deg_mean:.2f}°")
            rollout_mse_list.append((start, pos_mse, rot_deg_mean))
            
            if start == 60:
                pos_errors = np.mean((pos_pred - pos_gt)**2, axis=1)
                rot_errors = np.array([
                    quaternion_geodesic_error(qp, qg) for qp, qg in zip(quat_pred, quat_gt)
                ])
                rot_errors = np.degrees(rot_errors)
                pos_stepwise_pos_errors.append(pos_errors)
                pos_stepwise_rot_errors.append(rot_errors)

        for idx, (start, pos_mse, rot_deg_mean) in enumerate(rollout_mse_list):
            pos_interval_errors_pos[idx].append(pos_mse)
            pos_interval_errors_rot[idx].append(rot_deg_mean)

In [None]:
# Prepare data
interval_indices = sorted(interval_errors_pos.keys())
xticks = [start for (start, _, _) in rollout_mse_list]

# Model (yours)
mean_pos_errors = [np.mean(interval_errors_pos[i]) for i in interval_indices]
std_pos_errors  = [np.std(interval_errors_pos[i])  for i in interval_indices]
mean_rot_errors = [np.mean(interval_errors_rot[i]) for i in interval_indices]
std_rot_errors  = [np.std(interval_errors_rot[i])  for i in interval_indices]

# Baseline
mean_pos_baseline = [np.mean(pos_interval_errors_pos[i]) for i in interval_indices]
std_pos_baseline  = [np.std(pos_interval_errors_pos[i])  for i in interval_indices]
mean_rot_baseline = [np.mean(pos_interval_errors_rot[i]) for i in interval_indices]
std_rot_baseline  = [np.std(pos_interval_errors_rot[i])  for i in interval_indices]

# Plot
plt.figure(figsize=(10, 5))

# Position error
plt.subplot(1, 2, 1)
plt.ylim(0, 0.1)
plt.errorbar(xticks, mean_pos_baseline, yerr=std_pos_baseline,
             fmt='o', capsize=4, markersize=5, label="Position Input Model")
plt.errorbar(xticks, mean_pos_errors, yerr=std_pos_errors,
             fmt='*', color='red', capsize=4, markersize=10, label="Force Feedback Model")
plt.title("Average Position Error in 100-Step Predictions")
plt.xlabel("Autoregressive Rollout Start Step")
plt.ylabel("Mean Squared Error")
plt.xticks(xticks)
plt.legend()

# Rotation error
plt.subplot(1, 2, 2)
plt.errorbar(xticks, mean_rot_baseline, yerr=std_rot_baseline,
             fmt='o', capsize=4, markersize=5, label="Position Input Model")
plt.errorbar(xticks, mean_rot_errors, yerr=std_rot_errors,
             fmt='*', color='red', capsize=4, markersize=10, label="Force Feedback Model")
plt.title("Average Rotation Error in 100-Step Predictions")
plt.xlabel("Autoregressive Rollout Start Step")
plt.ylabel("Angular Error (°)")
plt.xticks(xticks)
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# === Stack per-step errors ===
# Model
model_pos_arr = np.stack(stepwise_pos_errors)     # (N, 100)
model_rot_arr = np.stack(stepwise_rot_errors)     # (N, 100)

# Baseline
baseline_pos_arr = np.stack(pos_stepwise_pos_errors)  # (N, 100)
baseline_rot_arr = np.stack(pos_stepwise_rot_errors)  # (N, 100)

# === Compute means and stds ===
steps = np.arange(model_pos_arr.shape[1])

mean_pos_model = np.mean(model_pos_arr, axis=0)
std_pos_model  = np.std(model_pos_arr, axis=0)
mean_rot_model = np.mean(model_rot_arr, axis=0)
std_rot_model  = np.std(model_rot_arr, axis=0)

mean_pos_baseline = np.mean(baseline_pos_arr, axis=0)
std_pos_baseline  = np.std(baseline_pos_arr, axis=0)
mean_rot_baseline = np.mean(baseline_rot_arr, axis=0)
std_rot_baseline  = np.std(baseline_rot_arr, axis=0)

# === Plot ===
plt.figure(figsize=(10, 5))

# --- Position Error ---
plt.subplot(1, 2, 1)
plt.ylim(0, 0.1)
plt.plot(steps, mean_pos_baseline, label="Position Input Model", color="blue", linestyle="--")
plt.fill_between(steps, mean_pos_baseline - std_pos_baseline, mean_pos_baseline + std_pos_baseline,
                 color="blue", alpha=0.3)

plt.plot(steps, mean_pos_model, label="Force Feedback Model", color="red")
plt.fill_between(steps, mean_pos_model - std_pos_model, mean_pos_model + std_pos_model,
                 color="red", alpha=0.3)

plt.title("Position Error (Start @ t=60)")
plt.xlabel("Prediction Step")
plt.ylabel("Mean Squared Error")
plt.legend()
plt.grid(True)

# --- Rotation Error ---
plt.subplot(1, 2, 2)
plt.ylim(0, 140)
plt.plot(steps, mean_rot_baseline, label="Position Input Model", color="blue", linestyle="--")
plt.fill_between(steps, mean_rot_baseline - std_rot_baseline, mean_rot_baseline + std_rot_baseline,
                 color="blue", alpha=0.3)

plt.plot(steps, mean_rot_model, label="Force Feedback Model", color="red")
plt.fill_between(steps, mean_rot_model - std_rot_model, mean_rot_model + std_rot_model,
                 color="red", alpha=0.3)

plt.title("Rotation Error (Degrees, Start @ t=60)")
plt.xlabel("Prediction Step")
plt.ylabel("Angular Error (°)")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()