# Alpamayo VLA — Comprehensive Visualization Dashboard

Visualizes all outputs from Alpamayo-R1-10B inference:
1. **BEV trajectory plot** (predicted vs ground truth)
2. **Trajectory overlay on camera images**
3. **Multi-sample uncertainty fan**
4. **Animated trajectory video**
5. **Heading arrows on BEV**
6. **Reasoning trace display**
7. **Multi-camera grid with trajectory**
8. **Per-video dashboard**
9. **Aggregate metrics charts**

**Usage:** Point `RESULTS_DIR` to the inference output directory containing `*_inference.json` and `*_vis_data.npz` files.

## 0 — Setup & Data Loading

In [None]:
import json
import glob
import os
from pathlib import Path
from typing import Optional

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize
from matplotlib import cm
from matplotlib.gridspec import GridSpec
import matplotlib.animation as animation
from IPython.display import HTML, display, Markdown
import textwrap

%matplotlib inline
plt.rcParams.update({
    "figure.dpi": 120,
    "font.size": 10,
    "axes.titlesize": 12,
    "axes.labelsize": 10,
})

In [None]:
# ──────────────────────────────────────────────────────────────────────
# CONFIGURE THIS: path to the inference output directory
# ──────────────────────────────────────────────────────────────────────
RESULTS_DIR = "/tmp/alpamayo_output"  # <-- change to your output path

# Auto-discover all run directories
run_dirs = sorted(glob.glob(os.path.join(RESULTS_DIR, "*")))
if not run_dirs:
    # Maybe RESULTS_DIR itself is the run directory
    run_dirs = [RESULTS_DIR]
print(f"Found {len(run_dirs)} run(s): {[os.path.basename(d) for d in run_dirs]}")

In [None]:
# Camera name lookup (index -> name)
INDEX_TO_CAMERA = {
    0: "cross_left_120fov",
    1: "front_wide_120fov",
    2: "cross_right_120fov",
    3: "rear_left_70fov",
    4: "rear_tele_30fov",
    5: "rear_right_70fov",
    6: "front_tele_30fov",
}


def load_video_result(video_dir: str) -> dict:
    """Load JSON metadata + npz tensor data for one video."""
    json_files = glob.glob(os.path.join(video_dir, "*_inference.json"))
    npz_files = glob.glob(os.path.join(video_dir, "*_vis_data.npz"))

    if not json_files:
        # Try one level deeper
        json_files = glob.glob(os.path.join(video_dir, "**", "*_inference.json"), recursive=True)
        npz_files = glob.glob(os.path.join(video_dir, "**", "*_vis_data.npz"), recursive=True)

    result = {}
    if json_files:
        with open(json_files[0]) as f:
            result["meta"] = json.load(f)
    if npz_files:
        result["vis"] = dict(np.load(npz_files[0], allow_pickle=True))
    return result


def discover_all_videos(results_dir: str) -> list[dict]:
    """Discover all video results across all runs."""
    all_videos = []
    for run_dir in sorted(glob.glob(os.path.join(results_dir, "*"))):
        if not os.path.isdir(run_dir):
            continue
        # Check if run_dir itself has results
        result = load_video_result(run_dir)
        if result.get("meta"):
            all_videos.append(result)
            continue
        # Otherwise look one level deeper (run_dir/video_id/...)
        for video_dir in sorted(glob.glob(os.path.join(run_dir, "*"))):
            if not os.path.isdir(video_dir):
                continue
            result = load_video_result(video_dir)
            if result.get("meta"):
                all_videos.append(result)
    return all_videos


all_videos = discover_all_videos(RESULTS_DIR)
print(f"Loaded {len(all_videos)} video result(s)")
for v in all_videos:
    m = v["meta"]
    has_vis = "vis" in v
    print(f"  • {m['video_id']}  minADE={m.get('min_ade_meters', 'N/A'):.4f}m  "
          f"vis_data={'✓' if has_vis else '✗'}  success={m['success']}")

---
## 1 — BEV Trajectory Plot (Predicted vs Ground Truth)

