# Flight Trajectory Prediction Results Analysis

This notebook evaluates the performance of the generative flight trajectory prediction model through various metrics, visualizations, and calibration analysis.

## Setup and Configuration

In [None]:
# Core imports
from pathlib import Path
import json
import numpy as np
import pandas as pd
import torch
import pathlib

# External libraries
from traffic.core import Traffic
import pyproj
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Project modules
from utils.utils import cache_paths
from utils.inference_utils import sample_many, denorm_seq_to_global
from utils.metrics import pit_values

# Configuration
PARQUET_PATH = "trajs_LSAS_filtered.parquet"
CACHE_DIR = pathlib.Path("./dataset_cache")
CKPT_PATH = "models/model_1min.pt"
CACHE_KEY_FILE = "ecec4b007a021fa3.key.json"
OUTPUT_STRIDE_SECONDS = 5

# Device configuration
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print(f"Using device: {DEVICE}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: mps


## Data Loading

In [2]:
# Load cached dataset artifacts
def load_cache_key():
    """Load dataset cache key file."""
    if CACHE_KEY_FILE:
        key_path = Path(CACHE_KEY_FILE)
        if not key_path.is_absolute():
            key_path = CACHE_DIR / key_path
        if not key_path.exists():
            raise FileNotFoundError(
                f"Specified cache key file not found: {key_path}. "
                f"Make sure it exists under {CACHE_DIR}."
            )
    else:
        key_files = sorted(
            CACHE_DIR.glob("*.key.json"),
            key=lambda p: p.stat().st_mtime,
            reverse=True,
        )
        if not key_files:
            raise FileNotFoundError(
                "No dataset_cache key files found. Run the training notebook first."
            )
        key_path = key_files[0]
    return key_path

# Load cache key and dataset information
key_path = load_cache_key()
key_info = json.loads(key_path.read_text())
print(f"[cache] using key: {key_path.name}")

dset_key = key_info["dataset_key"]
stats_key = key_info["stats_key"]
paths = cache_paths(dset_key, stats_key)

# Load dataset arrays (using memory-mapped files for efficiency)
X_train = np.load(paths["x_tr"], mmap_mode="r")
Y_train = np.load(paths["y_tr"], mmap_mode="r")
C_train = np.load(paths["c_tr"], mmap_mode="r")
X_val = np.load(paths["x_va"], mmap_mode="r")
Y_val = np.load(paths["y_va"], mmap_mode="r")
C_val = np.load(paths["c_va"], mmap_mode="r")
X_test = np.load(paths["x_te"], mmap_mode="r")
Y_test = np.load(paths["y_te"], mmap_mode="r")
C_test = np.load(paths["c_te"], mmap_mode="r")

# Load normalization statistics and metadata
norm_stats = json.loads(paths["stats"].read_text())
meta_train = pd.read_parquet(paths["meta_tr"])  # per-window metadata
meta_val = pd.read_parquet(paths["meta_va"])
meta_test = pd.read_parquet(paths["meta_te"])
manifest = json.loads(paths["manifest"].read_text())
summary = json.loads(paths["summary"].read_text())

# Extract normalization parameters
feat_mean = norm_stats["feat_mean"]
feat_std = norm_stats["feat_std"]
ctx_mean = norm_stats["ctx_mean"]
ctx_std = norm_stats["ctx_std"]

# Display dataset sizes
dataset_sizes = {
    k: (int(v) if isinstance(v, (int, np.integer)) else v)
    for k, v in summary.get("sizes", {}).items()
}
print("Dataset sizes:", dataset_sizes)

# Load trajectory data
trajs = Traffic.from_file(PARQUET_PATH)
print(f"Loaded {trajs.data.flight_id.nunique()} trajectories from {PARQUET_PATH}")

[cache] using key: ecec4b007a021fa3.key.json
Dataset sizes: {'train': 1000000, 'val': 200000, 'test': 200000}
Loaded 178947 trajectories from trajs_LSAS_filtered.parquet


## Model Architecture

Import the Flow Matching Model components for trajectory prediction.

In [3]:
# Import model architecture from dedicated module
from model import FlowMatchingModel, load_model_checkpoint

## Model Loading

In [4]:
# Load the pre-trained model
model = load_model_checkpoint(CKPT_PATH, DEVICE)
print(f"Loaded checkpoint: {CKPT_PATH}")



Loaded checkpoint: models/model_1min.pt


## Sample Generation and Visualization

Generate trajectory predictions and create visualizations.

In [5]:
# Choose trajectories for analysis
CASE_LIST = [6841, 10000, 12345]   # specific trajectory indices
N_TRAJS = len(CASE_LIST)
N_SAMPLES = 50
TIMESTEPS = int(Y_test.shape[1])

# Prepare input data
x_hist_batch = torch.from_numpy(np.array([X_test[idx] for idx in CASE_LIST])).to(DEVICE).contiguous()
ctx_batch = torch.from_numpy(np.array([C_test[idx] for idx in CASE_LIST])).to(DEVICE).contiguous()

print(f"Processing {N_TRAJS} trajectories with {N_SAMPLES} samples each")

# Generate future trajectory samples
y_norm_all = sample_many(
    model,
    x_hist_batch,
    ctx_batch,
    T_out=TIMESTEPS,
    n_steps=64,
    G=1.0,
    n_samples=N_SAMPLES,
    chunk=128,
)

# Convert back to global coordinates
y_glob_all = (
    denorm_seq_to_global(
        y_norm_all,
        ctx_batch.repeat(N_SAMPLES, 1),
        feat_mean, feat_std, ctx_mean, ctx_std,
    )
    .cpu()
    .numpy()
    .reshape(N_SAMPLES, N_TRAJS, TIMESTEPS, -1)
)

# Get history and ground truth in global coordinates
x_hist_glob = (
    denorm_seq_to_global(
        x_hist_batch, ctx_batch, feat_mean, feat_std, ctx_mean, ctx_std
    )
    .cpu()
    .numpy()[:, :, :3]
)

y_true_glob = (
    denorm_seq_to_global(
        torch.from_numpy(np.array([Y_test[idx] for idx in CASE_LIST])).to(DEVICE),
        ctx_batch, feat_mean, feat_std, ctx_mean, ctx_std,
    )
    .cpu()
    .numpy()[:, :, :3]
)

print(f"Generated predictions with shape: {y_glob_all.shape}")
print(f"History shape: {x_hist_glob.shape}, Ground truth shape: {y_true_glob.shape}")

Processing 3 trajectories with 50 samples each
Generated predictions with shape: (50, 3, 12, 7)
History shape: (3, 60, 3), Ground truth shape: (3, 12, 3)


## Coordinate System Conversion

Convert between different coordinate systems for visualization.

In [6]:
# Coordinate transformation setup
crs_lv95 = pyproj.CRS.from_epsg(2056)  # Swiss LV95
crs_wgs84 = pyproj.CRS.from_epsg(4326)  # WGS84 (lat/lon)
to_wgs84 = pyproj.Transformer.from_crs(crs_lv95, crs_wgs84, always_xy=True)

def xy_to_lonlat_arr(xy: np.ndarray) -> np.ndarray:
    """Convert XY coordinates to longitude/latitude."""
    lons, lats = to_wgs84.transform(xy[:, 0], xy[:, 1])
    return np.stack([lons, lats], axis=1)

def concat_polyline_lonlat(list_of_ll_arrays):
    """Concatenate multiple lon/lat arrays with None separators for Plotly."""
    lons, lats = [], []
    for arr in list_of_ll_arrays:
        lons.extend(arr[:, 0].tolist())
        lats.extend(arr[:, 1].tolist())
        lons.append(None)
        lats.append(None)
    if lons: lons.pop()  # Remove last None
    if lats: lats.pop()  # Remove last None
    return lons, lats

## Spaghetti Plot Visualization

Create spaghetti plots showing multiple trajectory predictions.

In [7]:
# Color scheme for plots
COLOR_HISTORY = "black"
COLOR_GT = "red"
COLOR_PRED = "#1f77b4"
COLOR_MEAN = "#e19f20"

# Create subplot grid
cols = min(3, N_TRAJS)
fig_spaghetti = make_subplots(
    rows=1, cols=3,
    specs=[[{"type": "map"} for _ in range(3)]],
    subplot_titles=[
        f"{meta_test.loc[CASE_LIST[i]]['flight_id'].split('_')[1]}" if i < N_TRAJS else "—"
        for i in range(3)
    ],
    horizontal_spacing=0.02,
)

for b in range(cols):
    # Convert coordinates to lon/lat
    hist_xy = x_hist_glob[b, :, :2]
    true_xy = y_true_glob[b, :, :2]
    samples_xy = [y_glob_all[s, b, :, :2] for s in range(N_SAMPLES)]

    hist_ll = xy_to_lonlat_arr(hist_xy)
    true_ll = xy_to_lonlat_arr(true_xy)
    samples_ll = [xy_to_lonlat_arr(sxy) for sxy in samples_xy]

    # Center map on trajectory
    all_ll = np.vstack([hist_ll, true_ll])
    lon_center = float((np.nanmin(all_ll[:, 0]) + np.nanmax(all_ll[:, 0])) * 0.5)
    lat_center = float((np.nanmin(all_ll[:, 1]) + np.nanmax(all_ll[:, 1])) * 0.5)

    # Concatenate samples into one polyline trace
    samp_lon, samp_lat = concat_polyline_lonlat(samples_ll)

    showleg = (b == 0)

    # Predicted samples (blue, semi-transparent spaghetti)
    fig_spaghetti.add_trace(
        go.Scattermap(
            lon=samp_lon, lat=samp_lat,
            mode="lines",
            line=dict(width=1, color=COLOR_PRED),
            opacity=0.5,
            name=f"{N_SAMPLES} samples", showlegend=showleg
        ),
        row=1, col=b+1
    )

    # Sample mean (orange)
    mean_ll = np.mean(np.stack(samples_ll, axis=0), axis=0)
    fig_spaghetti.add_trace(
        go.Scattermap(
            lon=mean_ll[:, 0], lat=mean_ll[:, 1],
            mode="lines",
            line=dict(width=3, color=COLOR_MEAN),
            name="Sample mean", showlegend=showleg
        ),
        row=1, col=b+1
    )

    # History (black)
    fig_spaghetti.add_trace(
        go.Scattermap(
            lon=hist_ll[:, 0], lat=hist_ll[:, 1],
            mode="lines",
            line=dict(width=1.5, color=COLOR_HISTORY),
            name="History", showlegend=showleg
        ),
        row=1, col=b+1
    )

    # Ground Truth (red, on top)
    fig_spaghetti.add_trace(
        go.Scattermap(
            lon=true_ll[:, 0], lat=true_ll[:, 1],
            mode="lines+markers",
            line=dict(width=1, color=COLOR_GT),
            marker=dict(size=6, color=COLOR_GT),
            name="Ground Truth", showlegend=showleg
        ),
        row=1, col=b+1
    )

    # Map settings
    key = "" if b == 0 else str(b+1)
    fig_spaghetti.update_layout({
        f"map{key}": dict(
            style="carto-positron",
            center=dict(lon=lon_center, lat=lat_center),
            zoom=10.5,
        )
    })

# Fill empty panels if fewer than 3
for k in range(2, 4):
    if k > cols:
        fig_spaghetti.update_layout({f"map{k}": dict(style="carto-positron")})

fig_spaghetti.update_layout(
    margin=dict(l=10, r=10, t=50, b=10),
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.15,
        xanchor="center",
        x=0.5
    )
)

