# Conflict Resolution and Management (CRM) Analysis

This notebook demonstrates conflict detection and resolution analysis using the generative flight trajectory prediction model. It loads pre-trained models, analyzes conflicting flight pairs, and evaluates collision risk.


## Setup and Configuration


In [1]:
# Core imports
import sys
import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import torch
import pathlib

# Add project root to Python path (BEFORE other imports)
project_root = Path.cwd().parent if Path.cwd().name == 'notebooks' else Path.cwd()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# Project imports (now that path is set up)
from model import load_model_checkpoint
from utils.generate_predictions import (
    predict_trajectories,
    calculate_trajectory_spacings,
)
from utils.utils import cache_paths

# External libraries
from traffic.core import Traffic
from traffic.data import opensky
import plotly.graph_objects as go
import importlib

# Configuration
PARQUET_PATH = "../trajs_LSAS_filtered.parquet"
CACHE_DIR = pathlib.Path("../dataset_cache")
CKPT_PATH = "../models/model_1min.pt"
CACHE_KEY_FILE = "ecec4b007a021fa3.key.json"
OUTPUT_STRIDE_SECONDS = 5

# Device configuration
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print(f"Using device: {DEVICE}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: mps


## Data Loading


In [2]:
# Load cached dataset artifacts
def load_cache_key():
    """Load dataset cache key file."""
    if CACHE_KEY_FILE:
        key_path = Path(CACHE_KEY_FILE)
        if not key_path.is_absolute():
            key_path = CACHE_DIR / key_path
        if not key_path.exists():
            raise FileNotFoundError(
                f"Specified cache key file not found: {key_path}. "
                f"Make sure it exists under {CACHE_DIR}."
            )
    else:
        key_files = sorted(
            CACHE_DIR.glob("*.key.json"),
            key=lambda p: p.stat().st_mtime,
            reverse=True,
        )
        if not key_files:
            raise FileNotFoundError(
                "No dataset_cache key files found. Run the training notebook first."
            )
        key_path = key_files[0]
    return key_path


# Load cache key and dataset information
key_path = load_cache_key()
key_info = json.loads(key_path.read_text())
print(f"[cache] using key: {key_path.name}")

dset_key = key_info["dataset_key"]
stats_key = key_info["stats_key"]
paths = cache_paths(dset_key, stats_key)

# Load dataset arrays (using memory-mapped files for efficiency)
X_train = np.load(paths["x_tr"], mmap_mode="r")
Y_train = np.load(paths["y_tr"], mmap_mode="r")
C_train = np.load(paths["c_tr"], mmap_mode="r")
X_val = np.load(paths["x_va"], mmap_mode="r")
Y_val = np.load(paths["y_va"], mmap_mode="r")
C_val = np.load(paths["c_va"], mmap_mode="r")
X_test = np.load(paths["x_te"], mmap_mode="r")
Y_test = np.load(paths["y_te"], mmap_mode="r")
C_test = np.load(paths["c_te"], mmap_mode="r")

# Load normalization statistics and metadata
norm_stats = json.loads(paths["stats"].read_text())
meta_train = pd.read_parquet(paths["meta_tr"])  # per-window metadata
meta_val = pd.read_parquet(paths["meta_va"])
meta_test = pd.read_parquet(paths["meta_te"])
manifest = json.loads(paths["manifest"].read_text())
summary = json.loads(paths["summary"].read_text())

# Extract normalization parameters
feat_mean = norm_stats["feat_mean"]
feat_std = norm_stats["feat_std"]
ctx_mean = norm_stats["ctx_mean"]
ctx_std = norm_stats["ctx_std"]

# Display dataset sizes
dataset_sizes = {
    k: (int(v) if isinstance(v, (int, np.integer)) else v)
    for k, v in summary.get("sizes", {}).items()
}
print("Dataset sizes:", dataset_sizes)

# Load trajectory data
trajs = Traffic.from_file(PARQUET_PATH)
print(f"Loaded {trajs.data.flight_id.nunique()} trajectories from {PARQUET_PATH}")

[cache] using key: ecec4b007a021fa3.key.json
Dataset sizes: {'train': 1000000, 'val': 200000, 'test': 200000}
Loaded 178947 trajectories from ../trajs_LSAS_filtered.parquet


## Model Architecture

Import the Flow Matching Model components for trajectory prediction.


In [3]:
# Import model architecture from dedicated module


## Model Loading


In [4]:
# Load the pre-trained model
model = load_model_checkpoint(CKPT_PATH, DEVICE)
print(f"Loaded checkpoint: {CKPT_PATH}")

Loaded checkpoint: ../models/model_1min.pt




## Conflict Analysis Setup

Load and prepare conflicting flight trajectory data for analysis.


In [5]:


conflicting_pair = (
    opensky.history(
        "2024-09-16 07:40:00",
        "2024-09-16 07:42:00",
        callsign=["EAU24K", "EZY59KU"],
        # bounds=(47.37, 12.35, 47.39, 12.45)
    )
    .assign_id()
    .eval()
)
conflicting_pair.data["timestamp"] = pd.to_datetime(
    conflicting_pair.data["last_position"], unit="s", utc=True
)
cols2keep = [
    "flight_id",
    "timestamp",
    "latitude",
    "longitude",
    "altitude",
    "track",
    "groundspeed",
    "vertical_rate",
    "icao24",
    "callsign",
]
conflicting_pair.data = conflicting_pair.data[cols2keep]
conflicting_pair.scatter_map(hover_data=["timestamp"])

2025-10-26 16:05:29,459 INFO Processing query SELECT state_vectors_data4.time, state_vectors_data4.icao24, state_vectors_data4.lat, state_vectors_data4.lon, state_vectors_data4.velocity, state_vectors_data4.heading, state_vectors_data4.vertrate, state_vectors_data4.callsign, state_vectors_data4.onground, state_vectors_data4.alert, state_vectors_data4.spi, state_vectors_data4.squawk, state_vectors_data4.baroaltitude, state_vectors_data4.geoaltitude, state_vectors_data4.lastposupdate, state_vectors_data4.lastcontact, state_vectors_data4.serials, state_vectors_data4.hour 
FROM state_vectors_data4 
WHERE state_vectors_data4.callsign IN (__[POSTCOMPILE_callsign_1]) AND state_vectors_data4.time >= :time_1 AND state_vectors_data4.time <= :time_2 AND state_vectors_data4.hour >= :hour_1 AND state_vectors_data4.hour < :hour_2
{'callsign_1': ['EAU24K', 'EZY59KU'], 'time_1': Timestamp('2024-09-16 07:40:00+0000', tz='UTC'), 'time_2': Timestamp('2024-09-16 07:42:00+0000', tz='UTC'), 'hour_1': Timest

In [6]:
import numpy as np
import pandas as pd


def resample_flight_1hz(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy().sort_values("timestamp").set_index("timestamp")

    # Ensure UTC tz-aware
    if df.index.tz is None:
        df.index = df.index.tz_localize("UTC")
    else:
        df.index = df.index.tz_convert("UTC")

    # --- Collapse duplicate timestamps (same second) ---
    # Define columns
    num_cols = [
        c
        for c in ["latitude", "longitude", "altitude", "groundspeed", "vertical_rate"]
        if c in df.columns
    ]
    id_cols = [c for c in ["flight_id", "icao24", "callsign"] if c in df.columns]

    def circ_mean_deg(s: pd.Series):
        s = s.dropna()
        if s.empty:
            return np.nan
        rad = np.deg2rad(s.values)
        return (
            np.degrees(np.arctan2(np.sin(rad).mean(), np.cos(rad).mean())) + 360.0
        ) % 360.0

    if not df.index.is_unique:
        aggs = {c: "mean" for c in num_cols}
        if "track" in df.columns:
            aggs["track"] = circ_mean_deg
        for c in id_cols:
            # keep the last non-null label observed at that timestamp
            aggs[c] = lambda s: s.dropna().iloc[-1] if s.dropna().size else np.nan
        df = df.groupby(level=0, sort=True).agg(aggs)

    # Build 1 Hz grid
    start = df.index[0].ceil("s")
    end = df.index[-1].floor("s")
    if end < start:
        # No whole-second span; return any rows already on whole seconds
        out = df[df.index == df.index.round("S")].reset_index()
        out.rename(columns={"index": "timestamp"}, inplace=True)
        return out

    grid = pd.date_range(start, end, freq="1s", tz="UTC")

    # Interpolate over the union, then select grid
    union = df.index.union(grid, sort=False)
    w = df.reindex(union)

    # Time-based interpolation for numeric
    for c in num_cols:
        w[c] = w[c].interpolate(method="time")

    # Track via vector interpolation
    if "track" in w.columns:
        rad = np.deg2rad(w["track"])
        x = np.cos(rad).interpolate(method="time")
        y = np.sin(rad).interpolate(method="time")
        w["track"] = (np.degrees(np.arctan2(y, x)) + 360.0) % 360.0

    # IDs forward/back fill
    for c in id_cols:
        w[c] = w[c].ffill().bfill()

    # Keep only the grid and tidy columns
    w = w.reindex(grid)
    cols2keep = [
        "flight_id",
        "latitude",
        "longitude",
        "altitude",
        "track",
        "groundspeed",
        "vertical_rate",
        "icao24",
        "callsign",
    ]
    cols2keep = [c for c in cols2keep if c in w.columns]
    w = w[cols2keep].reset_index().rename(columns={"index": "timestamp"})
    return w


# Usage
l_rs = []
for f in conflicting_pair:
    df_rs = resample_flight_1hz(f.data)
    l_rs.append(df_rs)

df_all_1hz = pd.concat(l_rs, ignore_index=True)
conflicting_pair = Traffic(df_all_1hz).assign_id().eval()
conflicting_pair

Unnamed: 0_level_0,count
flight_id,Unnamed: 1_level_1
EZY59KU_001,120
EAU24K_000,119


In [7]:
fig = conflicting_pair.first("1m").eval().scatter_map(color="altitude")
fig.update_layout(
    map_style="carto-positron",
)
fig.write_html("../figures/conflicting_pair.html")
fig.show()


In [8]:
f1_df = conflicting_pair[0].first("60s").data
f2_df = conflicting_pair[1].first("60s").data
gt1 = conflicting_pair[0].last("60s").data
gt2 = conflicting_pair[1].last("60s").data

In [9]:
cols2keep = [
    "timestamp",
    "latitude",
    "longitude",
    "altitude",
    "track",
    "groundspeed",
    "vertical_rate",
    "icao24",
    "callsign",
]
f1_df = f1_df[cols2keep]
f2_df = f2_df[cols2keep]
f1_df.shape

(60, 9)

In [10]:
n_samps = 20
sampling_rate = 1  # s
f1_latlonalt = predict_trajectories(
    f1_df,
    model,
    feat_mean,
    feat_std,
    ctx_mean,
    ctx_std,
    DEVICE,
    n_samples=n_samps,
    sampling_rate=sampling_rate,
)
f2_latlonalt = predict_trajectories(
    f2_df,
    model,
    feat_mean,
    feat_std,
    ctx_mean,
    ctx_std,
    DEVICE,
    n_samples=n_samps,
    sampling_rate=sampling_rate,
)

In [11]:
# === save_geojson.py ===
import json
import numpy as np
import pandas as pd

# assumes f1_df, f2_df, f1_latlonalt, f2_latlonalt already defined (like in your notebook)

OUTDIR = Path("../webdata")
OUTDIR.mkdir(parents=True, exist_ok=True)


def to_feature_line(coords, props, altitudes=None):
    feature = {
        "type": "Feature",
        "geometry": {"type": "LineString", "coordinates": coords},
        "properties": props,
    }
    if altitudes is not None:
        feature["properties"]["altitudes_ft"] = altitudes
    return feature


# Histories (one LineString per flight). Mapbox needs [lon, lat] order.
hist_features = []
for flight_id, df in [("f1", f1_df), ("f2", f2_df)]:
    coords = list(zip(df["longitude"].astype(float), df["latitude"].astype(float)))
    altitudes = (
        df["altitude"].astype(float).tolist() if "altitude" in df.columns else None
    )
    # epoch seconds per vertex (handy if you later want time-accurate pacing in JS)
    ts = pd.to_datetime(df["timestamp"]).astype("int64") // 10**9
    hist_features.append(
        to_feature_line(
            coords,
            {"flight": flight_id, "timestamps": ts.tolist()},
            altitudes=altitudes,
        )
    )

with open(OUTDIR / "histories.geojson", "w") as f:
    json.dump({"type": "FeatureCollection", "features": hist_features}, f)


# Futures (many LineStrings per flight: one per sample)
def futures_fc(latlonalt_list, flight_id, n_samples=20, T=None):
    feats = []
    n = min(n_samples, len(latlonalt_list))
    for i in range(n):
        arr = latlonalt_list[i]
        if T is not None:
            arr = arr[:T]
        # arr rows are [lat, lon, alt]
        coords = [(float(lon), float(lat)) for (lat, lon, *_rest) in arr]
        altitudes = [float(alt) for (*_, alt) in arr] if arr.shape[1] >= 3 else None
        feats.append(
            to_feature_line(
                coords, {"flight": flight_id, "sample": i}, altitudes=altitudes
            )
        )
    return {"type": "FeatureCollection", "features": feats}


# choose a light number of samples for performance
N_SAMPLES = 20
with open(OUTDIR / "futures_f1.geojson", "w") as f:
    json.dump(futures_fc(f1_latlonalt, "f1", n_samples=N_SAMPLES), f)
with open(OUTDIR / "futures_f2.geojson", "w") as f:
    json.dump(futures_fc(f2_latlonalt, "f2", n_samples=N_SAMPLES), f)

# Useful viewport metadata (bounds)
all_hist_ll = np.vstack(
    [
        np.column_stack([f1_df["longitude"].to_numpy(), f1_df["latitude"].to_numpy()]),
        np.column_stack([f2_df["longitude"].to_numpy(), f2_df["latitude"].to_numpy()]),
    ]
)
bounds = {
    "minLon": float(np.nanmin(all_hist_ll[:, 0])),
    "minLat": float(np.nanmin(all_hist_ll[:, 1])),
    "maxLon": float(np.nanmax(all_hist_ll[:, 0])),
    "maxLat": float(np.nanmax(all_hist_ll[:, 1])),
}
with open(OUTDIR / "meta.json", "w") as f:
    json.dump(bounds, f, indent=2)

print(
    "Wrote:",
    [
        str(p)
        for p in (
            OUTDIR / "histories.geojson",
            OUTDIR / "futures_f1.geojson",
            OUTDIR / "futures_f2.geojson",
            OUTDIR / "meta.json",
        )
    ],
)

Wrote: ['../webdata/histories.geojson', '../webdata/futures_f1.geojson', '../webdata/futures_f2.geojson', '../webdata/meta.json']


In [12]:
spacing_df = calculate_trajectory_spacings(f1_latlonalt, f2_latlonalt)

n_los = spacing_df.query("vert_spacing_ft < 1000 and hori_spacing_NM < 5").shape[0]
n_cols = spacing_df.query("vert_spacing_ft < 55 and hori_spacing_NM < 0.03").shape[0]

frac_los = n_los / spacing_df.shape[0]
frac_cols = n_cols / spacing_df.shape[0]

print("\n==== LOSS OF SEPARATION/COLLISION SUMMARY ====")
print(f"Total trajectory samples compared: {spacing_df.shape[0]:,}")
print(f"Number of Loss of Separation (LOS) (<1000ft & <5NM): {n_los} ({frac_los:.2%})")
print(f"Number of Collisions (<55ft & <0.03NM):        {n_cols} ({frac_cols:.2%})\n")



==== LOSS OF SEPARATION/COLLISION SUMMARY ====
Total trajectory samples compared: 400
Number of Loss of Separation (LOS) (<1000ft & <5NM): 288 (72.00%)
Number of Collisions (<55ft & <0.03NM):        2 (0.50%)



In [13]:
# px.histogram(
#     spacing_df,
#     x='CPA_3D_meter'
# ).show()
# fig = px.scatter(
#     spacing_df,
#     x='hori_spacing_NM',
#     y='vert_spacing_ft'
# )
# # add a collision box with a vertical line at 55ft and an horizontal line at 0.03NM
# fig.add_vrect(x0=0, x1=0.03, y0=0, y1=55, fillcolor="LightSalmon", opacity=0.2, line_width=0)

In [14]:
COLOR_HISTORY = "black"
COLOR_GT = "red"
COLOR_PRED = "#1f77b4"
COLOR_PRED2 = "#ff0000"
COLOR_MEAN = "#e19f20"

fig = go.Figure()

# Flight 1 predicted samples (blue, semi-transparent spaghetti)
for i, f in enumerate(f1_latlonalt[:100]):
    fig.add_trace(
        go.Scattermap(
            lon=f[:, 1],
            lat=f[:, 0],
            mode="lines",
            line=dict(width=1, color=COLOR_PRED),
            opacity=0.5,
            name=f"{100} samples Flight 1",
            showlegend=i == 0,
        )
    )


# # Flight 2 predicted samples (blue, semi-transparent spaghetti)
for i, f in enumerate(f2_latlonalt[:100]):
    fig.add_trace(
        go.Scattermap(
            lon=f[:, 1],
            lat=f[:, 0],
            mode="lines",
            line=dict(width=1, color=COLOR_PRED2),
            opacity=0.5,
            name=f"{100} samples Flight 2",
            showlegend=i == 0,
        )
    )


# History (observations) - black
fig.add_trace(
    go.Scattermap(
        lon=f1_df["longitude"],
        lat=f1_df["latitude"],
        mode="lines",
        line=dict(width=1.5, color=COLOR_HISTORY),
        name="Flight 1 History",
    )
)

fig.add_trace(
    go.Scattermap(
        lon=f2_df["longitude"],
        lat=f2_df["latitude"],
        mode="lines",
        line=dict(width=1.5, color=COLOR_HISTORY),
        name="Flight 2 History",
    )
)

# # Map settings - center on the trajectories
all_ll = np.vstack(
    [
        np.column_stack([f1_df["longitude"], f1_df["latitude"]]),
        np.column_stack([f2_df["longitude"], f2_df["latitude"]]),
    ]
)
lon_center = float((np.nanmin(all_ll[:, 0]) + np.nanmax(all_ll[:, 0])) * 0.5)
lat_center = float((np.nanmin(all_ll[:, 1]) + np.nanmax(all_ll[:, 1])) * 0.5)

fig.update_layout(
    map=dict(
        style="carto-positron",
        center=dict(lon=lon_center, lat=lat_center),
        zoom=10.5,
    ),
    margin=dict(l=10, r=10, t=0, b=10),
    legend=dict(orientation="h", yanchor="bottom", y=-0.05, xanchor="center", x=0.5),
    # height=800
)
# fig.write_html('../figures/conflicting_pair_predictions.html')
fig.show()