In [None]:
def plot_bev_trajectory(
    pred_xyz: np.ndarray,
    gt_xyz: np.ndarray,
    history_xyz: Optional[np.ndarray] = None,
    title: str = "BEV Trajectory",
    ax: Optional[plt.Axes] = None,
    show_time_coloring: bool = True,
) -> plt.Axes:
    """Bird's-eye-view trajectory plot.

    Coordinate convention for display: rotate 90° CCW so forward (x) points up.
      display_x = -y_ego,  display_y = x_ego

    Args:
        pred_xyz: (S, 64, 3) predicted trajectories.
        gt_xyz:   (64, 3) ground truth.
        history_xyz: (16, 3) optional ego history.
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(7, 7))

    num_samples = pred_xyz.shape[0]
    t_future = np.linspace(0, 6.4, 64)  # seconds
    cmap = cm.viridis
    norm = Normalize(vmin=0, vmax=6.4)

    # History
    if history_xyz is not None:
        hx, hy = -history_xyz[:, 1], history_xyz[:, 0]
        ax.plot(hx, hy, "s-", color="gray", markersize=3, linewidth=1.5,
                label="Ego History (1.6s)", alpha=0.7, zorder=2)

    # Ground truth
    gt_dx, gt_dy = -gt_xyz[:, 1], gt_xyz[:, 0]
    if show_time_coloring:
        points = np.column_stack([gt_dx, gt_dy]).reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        lc = LineCollection(segments, cmap="Reds", norm=norm, linewidth=3, zorder=3)
        lc.set_array(t_future[:-1])
        ax.add_collection(lc)
        ax.plot([], [], "r-", linewidth=3, label="Ground Truth (6.4s)")
    else:
        ax.plot(gt_dx, gt_dy, "r-", linewidth=3, label="Ground Truth", zorder=3)

    # Predicted trajectories
    colors = cm.tab10(np.linspace(0, 1, max(num_samples, 1)))
    for s in range(num_samples):
        px, py = -pred_xyz[s, :, 1], pred_xyz[s, :, 0]
        ax.plot(px, py, "o-", color=colors[s % len(colors)], markersize=2,
                linewidth=1.5, alpha=0.8, label=f"Predicted #{s+1}", zorder=4)

    # Origin marker (ego at t0)
    ax.plot(0, 0, "*", color="black", markersize=14, zorder=5, label="Ego at t₀")

    ax.set_xlabel("Lateral (m)")
    ax.set_ylabel("Forward (m)")
    ax.set_title(title)
    ax.set_aspect("equal")
    ax.legend(fontsize=8, loc="best")
    ax.grid(True, alpha=0.3)
    return ax


# Plot for each video
for v in all_videos:
    if "vis" not in v:
        continue
    vis = v["vis"]
    fig, ax = plt.subplots(figsize=(7, 7))
    plot_bev_trajectory(
        pred_xyz=vis["pred_xyz"],
        gt_xyz=vis["gt_future_xyz"],
        history_xyz=vis["ego_history_xyz"],
        title=f"BEV Trajectory — {v['meta']['video_id']}",
        ax=ax,
    )
    plt.tight_layout()
    plt.show()

---
## 2 — Trajectory Overlay on Front Camera Image

In [None]:
def project_trajectory_to_image(
    traj_xyz: np.ndarray,
    img_h: int,
    img_w: int,
    fov_deg: float = 120.0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Project 3D ego-frame trajectory onto a front-facing camera image (pinhole approx).

    The ego frame has x=forward, y=left, z=up.
    The camera frame has z=forward, x=right, y=down.

    Returns (u, v, mask) — pixel coordinates and a boolean mask for points
    that are in front of the camera and within the image bounds.
    """
    # Approximate intrinsics from horizontal FOV
    fov_rad = np.deg2rad(fov_deg)
    fx = (img_w / 2.0) / np.tan(fov_rad / 2.0)
    fy = fx  # square pixels
    cx, cy = img_w / 2.0, img_h / 2.0

    # Ego -> camera: x_cam = -y_ego, y_cam = -z_ego, z_cam = x_ego
    x_cam = -traj_xyz[:, 1]
    y_cam = -traj_xyz[:, 2]
    z_cam = traj_xyz[:, 0]   # forward

    # Only keep points in front of camera
    in_front = z_cam > 0.5

    u = np.full_like(z_cam, -1.0)
    v = np.full_like(z_cam, -1.0)
    u[in_front] = fx * x_cam[in_front] / z_cam[in_front] + cx
    v[in_front] = fy * y_cam[in_front] / z_cam[in_front] + cy

    in_bounds = in_front & (u >= 0) & (u < img_w) & (v >= 0) & (v < img_h)
    return u, v, in_bounds