# Save and display
fig_spaghetti.write_html('figures/spaghetti.html')
fig_spaghetti.show()

## Flow Visualization

Show the flow

In [16]:
# === Helpers =================================================================
def _array_len(arr) -> int:
    if torch.is_tensor(arr):
        return int(arr.shape[0])
    return int(np.array(arr).shape[0])


def _safe_select_case(arr, idx, device):
    """Select 1 case; index is wrapped modulo array length; supports numpy/torch."""
    n = _array_len(arr)
    assert n > 0, "Empty input array"
    idx = idx % n
    if torch.is_tensor(arr):
        return arr[idx : idx + 1].to(device)
    arr_np = np.array(arr)
    return torch.from_numpy(arr_np[idx : idx + 1]).to(device)


def _percentile_scale(
    vec_xy: np.ndarray, gvec_xy: np.ndarray, narrow: float, grid_factor=0.6, pct=95.0
):
    """Map 95th percentile magnitude to `narrow` for samples and `grid_factor*narrow` for grid."""
    vmag = np.hypot(vec_xy[:, 0], vec_xy[:, 1]) if vec_xy.size else np.array([0.0])
    gvmag = np.hypot(gvec_xy[:, 0], gvec_xy[:, 1]) if gvec_xy.size else np.array([0.0])
    vmax = np.nanpercentile(np.concatenate([vmag, gvmag]), pct) + 1e-9
    return (narrow / vmax), (grid_factor * narrow / vmax)


@torch.no_grad()
def collect_frames_with_grid(
    model,
    x_hist,
    ctx,
    T_out=12,
    n_steps=128,
    G=1.0,
    use_autocast=True,
    frame_stride=1,
    grid_xy_global=None,
    true_xy_global=None,
    grid_chunk=512,
):
    """
    Returns per-frame lists:
      - frames_pos_xy: list[np.ndarray (T,2)] predicted positions in global XY
      - frames_vec_xy: list[np.ndarray (T,2)] trajectory velocities at sample positions (global XY)
      - frames_grid_vec_xy: list[np.ndarray (M,2)] denoising vectors over grid points (global XY)
      - grid_xy_global: np.ndarray (M,2) grid anchor points in global XY (returned if None was passed)
    Requires B == 1 (single case).
    """
    model.eval()
    device_local = x_hist.device
    B, _, D = x_hist.shape
    assert B == 1, "B must be 1 for this animation."

    # Context: denorm & rotation
    C = ctx.size(-1)
    cm = torch.as_tensor(ctx_mean[:C], dtype=ctx.dtype, device=device_local).view(1, C)
    cs = torch.as_tensor(ctx_std[:C], dtype=ctx.dtype, device=device_local).view(1, C)
    ctx_raw = ctx * cs + cm
    c = ctx_raw[:, 3:4]
    s = ctx_raw[:, 4:5]
    ref_xy = ctx_raw[:, :2]

    fm = torch.as_tensor(feat_mean, dtype=x_hist.dtype, device=device_local).view(
        1, 1, -1
    )
    fs = torch.as_tensor(feat_std, dtype=x_hist.dtype, device=device_local).view(
        1, 1, -1
    )

    # Build grid if needed
    if grid_xy_global is None:
        hist_xy = (
            denorm_seq_to_global(x_hist, ctx, feat_mean, feat_std, ctx_mean, ctx_std)[
                0, :, :2
            ]
            .detach()
            .cpu()
            .numpy()
        )
        all_pos = (
            np.vstack([hist_xy, true_xy_global])
            if true_xy_global is not None
            else hist_xy
        )
        xmin, xmax = float(np.nanmin(all_pos[:, 0])), float(np.nanmax(all_pos[:, 0]))
        ymin, ymax = float(np.nanmin(all_pos[:, 1])), float(np.nanmax(all_pos[:, 1]))
        mx = 0.05 * max(1.0, xmax - xmin)
        my = 0.05 * max(1.0, ymax - ymin)
        gx = np.linspace(xmin - mx, xmax + mx, 15)
        gy = np.linspace(ymin - my, ymax + my, 15)
        GX, GY = np.meshgrid(gx, gy)
        grid_xy_global = np.stack([GX.reshape(-1), GY.reshape(-1)], axis=-1)

    # Langevin/Heun steps
    x = torch.randn(B, T_out, D, device=device_local, dtype=x_hist.dtype)
    dt = 1.0 / n_steps

    frames_pos_xy, frames_vec_xy, frames_grid_vec_xy = [], [], []
    amp_enabled = (device_local.type == "cuda") and use_autocast
    amp_dtype = torch.bfloat16

    def compute_grid_vectors(t_scalar: float) -> np.ndarray:
        M = grid_xy_global.shape[0]
        grid_xy_t = torch.as_tensor(
            grid_xy_global, dtype=x_hist.dtype, device=device_local
        )
        dx = grid_xy_t[:, 0:1] - ref_xy[:, 0:1]
        dy = grid_xy_t[:, 1:2] - ref_xy[:, 1:2]
        local_x = c * dx - s * dy
        local_y = s * dx + c * dy
        local = torch.cat([local_x, local_y], dim=-1)
        fm_xy = fm[..., :2].reshape(1, 2)
        fs_xy = fs[..., :2].reshape(1, 2)
        local_norm = (local - fm_xy) / fs_xy
        xt = torch.zeros((M, 1, D), dtype=x_hist.dtype, device=device_local)
        xt[:, 0, 0:2] = local_norm
        xh = x_hist.repeat(M, 1, 1)
        cctx = ctx.repeat(M, 1)
        tt = torch.full(
            (M, 1), float(t_scalar), dtype=x_hist.dtype, device=device_local
        )
        with torch.amp.autocast(
            device_type="cuda", dtype=amp_dtype, enabled=amp_enabled
        ):
            v = model(xh, xt, tt, cctx)  # (M,1,D)
        v_xy_local = v[:, 0:1, 0:2] * fs[:, :, 0:2]
        vx_g = c * v_xy_local[:, :, 0:1] + s * v_xy_local[:, :, 1:2]
        vy_g = -s * v_xy_local[:, :, 0:1] + c * v_xy_local[:, :, 1:2]
        v_xy = torch.cat([vx_g, vy_g], dim=-1)[:, 0, :]
        return v_xy.detach().cpu().numpy()

    x_prev = x.clone()
    for k in range(n_steps):
        t0v = min(k * dt, 1.0 - 1e-6)
        t1v = min((k + 1) * dt, 1.0 - 1e-6)
        t0 = torch.full((B, 1), t0v, device=device_local, dtype=x_hist.dtype)
        t1 = torch.full((B, 1), t1v, device=device_local, dtype=x_hist.dtype)

        with torch.amp.autocast(
            device_type="cuda", dtype=amp_dtype, enabled=amp_enabled
        ):
            v1 = model(x_hist, x, t0, ctx)
            x_pred = x + (G * v1) * dt
            v2 = model(x_hist, x_pred, t1, ctx)
            x = x + 0.5 * (G * v1 + G * v2) * dt

        if (k % frame_stride == 0) or (k == n_steps - 1):
            pos_glob = denorm_seq_to_global(
                x, ctx, feat_mean, feat_std, ctx_mean, ctx_std
            )
            pos_xy = pos_glob[0, :, :2].detach().cpu().numpy()
            pos_glob_prev = denorm_seq_to_global(
                x_prev, ctx, feat_mean, feat_std, ctx_mean, ctx_std
            )
            vel_xy = (
                pos_glob[0, :, :2] - pos_glob_prev[0, :, :2]
            ).detach().cpu().numpy() / dt
            frames_pos_xy.append(pos_xy)
            frames_vec_xy.append(vel_xy)
            frames_grid_vec_xy.append(compute_grid_vectors(t0v))
            x_prev = x.clone()

    return frames_pos_xy, frames_vec_xy, frames_grid_vec_xy, grid_xy_global