def overlay_trajectory_on_image(
    img: np.ndarray,
    pred_xyz: np.ndarray,
    gt_xyz: np.ndarray,
    fov_deg: float = 120.0,
    title: str = "",
    ax: Optional[plt.Axes] = None,
) -> plt.Axes:
    """Overlay predicted + GT trajectories on a camera image."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 7))
    h, w = img.shape[:2]
    ax.imshow(img)

    # Ground truth
    u_gt, v_gt, mask_gt = project_trajectory_to_image(gt_xyz, h, w, fov_deg)
    if mask_gt.any():
        ax.plot(u_gt[mask_gt], v_gt[mask_gt], "o-", color="red", markersize=4,
                linewidth=2, label="Ground Truth", zorder=3)

    # Predictions
    num_samples = pred_xyz.shape[0]
    colors = cm.tab10(np.linspace(0, 1, max(num_samples, 1)))
    for s in range(num_samples):
        u_p, v_p, mask_p = project_trajectory_to_image(pred_xyz[s], h, w, fov_deg)
        if mask_p.any():
            ax.plot(u_p[mask_p], v_p[mask_p], "o-", color=colors[s % len(colors)],
                    markersize=3, linewidth=1.5, alpha=0.9,
                    label=f"Predicted #{s+1}", zorder=4)

    ax.set_title(title or "Trajectory Overlay")
    ax.legend(fontsize=8, loc="lower right")
    ax.axis("off")
    return ax


# Plot for each video
for v in all_videos:
    if "vis" not in v:
        continue
    vis = v["vis"]
    cam_indices = vis["camera_indices"]  # (N_cam,)
    num_cams = len(cam_indices)
    num_frames = vis["image_frames"].shape[0] // num_cams

    # Find front_wide camera (index 1), use last frame
    front_pos = np.where(cam_indices == 1)[0]
    if len(front_pos) == 0:
        front_pos = [0]  # fallback to first camera
    frame_idx = front_pos[0] * num_frames + (num_frames - 1)  # last frame
    front_img = vis["image_frames"][frame_idx]

    # Determine FOV from camera type
    cam_name = v["meta"].get("camera_name", "camera_front_wide_120fov")
    fov = 120.0 if "120fov" in cam_name else 30.0

    fig, ax = plt.subplots(figsize=(12, 7))
    overlay_trajectory_on_image(
        img=front_img,
        pred_xyz=vis["pred_xyz"],
        gt_xyz=vis["gt_future_xyz"],
        fov_deg=120.0,  # front_wide is 120°
        title=f"Trajectory Overlay — {v['meta']['video_id']}",
        ax=ax,
    )
    plt.tight_layout()
    plt.show()

---
## 3 — Multi-Sample Uncertainty Fan

In [None]:
def plot_uncertainty_fan(
    pred_xyz: np.ndarray,
    gt_xyz: np.ndarray,
    history_xyz: Optional[np.ndarray] = None,
    title: str = "Trajectory Uncertainty Fan",
    ax: Optional[plt.Axes] = None,
) -> plt.Axes:
    """Plot all trajectory samples with a filled uncertainty region.

    Args:
        pred_xyz: (S, 64, 3) — multiple trajectory samples.
        gt_xyz:   (64, 3) ground truth.
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(7, 7))

    S = pred_xyz.shape[0]
    # Rotate to display frame: dx = -y, dy = x
    pred_dx = -pred_xyz[:, :, 1]  # (S, 64)
    pred_dy = pred_xyz[:, :, 0]   # (S, 64)

    gt_dx = -gt_xyz[:, 1]
    gt_dy = gt_xyz[:, 0]

    if S > 1:
        # Compute per-timestep mean and spread
        mean_dx = pred_dx.mean(axis=0)
        mean_dy = pred_dy.mean(axis=0)
        std_dx = pred_dx.std(axis=0)
        std_dy = pred_dy.std(axis=0)

        # Uncertainty ellipse at each timestep (approx as ±2σ band)
        ax.fill_between(
            mean_dx,
            mean_dy - 2 * std_dy,
            mean_dy + 2 * std_dy,
            alpha=0.15, color="blue", label="±2σ spread", zorder=1,
        )
        ax.fill_between(
            mean_dx,
            mean_dy - std_dy,
            mean_dy + std_dy,
            alpha=0.25, color="blue", label="±1σ spread", zorder=1,
        )
        # Mean trajectory
        ax.plot(mean_dx, mean_dy, "-", color="blue", linewidth=2,
                label="Mean Prediction", zorder=4)

    # Individual samples
    colors = cm.cool(np.linspace(0.2, 0.8, S))
    for s in range(S):
        ax.plot(pred_dx[s], pred_dy[s], "-", color=colors[s], alpha=0.5,
                linewidth=0.8, zorder=3)

    # Ground truth
    ax.plot(gt_dx, gt_dy, "r-", linewidth=3, label="Ground Truth", zorder=5)

    # History
    if history_xyz is not None:
        hx, hy = -history_xyz[:, 1], history_xyz[:, 0]
        ax.plot(hx, hy, "s-", color="gray", markersize=3, linewidth=1.5,
                label="History", alpha=0.7, zorder=2)

    ax.plot(0, 0, "*", color="black", markersize=14, zorder=6)
    ax.set_xlabel("Lateral (m)")
    ax.set_ylabel("Forward (m)")
    ax.set_title(f"{title}  (S={S} samples)")
    ax.set_aspect("equal")
    ax.legend(fontsize=8, loc="best")
    ax.grid(True, alpha=0.3)
    return ax


for v in all_videos:
    if "vis" not in v:
        continue
    vis = v["vis"]
    fig, ax = plt.subplots(figsize=(7, 7))
    plot_uncertainty_fan(
        pred_xyz=vis["pred_xyz"],
        gt_xyz=vis["gt_future_xyz"],
        history_xyz=vis["ego_history_xyz"],
        title=f"Uncertainty — {v['meta']['video_id']}",
        ax=ax,
    )
    plt.tight_layout()
    plt.show()

---
## 4 — Animated Trajectory Video (6.4s rollout)