In [27]:
# === Single-case animation (XY + map) ========================================
# Pick a single case index (wrapped)
CASE_SINGLE = 6841
x_hist_case_t = torch.tensor(X_test[CASE_SINGLE], device=DEVICE).unsqueeze(0)
ctx_case_t = torch.tensor(C_test[CASE_SINGLE], device=DEVICE).unsqueeze(0)

# Compute GT for the same single case
y_true_case_t = torch.tensor(Y_test[CASE_SINGLE], device=DEVICE).unsqueeze(0)
x_hist_glob_1 = denorm_seq_to_global(
    x_hist_case_t, ctx_case_t, feat_mean, feat_std, ctx_mean, ctx_std
)
y_true_glob_1 = denorm_seq_to_global(
    y_true_case_t, ctx_case_t, feat_mean, feat_std, ctx_mean, ctx_std
)

hist_xy = x_hist_glob_1[0, :, :2].detach().cpu().numpy()
true_xy = y_true_glob_1[0, :, :2].detach().cpu().numpy()
all_pos = np.vstack([hist_xy, true_xy])

# Small grid around history + GT
xmin, xmax = float(np.nanmin(all_pos[:, 0])), float(np.nanmax(all_pos[:, 0]))
ymin, ymax = float(np.nanmin(all_pos[:, 1])), float(np.nanmax(all_pos[:, 1]))
mx = 0.05 * max(1.0, xmax - xmin)
my = 0.05 * max(1.0, ymax - ymin)
grid_res = 10
GX, GY = np.meshgrid(
    np.linspace(xmin - mx, xmax + mx, grid_res),
    np.linspace(ymin - my, ymax + my, grid_res),
)
grid_xy = np.stack([GX.reshape(-1), GY.reshape(-1)], axis=-1)

# Collect frames with grid vectors
frames_pos_xy_g, frames_vec_xy_g, frames_grid_vec_xy, grid_xy = (
    collect_frames_with_grid(
        model,
        x_hist_case_t,
        ctx_case_t,
        T_out=TIMESTEPS,
        n_steps=64,
        G=1.0,
        use_autocast=True,
        frame_stride=1,
        grid_xy_global=grid_xy,
        true_xy_global=true_xy,
    )
)

# Arrow baseline (meters in LV95)
rx = float(np.nanmax(all_pos[:, 0]) - np.nanmin(all_pos[:, 0]) + 1e-6)
ry = float(np.nanmax(all_pos[:, 1]) - np.nanmin(all_pos[:, 1]) + 1e-6)
narrow = 0.05 * max(rx, ry)

# --- Map figure (Scattermap) ----------------------------------------------
# LV95 (EPSG:2056) -> WGS84 (EPSG:4326)
crs_lv95 = pyproj.CRS.from_epsg(2056)
crs_wgs84 = pyproj.CRS.from_epsg(4326)
to_wgs84 = pyproj.Transformer.from_crs(crs_lv95, crs_wgs84, always_xy=True)


def xy_to_lonlat_arr(xy: np.ndarray) -> np.ndarray:
    lons, lats = to_wgs84.transform(xy[:, 0], xy[:, 1])
    return np.stack([lons, lats], axis=1)


def seg_xy_lists_to_lonlat_lists(xs: list, ys: list):
    lons, lats = [], []
    for x, y in zip(xs, ys):
        if x is None or y is None:
            lons.append(None)
            lats.append(None)
        else:
            lon, lat = to_wgs84.transform(x, y)
            lons.append(lon)
            lats.append(lat)
    return lons, lats


hist_ll = xy_to_lonlat_arr(hist_xy)
true_ll = xy_to_lonlat_arr(true_xy)
all_ll = np.vstack([hist_ll, true_ll])
lon_center = float((np.nanmin(all_ll[:, 0]) + np.nanmax(all_ll[:, 0])) * 0.5)
lat_center = float((np.nanmin(all_ll[:, 1]) + np.nanmax(all_ll[:, 1])) * 0.5)
lat_span = float(np.nanmax(all_ll[:, 1]) - np.nanmin(all_ll[:, 1]))
lon_span = float(np.nanmax(all_ll[:, 0]) - np.nanmin(all_ll[:, 0]))
span = max(lat_span, lon_span)
map_zoom = 10.5

# The order of trace addition matters for how frames index into fig_map.data.
# We'll add History, t=0, GT, Sample vectors, Grid field, then Predicted (which is frame-updated)
fig_map = go.Figure()
fig_map.add_trace(
    go.Scattermap(
        lon=hist_ll[:, 0],
        lat=hist_ll[:, 1],
        mode="lines",
        line=dict(width=1, color="rgba(0,0,0,0.8)"),
        name="History",
    )
)
fig_map.add_trace(
    go.Scattermap(
        lon=[hist_ll[-1, 0]],
        lat=[hist_ll[-1, 1]],
        mode="markers",
        marker=dict(size=8, color="black"),
        name="t=0",
    )
)
fig_map.add_trace(
    go.Scattermap(
        lon=true_ll[:, 0],
        lat=true_ll[:, 1],
        mode="markers",
        marker=dict(size=8, color="red"),
        name="Ground Truth",
    )
)
fig_map.add_trace(
    go.Scattermap(
        lon=[],  # Will update by frames for sample vectors
        lat=[],
        mode="lines",
        line=dict(width=2, color="#ff7f0e"),
        opacity=0.9,
        name="Sample vectors",
    )
)
fig_map.add_trace(
    go.Scattermap(
        lon=[],  # Will update by frames for grid field
        lat=[],
        mode="lines",
        line=dict(width=1.5, color="purple"),
        opacity=0.6,
        name="Grid field",
    )
)
fig_map.add_trace(
    go.Scattermap(
        lon=[],  # Will update by frames for predicted positions
        lat=[],
        mode="lines+markers",
        line=dict(width=1, color="#1f77b4"),
        marker=dict(size=8),
        name="Predicted",
    )
)

frames_map = []
for i, (pos_xy, vec_xy, gvec_xy) in enumerate(
    zip(frames_pos_xy_g, frames_vec_xy_g, frames_grid_vec_xy)
):
    pos_ll = xy_to_lonlat_arr(pos_xy)
    k_samp, k_grid = _percentile_scale(
        vec_xy, gvec_xy, narrow, grid_factor=0.6, pct=95.0
    )
    seg_x, seg_y = [], []
    for j in range(pos_xy.shape[0]):
        x0, y0 = float(pos_xy[j, 0]), float(pos_xy[j, 1])
        dx, dy = k_samp * float(vec_xy[j, 0]), k_samp * float(vec_xy[j, 1])
        seg_x += [x0, x0 + dx, None]
        seg_y += [y0, y0 + dy, None]
    gseg_x, gseg_y = [], []
    for j in range(grid_xy.shape[0]):
        x0, y0 = float(grid_xy[j, 0]), float(grid_xy[j, 1])
        dx, dy = k_grid * float(gvec_xy[j, 0]), k_grid * float(gvec_xy[j, 1])
        gseg_x += [x0, x0 + dx, None]
        gseg_y += [y0, y0 + dy, None]
    seg_lon, seg_lat = seg_xy_lists_to_lonlat_lists(seg_x, seg_y)
    gseg_lon, gseg_lat = seg_xy_lists_to_lonlat_lists(gseg_x, gseg_y)

    frames_map.append(
        go.Frame(
            name=str(i),
            data=[
                # indices must match the above order:
                fig_map.data[0],  # History
                fig_map.data[1],  # t=0
                fig_map.data[2],  # Ground Truth
                go.Scattermap(
                    lon=seg_lon,
                    lat=seg_lat,
                    mode="lines",
                    line=dict(width=2, color="#ff7f0e"),
                    opacity=0.9,
                    name="Sample vectors",
                ),
                go.Scattermap(
                    lon=gseg_lon,
                    lat=gseg_lat,
                    mode="lines",
                    line=dict(width=1.5, color="purple"),
                    opacity=0.6,
                    name="Grid field",
                ),
                go.Scattermap(
                    lon=pos_ll[:, 0],
                    lat=pos_ll[:, 1],
                    mode="markers+lines",
                    line=dict(width=0.5, color="#1f77b4"),
                    marker=dict(size=8, color="#1f77b4"),
                    name="Predicted",
                ),
            ],
        )
    )

fig_map.frames = frames_map
fig_map.update_layout(
    title=f"Noise → Trajectory on map (case {CASE_SINGLE})",
    height=800,
    # width=800,
    map=dict(
        style="carto-positron",
        center=dict(lon=lon_center, lat=lat_center),
        zoom=map_zoom,
    ),
    margin=dict(l=10, r=10, t=50, b=10),
    updatemenus=[
        {
            "type": "buttons",
            "showactive": False,
            "x": 0.05,
            "y": 1.150,
            "xanchor": "left",
            "yanchor": "top",
            "direction": "left",
            "buttons": [
                {
                    "label": "Play",
                    "method": "animate",
                    "args": [
                        None,
                        {
                            "fromcurrent": True,
                            "frame": {"duration": 80, "redraw": True},
                            "transition": {"duration": 0},
                        },
                    ],
                },
                {
                    "label": "Pause",
                    "method": "animate",
                    "args": [
                        [None],
                        {
                            "mode": "immediate",
                            "frame": {"duration": 0, "redraw": False},
                            "transition": {"duration": 0},
                        },
                    ],
                },
            ],
        }
    ],
    sliders=[
        {
            "active": 0,
            "x": 0.05,
            "y": 0.05,
            "len": 0.9,
            "xanchor": "left",
            "yanchor": "top",
            "currentvalue": {"prefix": "Step: "},
            "steps": [
                {
                    "label": str(i),
                    "method": "animate",
                    "args": [
                        [str(i)],
                        {
                            "mode": "immediate",
                            "frame": {"duration": 0, "redraw": True},
                            "transition": {"duration": 0},
                        },
                    ],
                }
                for i in range(len(frames_map))
            ],
        }
    ],
)

# Initialize first frame for map (match trace order exactly!)
if len(frames_map) > 0:
    init_pos_ll = xy_to_lonlat_arr(frames_pos_xy_g[0])
    init_vec = frames_vec_xy_g[0]
    gvec0 = frames_grid_vec_xy[0]
    k_samp0, k_grid0 = _percentile_scale(
        init_vec, gvec0, narrow, grid_factor=0.6, pct=95.0
    )
    seg_x0, seg_y0, gseg_x0, gseg_y0 = [], [], [], []
    for j in range(frames_pos_xy_g[0].shape[0]):
        x0, y0 = float(frames_pos_xy_g[0][j, 0]), float(frames_pos_xy_g[0][j, 1])
        dx, dy = k_samp0 * float(init_vec[j, 0]), k_samp0 * float(init_vec[j, 1])
        seg_x0 += [x0, x0 + dx, None]
        seg_y0 += [y0, y0 + dy, None]
    for j in range(grid_xy.shape[0]):
        x0, y0 = float(grid_xy[j, 0]), float(grid_xy[j, 1])
        dx, dy = k_grid0 * float(gvec0[j, 0]), k_grid0 * float(gvec0[j, 1])
        gseg_x0 += [x0, x0 + dx, None]
        gseg_y0 += [y0, y0 + dy, None]
    seg_lon0, seg_lat0 = seg_xy_lists_to_lonlat_lists(seg_x0, seg_y0)
    gseg_lon0, gseg_lat0 = seg_xy_lists_to_lonlat_lists(gseg_x0, gseg_y0)
    fig_map.data[3].lon = seg_lon0
    fig_map.data[3].lat = seg_lat0
    fig_map.data[4].lon = gseg_lon0
    fig_map.data[4].lat = gseg_lat0
    fig_map.data[5].lon = init_pos_ll[:, 0]
    fig_map.data[5].lat = init_pos_ll[:, 1]