In [None]:
def create_trajectory_animation(
    pred_xyz: np.ndarray,
    gt_xyz: np.ndarray,
    history_xyz: Optional[np.ndarray] = None,
    title: str = "Trajectory Rollout",
    interval_ms: int = 100,
) -> animation.FuncAnimation:
    """Create frame-by-frame animation of the trajectory unfolding.

    Each animation frame reveals one more waypoint (at 10 Hz → 64 frames).
    """
    fig, ax = plt.subplots(figsize=(7, 7))
    num_steps = pred_xyz.shape[1]  # 64
    S = pred_xyz.shape[0]

    # Compute axis limits from all data
    all_x = np.concatenate([-pred_xyz[:, :, 1].flatten(), -gt_xyz[:, 1]])
    all_y = np.concatenate([pred_xyz[:, :, 0].flatten(), gt_xyz[:, 0]])
    if history_xyz is not None:
        all_x = np.concatenate([all_x, -history_xyz[:, 1]])
        all_y = np.concatenate([all_y, history_xyz[:, 0]])
    margin = 5
    xlim = (all_x.min() - margin, all_x.max() + margin)
    ylim = (all_y.min() - margin, all_y.max() + margin)

    colors = cm.tab10(np.linspace(0, 1, max(S, 1)))

    def init():
        ax.clear()
        return []

    def update(frame_idx):
        ax.clear()
        t = frame_idx + 1  # number of waypoints to show

        # History (always full)
        if history_xyz is not None:
            hx, hy = -history_xyz[:, 1], history_xyz[:, 0]
            ax.plot(hx, hy, "s-", color="gray", markersize=3,
                    linewidth=1.5, alpha=0.7, label="History")

        # Ground truth up to t
        gt_dx, gt_dy = -gt_xyz[:t, 1], gt_xyz[:t, 0]
        ax.plot(gt_dx, gt_dy, "r-", linewidth=3, label="Ground Truth")
        if t > 0:
            ax.plot(gt_dx[-1], gt_dy[-1], "ro", markersize=8)

        # Predicted up to t
        for s in range(S):
            px, py = -pred_xyz[s, :t, 1], pred_xyz[s, :t, 0]
            ax.plot(px, py, "o-", color=colors[s % len(colors)], markersize=2,
                    linewidth=1.5, alpha=0.8, label=f"Pred #{s+1}")
            if t > 0:
                ax.plot(px[-1], py[-1], "o", color=colors[s % len(colors)], markersize=6)

        ax.plot(0, 0, "*", color="black", markersize=14)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_aspect("equal")
        ax.set_xlabel("Lateral (m)")
        ax.set_ylabel("Forward (m)")
        ax.set_title(f"{title}  t={t*0.1:.1f}s / 6.4s")
        ax.legend(fontsize=7, loc="upper left")
        ax.grid(True, alpha=0.3)
        return []

    anim = animation.FuncAnimation(
        fig, update, init_func=init,
        frames=num_steps, interval=interval_ms, blit=False,
    )
    plt.close(fig)
    return anim


# Show animation for the first video
for v in all_videos[:1]:
    if "vis" not in v:
        continue
    vis = v["vis"]
    anim = create_trajectory_animation(
        pred_xyz=vis["pred_xyz"],
        gt_xyz=vis["gt_future_xyz"],
        history_xyz=vis["ego_history_xyz"],
        title=v["meta"]["video_id"],
    )
    display(HTML(anim.to_jshtml()))

---
## 5 — Heading Arrows on BEV

In [None]:
def extract_yaw_from_rotation_matrix(rot: np.ndarray) -> np.ndarray:
    """Extract yaw angle from SO(3) rotation matrices.

    rot: (..., 3, 3) -> yaw in radians (...,)
    The yaw is atan2(R[1,0], R[0,0]) — rotation about the z-axis.
    """
    return np.arctan2(rot[..., 1, 0], rot[..., 0, 0])