fig_map.show()

## Model Quality Metrics

Evaluate model performance using RMSE and MAE metrics.

In [15]:
from tqdm.auto import tqdm

@torch.no_grad()
def constant_velocity_extrapolation(
    X_test,
    C_test,
    feat_mean,
    feat_std,
    ctx_mean,
    ctx_std,
    device,
    T_out: int,
    dt_seconds: float = OUTPUT_STRIDE_SECONDS,
):
    """Deterministic constant-velocity baseline in global coordinates."""
    # Denorm history to global to read last state
    x_hist = torch.from_numpy(np.array(X_test)).to(device).contiguous()
    ctx = torch.from_numpy(np.array(C_test)).to(device).contiguous()
    x_hist_glob = (
        denorm_seq_to_global(x_hist, ctx, feat_mean, feat_std, ctx_mean, ctx_std)
        .cpu()
        .numpy()
    )  # (N, L, D)

    N, L, D = x_hist_glob.shape
    pos0 = x_hist_glob[:, -1, :3]  # (N, 3): x,y,z at last history step
    vel0 = np.zeros((N, 3), dtype=pos0.dtype)
    if D >= 6:
        vel0 = x_hist_glob[:, -1, 3:6]  # (N, 3): vx,vy,vz in m/s

    # Prepare output (N, T_out, D)
    y_pred_glob = np.zeros((N, T_out, D), dtype=x_hist_glob.dtype)

    # Fill positions by extrapolation, velocities constant; other channels set to 0
    tgrid = np.arange(1, T_out + 1, dtype=pos0.dtype) * dt_seconds  # seconds ahead
    disp = tgrid[None, :, None] * vel0[:, None, :]  # (N,T,3)

    # Positions
    y_pred_glob[:, :, 0] = pos0[:, 0][:, None] + disp[:, :, 0]  # x
    y_pred_glob[:, :, 1] = pos0[:, 1][:, None] + disp[:, :, 1]  # y
    if D >= 3:
        y_pred_glob[:, :, 2] = pos0[:, 2][:, None] + disp[:, :, 2]  # z

    # Velocities
    if D >= 4:
        y_pred_glob[:, :, 3] = vel0[:, 0][:, None]  # vx
    if D >= 5:
        y_pred_glob[:, :, 4] = vel0[:, 1][:, None]  # vy
    if D >= 6:
        y_pred_glob[:, :, 5] = vel0[:, 2][:, None]  # vz

    # Any remaining channels set to 0
    if D > 6:
        y_pred_glob[:, :, 6:] = 0.0

    return y_pred_glob  # global coords