def plot_bev_with_heading(
    pred_xyz: np.ndarray,
    pred_rot: np.ndarray,
    gt_xyz: np.ndarray,
    gt_rot: np.ndarray,
    history_xyz: Optional[np.ndarray] = None,
    history_rot: Optional[np.ndarray] = None,
    title: str = "BEV + Heading Arrows",
    arrow_every: int = 8,
    arrow_length: float = 1.5,
    ax: Optional[plt.Axes] = None,
) -> plt.Axes:
    """BEV plot with yaw heading arrows at regular intervals."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))

    S = pred_xyz.shape[0]

    # History
    if history_xyz is not None:
        hx, hy = -history_xyz[:, 1], history_xyz[:, 0]
        ax.plot(hx, hy, "s-", color="gray", markersize=3, linewidth=1.5,
                label="History", alpha=0.7)
        if history_rot is not None:
            yaw_h = extract_yaw_from_rotation_matrix(history_rot)
            for i in range(0, len(hx), arrow_every):
                # Arrow direction in display frame: rotate yaw by 90° CCW
                dx = -np.sin(yaw_h[i]) * arrow_length
                dy = np.cos(yaw_h[i]) * arrow_length
                ax.annotate("", xy=(hx[i]+dx, hy[i]+dy), xytext=(hx[i], hy[i]),
                            arrowprops=dict(arrowstyle="->", color="gray", lw=1.5))

    # Ground truth with arrows
    gt_dx, gt_dy = -gt_xyz[:, 1], gt_xyz[:, 0]
    ax.plot(gt_dx, gt_dy, "r-", linewidth=3, label="Ground Truth", zorder=3)
    yaw_gt = extract_yaw_from_rotation_matrix(gt_rot)
    for i in range(0, len(gt_dx), arrow_every):
        dx = -np.sin(yaw_gt[i]) * arrow_length
        dy = np.cos(yaw_gt[i]) * arrow_length
        ax.annotate("", xy=(gt_dx[i]+dx, gt_dy[i]+dy), xytext=(gt_dx[i], gt_dy[i]),
                    arrowprops=dict(arrowstyle="->", color="red", lw=2), zorder=5)

    # Predicted with arrows
    colors = cm.tab10(np.linspace(0, 1, max(S, 1)))
    for s in range(S):
        px, py = -pred_xyz[s, :, 1], pred_xyz[s, :, 0]
        ax.plot(px, py, "o-", color=colors[s % len(colors)], markersize=2,
                linewidth=1.5, alpha=0.8, label=f"Pred #{s+1}", zorder=4)
        yaw_p = extract_yaw_from_rotation_matrix(pred_rot[s])
        for i in range(0, len(px), arrow_every):
            dx = -np.sin(yaw_p[i]) * arrow_length
            dy = np.cos(yaw_p[i]) * arrow_length
            ax.annotate("", xy=(px[i]+dx, py[i]+dy), xytext=(px[i], py[i]),
                        arrowprops=dict(arrowstyle="->", color=colors[s % len(colors)],
                                        lw=1.5, alpha=0.8), zorder=5)

    ax.plot(0, 0, "*", color="black", markersize=14, zorder=6)
    ax.set_xlabel("Lateral (m)")
    ax.set_ylabel("Forward (m)")
    ax.set_title(title)
    ax.set_aspect("equal")
    ax.legend(fontsize=8, loc="best")
    ax.grid(True, alpha=0.3)
    return ax


for v in all_videos:
    if "vis" not in v:
        continue
    vis = v["vis"]
    fig, ax = plt.subplots(figsize=(8, 8))
    plot_bev_with_heading(
        pred_xyz=vis["pred_xyz"],
        pred_rot=vis["pred_rot"],
        gt_xyz=vis["gt_future_xyz"],
        gt_rot=vis["gt_future_rot"],
        history_xyz=vis["ego_history_xyz"],
        history_rot=vis["ego_history_rot"],
        title=f"Heading Arrows — {v['meta']['video_id']}",
        ax=ax,
    )
    plt.tight_layout()
    plt.show()

---
## 6 — Reasoning Trace (Chain-of-Causation) Display

In [None]:
def display_reasoning_traces(video_result: dict):
    """Display CoC reasoning traces alongside the front camera image."""
    meta = video_result["meta"]
    traces = meta.get("reasoning_traces", [])
    has_vis = "vis" in video_result

    if has_vis:
        vis = video_result["vis"]
        cam_indices = vis["camera_indices"]
        num_cams = len(cam_indices)
        num_frames = vis["image_frames"].shape[0] // num_cams

        # Find front_wide camera
        front_pos = np.where(cam_indices == 1)[0]
        if len(front_pos) == 0:
            front_pos = [0]
        frame_idx = front_pos[0] * num_frames + (num_frames - 1)
        front_img = vis["image_frames"][frame_idx]

        fig, (ax_img, ax_txt) = plt.subplots(1, 2, figsize=(16, 5),
                                              gridspec_kw={"width_ratios": [1, 1]})
        ax_img.imshow(front_img)
        ax_img.set_title(f"Front Camera — {meta['video_id']}")
        ax_img.axis("off")
    else:
        fig, ax_txt = plt.subplots(figsize=(10, 5))

    # Reasoning text
    ax_txt.axis("off")
    ax_txt.set_title("Chain-of-Causation Reasoning", fontsize=12, fontweight="bold")

    if not traces:
        ax_txt.text(0.05, 0.5, "No reasoning traces available.",
                    transform=ax_txt.transAxes, fontsize=10, va="center")
    else:
        y_pos = 0.95
        for i, trace in enumerate(traces):
            wrapped = textwrap.fill(trace.strip(), width=60)
            header = f"Trajectory #{i+1}:"
            ax_txt.text(0.02, y_pos, header, transform=ax_txt.transAxes,
                        fontsize=9, fontweight="bold", va="top",
                        fontfamily="monospace")
            y_pos -= 0.06
            ax_txt.text(0.02, y_pos, wrapped, transform=ax_txt.transAxes,
                        fontsize=8, va="top", fontfamily="serif",
                        linespacing=1.4)
            # Estimate vertical space used
            num_lines = wrapped.count("\n") + 1
            y_pos -= 0.05 * num_lines + 0.04
            if y_pos < 0.05:
                break

    plt.tight_layout()
    plt.show()

    # Also print full text for copy-paste
    print(f"\n{'='*60}")
    print(f"Reasoning Traces for: {meta['video_id']}")
    print(f"{'='*60}")
    for i, trace in enumerate(traces):
        print(f"\n--- Trajectory #{i+1} ---")
        print(trace.strip())


for v in all_videos:
    display_reasoning_traces(v)

---
## 7 — Multi-Camera Grid with BEV Trajectory

In [None]:
def plot_multicam_grid_with_bev(video_result: dict):
    """Show all camera views (last frame each) tiled alongside BEV trajectory."""
    if "vis" not in video_result:
        print(f"No visualization data for {video_result['meta']['video_id']}")
        return

    meta = video_result["meta"]
    vis = video_result["vis"]
    cam_indices = vis["camera_indices"]
    num_cams = len(cam_indices)
    num_frames_per_cam = vis["image_frames"].shape[0] // num_cams

    # Layout: top row = cameras, bottom = BEV plot
    fig = plt.figure(figsize=(16, 10))
    gs = GridSpec(2, num_cams, figure=fig, height_ratios=[1, 1.2], hspace=0.25, wspace=0.05)

    # Top row: camera views (last frame of each camera)
    for c in range(num_cams):
        ax_cam = fig.add_subplot(gs[0, c])
        frame_idx = c * num_frames_per_cam + (num_frames_per_cam - 1)  # last frame
        img = vis["image_frames"][frame_idx]
        ax_cam.imshow(img)
        cam_name = INDEX_TO_CAMERA.get(int(cam_indices[c]), f"cam_{cam_indices[c]}")
        ax_cam.set_title(cam_name, fontsize=9)
        ax_cam.axis("off")

    # Bottom row: BEV trajectory spanning all columns
    ax_bev = fig.add_subplot(gs[1, :])
    plot_bev_trajectory(
        pred_xyz=vis["pred_xyz"],
        gt_xyz=vis["gt_future_xyz"],
        history_xyz=vis["ego_history_xyz"],
        title=f"BEV Trajectory — {meta['video_id']}",
        ax=ax_bev,
    )

    fig.suptitle(f"Multi-Camera View + BEV — {meta['video_id']}", fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()


for v in all_videos:
    plot_multicam_grid_with_bev(v)

---
## 8 — Per-Video Dashboard (cameras + BEV + reasoning + metrics)

In [None]:
def plot_full_dashboard(video_result: dict):
    """Complete per-video dashboard combining all visual outputs."""
    meta = video_result["meta"]
    has_vis = "vis" in video_result

    if not has_vis:
        print(f"No visualization data for {meta['video_id']}")
        return

    vis = video_result["vis"]
    cam_indices = vis["camera_indices"]
    num_cams = len(cam_indices)
    num_frames_per_cam = vis["image_frames"].shape[0] // num_cams

    # Layout:
    # Row 0: 4 camera views
    # Row 1: [BEV trajectory] [BEV with heading]
    # Row 2: [Reasoning text] [Metrics table]
    fig = plt.figure(figsize=(20, 18))
    gs = GridSpec(3, 4, figure=fig, height_ratios=[0.8, 1.2, 0.8],
                  hspace=0.3, wspace=0.3)

    # ── Row 0: Camera views ──
    for c in range(min(num_cams, 4)):
        ax = fig.add_subplot(gs[0, c])
        frame_idx = c * num_frames_per_cam + (num_frames_per_cam - 1)
        ax.imshow(vis["image_frames"][frame_idx])
        cam_name = INDEX_TO_CAMERA.get(int(cam_indices[c]), f"cam_{cam_indices[c]}")
        ax.set_title(cam_name, fontsize=9)
        ax.axis("off")

    # ── Row 1 left: BEV trajectory ──
    ax_bev = fig.add_subplot(gs[1, :2])
    plot_bev_trajectory(
        pred_xyz=vis["pred_xyz"],
        gt_xyz=vis["gt_future_xyz"],
        history_xyz=vis["ego_history_xyz"],
        title="BEV Trajectory",
        ax=ax_bev,
    )

    # ── Row 1 right: BEV with heading arrows ──
    ax_heading = fig.add_subplot(gs[1, 2:])
    plot_bev_with_heading(
        pred_xyz=vis["pred_xyz"],
        pred_rot=vis["pred_rot"],
        gt_xyz=vis["gt_future_xyz"],
        gt_rot=vis["gt_future_rot"],
        history_xyz=vis["ego_history_xyz"],
        history_rot=vis["ego_history_rot"],
        title="BEV + Heading",
        ax=ax_heading,
    )

    # ── Row 2 left: Reasoning trace ──
    ax_cot = fig.add_subplot(gs[2, :2])
    ax_cot.axis("off")
    ax_cot.set_title("Chain-of-Causation Reasoning", fontsize=11, fontweight="bold",
                     loc="left")
    traces = meta.get("reasoning_traces", [])
    if traces:
        y_pos = 0.95
        for i, trace in enumerate(traces[:3]):  # max 3 traces
            wrapped = textwrap.fill(trace.strip(), width=80)
            ax_cot.text(0.02, y_pos, f"Traj #{i+1}: {wrapped}",
                        transform=ax_cot.transAxes, fontsize=7.5, va="top",
                        fontfamily="serif", linespacing=1.3)
            num_lines = wrapped.count("\n") + 1
            y_pos -= 0.04 * num_lines + 0.06
    else:
        ax_cot.text(0.02, 0.5, "No reasoning traces.",
                    transform=ax_cot.transAxes, fontsize=10, va="center")

    # ── Row 2 right: Metrics table ──
    ax_metrics = fig.add_subplot(gs[2, 2:])
    ax_metrics.axis("off")
    ax_metrics.set_title("Inference Metrics", fontsize=11, fontweight="bold", loc="left")

    metrics = meta.get("metrics", {})
    table_data = [
        ["Video ID", meta.get("video_id", "N/A")],
        ["Clip ID", meta.get("clip_id", "N/A")],
        ["Camera", meta.get("camera_name", "N/A")],
        ["minADE", f"{meta.get('min_ade_meters', 0):.4f} m"],
        ["# Trajectories", str(meta.get("num_trajectories", 0))],
        ["Inference Time", f"{metrics.get('inference_time_seconds', 0):.1f} s"],
        ["GPU Peak", f"{metrics.get('gpu_memory_peak_gb', 0):.2f} GB"],
        ["RAM Peak", f"{metrics.get('ram_peak_mb', 0):.0f} MB"],
        ["FPS (video)", f"{meta.get('temporal_config', {}).get('vp_video_fps', 'N/A')}"],
        ["Status", "SUCCESS" if meta.get("success") else "FAILED"],
    ]
    table = ax_metrics.table(
        cellText=table_data,
        colLabels=["Metric", "Value"],
        loc="center",
        cellLoc="left",
    )
    table.auto_set_font_size(False)
    table.set_fontsize(9)
    table.scale(1, 1.4)
    # Style header
    for j in range(2):
        table[0, j].set_facecolor("#4472C4")
        table[0, j].set_text_props(color="white", fontweight="bold")

    fig.suptitle(
        f"Alpamayo VLA Dashboard — {meta['video_id']}",
        fontsize=16, fontweight="bold", y=1.01,
    )
    plt.tight_layout()
    plt.show()


for v in all_videos:
    plot_full_dashboard(v)

---
## 9 — Aggregate Metrics Charts (across all videos)

In [None]:
def plot_aggregate_metrics(all_videos: list[dict]):
    """Bar/box plots of minADE, inference time, GPU memory across all videos."""
    # Gather metrics
    video_ids = []
    min_ades = []
    inf_times = []
    gpu_peaks = []
    ram_peaks = []
    statuses = []

    for v in all_videos:
        m = v["meta"]
        metrics = m.get("metrics", {})
        vid = m.get("video_id", "unknown")
        # Truncate long names for display
        short_id = vid[:25] + "…" if len(vid) > 25 else vid

        video_ids.append(short_id)
        min_ades.append(m.get("min_ade_meters", float("nan")))
        inf_times.append(metrics.get("inference_time_seconds", 0))
        gpu_peaks.append(metrics.get("gpu_memory_peak_gb", 0))
        ram_peaks.append(metrics.get("ram_peak_mb", 0))
        statuses.append("green" if m.get("success") else "red")

    n = len(video_ids)
    x = np.arange(n)

    fig, axes = plt.subplots(2, 2, figsize=(16, 10))

    # 1. minADE bar chart
    ax = axes[0, 0]
    bars = ax.bar(x, min_ades, color=statuses, alpha=0.8, edgecolor="black", linewidth=0.5)
    ax.set_xticks(x)
    ax.set_xticklabels(video_ids, rotation=45, ha="right", fontsize=7)
    ax.set_ylabel("minADE (meters)")
    ax.set_title("minADE per Video")
    ax.axhline(np.nanmean(min_ades), color="blue", linestyle="--", linewidth=1,
               label=f"Mean: {np.nanmean(min_ades):.4f}m")
    ax.legend(fontsize=8)
    ax.grid(axis="y", alpha=0.3)

    # 2. Inference time bar chart
    ax = axes[0, 1]
    ax.bar(x, inf_times, color="#4472C4", alpha=0.8, edgecolor="black", linewidth=0.5)
    ax.set_xticks(x)
    ax.set_xticklabels(video_ids, rotation=45, ha="right", fontsize=7)
    ax.set_ylabel("Time (seconds)")
    ax.set_title("Inference Time per Video")
    ax.axhline(np.mean(inf_times), color="red", linestyle="--", linewidth=1,
               label=f"Mean: {np.mean(inf_times):.1f}s")
    ax.legend(fontsize=8)
    ax.grid(axis="y", alpha=0.3)

    # 3. GPU memory peak bar chart
    ax = axes[1, 0]
    ax.bar(x, gpu_peaks, color="#ED7D31", alpha=0.8, edgecolor="black", linewidth=0.5)
    ax.set_xticks(x)
    ax.set_xticklabels(video_ids, rotation=45, ha="right", fontsize=7)
    ax.set_ylabel("GPU Memory Peak (GB)")
    ax.set_title("GPU Memory Peak per Video")
    ax.axhline(np.mean(gpu_peaks), color="red", linestyle="--", linewidth=1,
               label=f"Mean: {np.mean(gpu_peaks):.2f} GB")
    ax.legend(fontsize=8)
    ax.grid(axis="y", alpha=0.3)

    # 4. Box plot summary (minADE + inference time + GPU)
    ax = axes[1, 1]
    valid_ades = [a for a in min_ades if not np.isnan(a)]
    box_data = []
    box_labels = []
    if valid_ades:
        box_data.append(valid_ades)
        box_labels.append(f"minADE\n(m)")
    if inf_times:
        box_data.append(inf_times)
        box_labels.append(f"Inf Time\n(s)")
    if gpu_peaks:
        box_data.append(gpu_peaks)
        box_labels.append(f"GPU Peak\n(GB)")
    if ram_peaks:
        # Convert to GB for better display range
        box_data.append([r / 1024 for r in ram_peaks])
        box_labels.append(f"RAM Peak\n(GB)")

    if box_data:
        bp = ax.boxplot(box_data, labels=box_labels, patch_artist=True,
                        medianprops={"color": "red", "linewidth": 2})
        box_colors = ["#4472C4", "#ED7D31", "#A5A5A5", "#70AD47"]
        for patch, color in zip(bp["boxes"], box_colors[:len(bp["boxes"])]):
            patch.set_facecolor(color)
            patch.set_alpha(0.6)
    ax.set_title("Metric Distributions")
    ax.grid(axis="y", alpha=0.3)

    fig.suptitle(
        f"Aggregate Metrics — {n} video(s)",
        fontsize=14, fontweight="bold",
    )
    plt.tight_layout()
    plt.show()

    # Print summary table
    print(f"\n{'='*70}")
    print(f"AGGREGATE SUMMARY ({n} videos)")
    print(f"{'='*70}")
    print(f"  minADE — mean: {np.nanmean(min_ades):.4f}m, "
          f"median: {np.nanmedian(min_ades):.4f}m, "
          f"best: {np.nanmin(min_ades):.4f}m, "
          f"worst: {np.nanmax(min_ades):.4f}m")
    print(f"  Time  — mean: {np.mean(inf_times):.1f}s, "
          f"total: {sum(inf_times):.1f}s ({sum(inf_times)/60:.1f}min)")
    print(f"  GPU   — mean peak: {np.mean(gpu_peaks):.2f} GB, "
          f"max peak: {max(gpu_peaks):.2f} GB")
    print(f"  RAM   — mean peak: {np.mean(ram_peaks):.0f} MB, "
          f"max peak: {max(ram_peaks):.0f} MB")
    success_count = sum(1 for s in statuses if s == "green")
    print(f"  Status: {success_count}/{n} successful")


plot_aggregate_metrics(all_videos)

---
## Save All Figures to Disk

In [None]:
# Optional: save all dashboards as PNG for reports
SAVE_DIR = os.path.join(RESULTS_DIR, "visualizations")
os.makedirs(SAVE_DIR, exist_ok=True)

for v in all_videos:
    if "vis" not in v:
        continue
    meta = v["meta"]
    vis = v["vis"]
    vid = meta["video_id"]

    # BEV
    fig, ax = plt.subplots(figsize=(7, 7))
    plot_bev_trajectory(vis["pred_xyz"], vis["gt_future_xyz"],
                        vis["ego_history_xyz"], title=f"BEV — {vid}", ax=ax)
    fig.savefig(os.path.join(SAVE_DIR, f"{vid}_bev.png"), dpi=150, bbox_inches="tight")
    plt.close(fig)

    # BEV + heading
    fig, ax = plt.subplots(figsize=(8, 8))
    plot_bev_with_heading(
        vis["pred_xyz"], vis["pred_rot"],
        vis["gt_future_xyz"], vis["gt_future_rot"],
        vis["ego_history_xyz"], vis["ego_history_rot"],
        title=f"Heading — {vid}", ax=ax,
    )
    fig.savefig(os.path.join(SAVE_DIR, f"{vid}_heading.png"), dpi=150, bbox_inches="tight")
    plt.close(fig)

    # Uncertainty fan
    fig, ax = plt.subplots(figsize=(7, 7))
    plot_uncertainty_fan(vis["pred_xyz"], vis["gt_future_xyz"],
                         vis["ego_history_xyz"], title=f"Uncertainty — {vid}", ax=ax)
    fig.savefig(os.path.join(SAVE_DIR, f"{vid}_uncertainty.png"), dpi=150, bbox_inches="tight")
    plt.close(fig)

print(f"Saved visualizations to: {SAVE_DIR}")
print(f"Files: {sorted(os.listdir(SAVE_DIR))}")