@torch.no_grad()
def rmse_mae_3d_2d_vert_vs_horizon(
    model,
    X,
    Y,
    C,
    feat_mean,
    feat_std,
    ctx_mean,
    ctx_std,
    device,
    n_eval: int = 2000,
    n_samples: int = 32,
    n_steps: int = 64,
    batch_size: int = 128,
    dt_seconds: float = OUTPUT_STRIDE_SECONDS,
    progress: bool = True,
    desc: str = "Evaluating batches",
):
    """Single pass evaluator for RMSE & MAE vs horizon."""
    N = min(int(X.shape[0]), int(n_eval))
    T = int(Y.shape[1])

    def zeros():
        return np.zeros(T, dtype=np.float64)

    # Accumulators
    acc = {
        "rmse": {
            "3d": {"mean": zeros(), "best": zeros(), "cv": zeros()},
            "2d": {"mean": zeros(), "best": zeros(), "cv": zeros()},
            "vert": {"mean": zeros(), "best": zeros(), "cv": zeros()},
        },
        "mae": {
            "3d": {"mean": zeros(), "best": zeros(), "cv": zeros()},
            "2d": {"mean": zeros(), "best": zeros(), "cv": zeros()},
            "vert": {"mean": zeros(), "best": zeros(), "cv": zeros()},
        },
    }
    num_cases = 0

    # Progress tracking
    num_batches = (N + batch_size - 1) // batch_size
    iterator = range(0, N, batch_size)
    if progress:
        iterator = tqdm(iterator, total=num_batches, desc=desc, leave=True)

    for i0 in iterator:
        i1 = min(N, i0 + batch_size)

        # Tensors on device
        x_hist = torch.from_numpy(np.array(X[i0:i1])).to(device).contiguous()
        ctx = torch.from_numpy(np.array(C[i0:i1])).to(device).contiguous()
        y_true = torch.from_numpy(np.array(Y[i0:i1])).to(device).contiguous()
        B = int(x_hist.shape[0])

        # Samples (normalized) → global
        y_norm_all = sample_many(
            model, x_hist, ctx,
            T_out=T, n_steps=n_steps, G=1.0,
            n_samples=n_samples, chunk=128,
        )  # (S*B, T, D_norm)

        y_s_glob = denorm_seq_to_global(
            y_norm_all, ctx.repeat(n_samples, 1),
            feat_mean, feat_std, ctx_mean, ctx_std
        ).view(n_samples, B, T, -1)  # (S,B,T,D)

        # Ground truth global
        y_true_glob = denorm_seq_to_global(
            y_true, ctx, feat_mean, feat_std, ctx_mean, ctx_std
        )  # (B,T,D)

        # Mean prediction across samples
        y_mean_glob = y_s_glob.mean(dim=0)  # (B,T,D)

        # Errors for mean
        dx_mean = y_mean_glob[..., 0] - y_true_glob[..., 0]  # (B,T)
        dy_mean = y_mean_glob[..., 1] - y_true_glob[..., 1]
        dz_mean = y_mean_glob[..., 2] - y_true_glob[..., 2]

        dist3_mean_sq = dx_mean**2 + dy_mean**2 + dz_mean**2
        dist2_mean_sq = dx_mean**2 + dy_mean**2
        vert_mean_sq = dz_mean**2

        dist3_mean = torch.sqrt(dist3_mean_sq + 1e-12)
        dist2_mean = torch.sqrt(dist2_mean_sq + 1e-12)
        vert_mean = torch.sqrt(vert_mean_sq + 1e-12)  # = |dz|

        # Errors for best-of-N
        dx_all = y_s_glob[..., 0] - y_true_glob.unsqueeze(0)[..., 0]  # (S,B,T)
        dy_all = y_s_glob[..., 1] - y_true_glob.unsqueeze(0)[..., 1]
        dz_all = y_s_glob[..., 2] - y_true_glob.unsqueeze(0)[..., 2]

        dist3_all_sq = dx_all**2 + dy_all**2 + dz_all**2       # (S,B,T)
        dist2_all_sq = dx_all**2 + dy_all**2                   # (S,B,T)
        vert_all_sq = dz_all**2                               # (S,B,T)

        # Min over samples (S) per (B,T)
        dist3_best_sq = dist3_all_sq.min(dim=0).values
        dist2_best_sq = dist2_all_sq.min(dim=0).values
        vert_best_sq = vert_all_sq.min(dim=0).values

        dist3_best = torch.sqrt(dist3_best_sq + 1e-12)
        dist2_best = torch.sqrt(dist2_best_sq + 1e-12)
        vert_best = torch.sqrt(vert_best_sq + 1e-12)  # = min |dz|

        # Constant-velocity baseline
        cv_pred_glob = constant_velocity_extrapolation(
            X[i0:i1], C[i0:i1],
            feat_mean, feat_std, ctx_mean, ctx_std,
            device=device, T_out=T, dt_seconds=dt_seconds,
        )  # (B,T,D)
        cv = torch.from_numpy(cv_pred_glob).to(device)

        dx_cv = cv[..., 0] - y_true_glob[..., 0]
        dy_cv = cv[..., 1] - y_true_glob[..., 1]
        dz_cv = cv[..., 2] - y_true_glob[..., 2]

        dist3_cv_sq = dx_cv**2 + dy_cv**2 + dz_cv**2
        dist2_cv_sq = dx_cv**2 + dy_cv**2
        vert_cv_sq = dz_cv**2

        dist3_cv = torch.sqrt(dist3_cv_sq + 1e-12)
        dist2_cv = torch.sqrt(dist2_cv_sq + 1e-12)
        vert_cv = torch.sqrt(vert_cv_sq + 1e-12)  # = |dz|

        # Accumulate (sum over batch; average over num_cases later)
        for key, sq_t, abs_t in [
            (("3d","mean"), dist3_mean_sq, dist3_mean),
            (("2d","mean"), dist2_mean_sq, dist2_mean),
            (("vert","mean"), vert_mean_sq, vert_mean),
            (("3d","best"), dist3_best_sq, dist3_best),
            (("2d","best"), dist2_best_sq, dist2_best),
            (("vert","best"), vert_best_sq, vert_best),
            (("3d","cv"), dist3_cv_sq, dist3_cv),
            (("2d","cv"), dist2_cv_sq, dist2_cv),
            (("vert","cv"), vert_cv_sq, vert_cv),
        ]:
            cat, est = key
            acc["rmse"][cat][est] += sq_t.detach().cpu().numpy().sum(axis=0)
            acc["mae"][cat][est]  += abs_t.detach().cpu().numpy().sum(axis=0)

        num_cases += B

        # Progress update
        if progress:
            iterator.set_postfix_str(f"cases={num_cases}, B={B}")

    # Finalize
    denom = max(1, num_cases)
    out = {"horizons_s": (np.arange(1, T + 1) * dt_seconds).astype(int), "rmse": {}, "mae": {}, "fde": {}}
    for cat in ["3d", "2d", "vert"]:
        out["rmse"][cat] = {}
        out["mae"][cat]  = {}
        out["fde"][cat]  = {}
        for est in ["mean", "best", "cv"]:
            rmse_curve = np.sqrt(acc["rmse"][cat][est] / denom)
            mae_curve  = acc["mae"][cat][est] / denom
            out["rmse"][cat][est] = rmse_curve
            out["mae"][cat][est]  = mae_curve
            out["fde"][cat][est] = {
                "rmse": rmse_curve[-1],
                "mae":  mae_curve[-1],
            }
    return out

# Run evaluation metrics
res = rmse_mae_3d_2d_vert_vs_horizon(
    model,
    X_test, Y_test, C_test,
    feat_mean, feat_std, ctx_mean, ctx_std,
    device=DEVICE,
    n_eval=64,  
    n_samples=64,
    n_steps=32,
    batch_size=16,
    dt_seconds=OUTPUT_STRIDE_SECONDS
)

print("Evaluation completed successfully")

Evaluating batches: 100%|██████████| 4/4 [06:12<00:00, 93.14s/it, cases=64, B=16]

Evaluation completed successfully





In [13]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

H = res["horizons_s"]

# --- build 2-column × 3-row subplot ---
fig = make_subplots(
    rows=3, cols=2,
    shared_xaxes=True,
    subplot_titles=[
        "MAE — 3D (x,y,z)", "RMSE — 3D (x,y,z)",
        "MAE — 2D horizontal (x,y)", "RMSE — 2D horizontal (x,y)",
        "MAE — Vertical (z)", "RMSE — Vertical (z)"
    ],
    vertical_spacing=0.08,
    horizontal_spacing=0.08
)

# === color scheme ===
COLOR_MEAN = "blue"
COLOR_BEST = "green"
COLOR_CV   = "red"

def add_metric_traces(cat_name, row_mae, row_rmse, show_legend=False):
    # --- MAE (left) ---
    fig.add_trace(go.Scatter(
        x=H, y=res["mae"][cat_name]["mean"],
        name="Model (mean)", mode="lines+markers",
        line=dict(width=2, color=COLOR_MEAN),
        showlegend=show_legend
    ), row=row_mae, col=1)

    fig.add_trace(go.Scatter(
        x=H, y=res["mae"][cat_name]["best"],
        name="Model (best-of-S)", mode="lines+markers",
        line=dict(width=2, dash="dash", color=COLOR_BEST),
        showlegend=show_legend
    ), row=row_mae, col=1)

    fig.add_trace(go.Scatter(
        x=H, y=res["mae"][cat_name]["cv"],
        name="Constant velocity", mode="lines+markers",
        line=dict(width=2, dash="dot", color=COLOR_CV),
        showlegend=show_legend
    ), row=row_mae, col=1)

    # --- RMSE (right) ---
    fig.add_trace(go.Scatter(
        x=H, y=res["rmse"][cat_name]["mean"],
        name="Model (mean)", mode="lines+markers",
        line=dict(width=2, color=COLOR_MEAN),
        showlegend=False
    ), row=row_rmse, col=2)

    fig.add_trace(go.Scatter(
        x=H, y=res["rmse"][cat_name]["best"],
        name="Model (best-of-S)", mode="lines+markers",
        line=dict(width=2, dash="dash", color=COLOR_BEST),
        showlegend=False
    ), row=row_rmse, col=2)

    fig.add_trace(go.Scatter(
        x=H, y=res["rmse"][cat_name]["cv"],
        name="Constant velocity", mode="lines+markers",
        line=dict(width=2, dash="dot", color=COLOR_CV),
        showlegend=False
    ), row=row_rmse, col=2)

# Add traces — legend only once (for first row)
add_metric_traces("3d",   1, 1, show_legend=True)
add_metric_traces("2d",   2, 2, show_legend=False)
add_metric_traces("vert", 3, 3, show_legend=False)

# --- layout ---
fig.update_layout(
    template="plotly_white",
    height=900, width=1100,
    # title="MAE and RMSE vs Prediction Horizon",
    legend=dict(
        orientation="h",
        yanchor="bottom", y=-0.1,
        xanchor="center", x=0.5,
    )
)

# axis labels
for r in range(1, 4):
    fig.update_xaxes(title_text="Horizon [s]", row=r, col=1)
    fig.update_xaxes(title_text="Horizon [s]", row=r, col=2)
    fig.update_yaxes(title_text="Error [m]", row=r, col=1)
    fig.update_yaxes(title_text="Error [m]", row=r, col=2)

fig.show()

## Probabilistic Calibration Analysis

Analyze the calibration of probabilistic predictions using PIT histograms.

In [9]:
@torch.no_grad()
def collect_samples_for_calibration(
    model,
    X,
    Y,
    C,
    feat_mean,
    feat_std,
    ctx_mean,
    ctx_std,
    device,
    n_eval: int = 2000,
    n_samples: int = 64,
    n_steps: int = 64,
    batch_size: int = 128,
):
    """Draw samples and return predictions and ground truth."""
    N = min(int(X.shape[0]), int(n_eval))
    T = int(Y.shape[1])
    pos_samples = []
    pos_truth = []
    for i0 in tqdm(range(0, N, batch_size)):
        i1 = min(N, i0 + batch_size)
        x_hist = torch.from_numpy(np.array(X[i0:i1])).to(device).contiguous()
        ctx = torch.from_numpy(np.array(C[i0:i1])).to(device).contiguous()
        y_true = torch.from_numpy(np.array(Y[i0:i1])).to(device).contiguous()
        B = int(x_hist.shape[0])
        y_norm_all = sample_many(
            model,
            x_hist,
            ctx,
            T_out=T,
            n_steps=n_steps,
            G=1.0,
            n_samples=n_samples,
            chunk=128,
        )  # (S*B, T, Dn)
        y_glob_all = denorm_seq_to_global(
            y_norm_all, ctx.repeat(n_samples, 1), feat_mean, feat_std, ctx_mean, ctx_std
        ).view(n_samples, B, T, -1)[..., :3]  # (S,B,T,3)
        y_true_g = denorm_seq_to_global(
            y_true, ctx, feat_mean, feat_std, ctx_mean, ctx_std
        )[..., :3]
        pos_samples.append(y_glob_all)
        pos_truth.append(y_true_g)
    ysamp_g = torch.cat(pos_samples, dim=1)  # (S,N,T,3)
    ytrue_g = torch.cat(pos_truth, dim=0)  # (N,T,3)
    return ysamp_g, ytrue_g

# Draw samples for calibration analysis
ysamp_g, ytrue_g = collect_samples_for_calibration(
    model,
    X_test,
    Y_test,
    C_test,
    feat_mean,
    feat_std,
    ctx_mean,
    ctx_std,
    device=DEVICE,
    n_eval=12,
    n_samples=64,
    n_steps=32,
    batch_size=4,
)

print(f"Collected {ysamp_g.shape[0]} samples for {ysamp_g.shape[1]} trajectories")

100%|██████████| 3/3 [01:08<00:00, 22.97s/it]

Collected 64 samples for 12 trajectories





## PIT Histogram Visualization

Create Probability Integral Transform (PIT) histograms to assess calibration.

In [10]:
# Calculate PIT values
pits_btd = pit_values(ysamp_g, ytrue_g)  # (N,T,3)
pits_flat = pits_btd.reshape(-1, 3).detach().cpu().numpy()

# Setup for plotting
axis_names = ["x", "y", "z"]
axis_colors = ["black", "red", "blue"]
uniform_line_color = "gray"

# Create subplots: 1 row x 3 cols (histograms)
fig = make_subplots(
    rows=1, cols=3,
    shared_xaxes=False,
    shared_yaxes=False,
    horizontal_spacing=0.07,
    subplot_titles=[f"PIT histogram — {axis}" for axis in axis_names],
)

# Parameters
nbins = 30

for d in range(3):
    col = d + 1
    color = axis_colors[d]
    u = pits_flat[:, d]

    # Histogram
    fig.add_trace(
        go.Histogram(
            x=u,
            nbinsx=nbins,
            histnorm="probability density",
            marker=dict(color=color),
            name=axis_names[d],
            showlegend=False,
            opacity=0.75,
        ),
        row=1, col=col
    )

    # Uniform PDF reference line (y=1 on [0,1])
    fig.add_trace(
        go.Scatter(
            x=[0, 1],
            y=[1, 1],
            mode="lines",
            line=dict(dash="dash", color=uniform_line_color),
            name="Uniform PDF (y=1)" if d == 0 else None,
            showlegend=(d == 0),
        ),
        row=1, col=col
    )

    # Axes formatting
    fig.update_xaxes(title_text="PIT value", range=[0, 1], row=1, col=col)
    fig.update_yaxes(title_text="Density", range=[0, 4.5], row=1, col=col)

# Layout
fig.update_layout(
    height=400, width=800,
    template="plotly_white",
    title="Calibration diagnostics: PIT histograms",
    bargap=0.05,
    legend=dict(
        orientation="h",
        yanchor="bottom", y=-0.3,
        xanchor="center", x=0.5
    ),
    margin=dict(l=10, r=10, t=60, b=60),
)

# Save and display
fig.write_html("figures/pit_histograms.html")
fig.show()