# Flow Matching Model Training

This notebook trains the Flow Matching Model for generative flight trajectory prediction. The model learns to predict future aircraft trajectories given historical flight data and contextual information.

## Setup and Configuration

In [None]:
# Core imports
import os
import math
import time
import random
from collections import OrderedDict
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset
from traffic.core import Traffic

# Quiet optional logs for cleaner output
os.environ["TRITON_PRINT_AUTOTUNING"] = "0"
if "TORCH_LOGS" in os.environ:
    os.environ.pop("TORCH_LOGS", None)

# Import project utilities
from utils import (
    load_and_engineer,
    WindowParams,
    SplitConfig,
    SamplingConfig,
    StatsConfig,
    CFMDataset,
    make_loader,
    build_or_load_dataset,
    TurnSampling
)

# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Enable optimizations if using CUDA
if torch.cuda.is_available():
    try:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
    except Exception:
        pass

## Data Loading and Preparation

In [None]:
# Input data path
INPUT_PARQUET = "trajs_LSAS_filtered.parquet"

# Load and perform feature engineering
print("Loading and engineering trajectory data...")
df = load_and_engineer(INPUT_PARQUET)
print(f"Loaded dataset with {len(df)} trajectory segments")

In [None]:
# Dataset configuration
wparams = WindowParams(
    input_len=60,      # 60 historical time steps (5 minutes at 5Hz)
    output_horizon=60,  # 60 future time steps (5 minutes at 5Hz)
    output_stride=5,    # 5 second intervals
    overlap=False       # No overlap between windows
)

# Train/validation/test split configuration
scfg = SplitConfig(
    train_frac=0.8,   # 80% training
    val_frac=0.1,     # 10% validation
    split_seed=42     # Reproducible splits
)

# Sampling configuration for trajectory selection
samp = SamplingConfig(
    n_train=1_000_000,  # 1M training samples
    n_val=200_000,      # 200K validation samples
    n_test=200_000,     # 200K test samples
    train_turn=TurnSampling(
        min_turn_frac=0.30, turn_thr=0.01, consec=3,
        consider_hist=True, consider_future=True
    ),
    val_turn=TurnSampling(
        min_turn_frac=0.30, turn_thr=0.01, consec=3,
        consider_hist=True, consider_future=True
    ),
    test_turn=TurnSampling(
        min_turn_frac=0.0, turn_thr=0.01, consec=3,
        consider_hist=True, consider_future=True
    ),
)

# Normalization statistics configuration
stats_cfg = StatsConfig(
    stats_seed=1234,
    stats_sample_size=2_000_000  # Sample size for computing normalization stats
)

# Build or load processed datasets
print("Building datasets...")
(X_train, Y_train, C_train,
 X_val, Y_val, C_val,
 X_test, Y_test, C_test,
 norm_stats, meta_train, meta_val, meta_test,
 manifest, summary) = build_or_load_dataset(df, wparams, scfg, samp, stats_cfg)

# Create dataset objects
train_ds = CFMDataset(X_train, Y_train, C_train)
val_ds = CFMDataset(X_val, Y_val, C_val)
test_ds = CFMDataset(X_test, Y_test, C_test)

print(f"Dataset sizes: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}")

# Create data loaders
train_loader = make_loader(train_ds, bs=2048, shuffle=True)
val_loader = make_loader(val_ds, bs=2048, shuffle=False)
test_loader = make_loader(test_ds, bs=2048, shuffle=False)

print(f"Batch counts: train={len(train_loader)}, val={len(val_loader)}, test={len(test_loader)}")

## Model Architecture

Import the Flow Matching Model components for conditional generative trajectory prediction.

In [None]:
# Import model architecture and utilities from dedicated module
from model import FlowMatchingModel, sample_xt_and_target

## Training Utilities

Import training utilities including learning rate scheduling and EMA.

In [None]:
# Import training utilities from dedicated module
from training_utils import train_cfm

## Model Training

Execute the training process with the configured hyperparameters.

In [None]:
# Extract normalization statistics
feat_mean = norm_stats["feat_mean"]
feat_std = norm_stats["feat_std"]
ctx_mean = norm_stats["ctx_mean"]
ctx_std = norm_stats["ctx_std"]

# Model configuration
model_cfg = dict(
    d_model=512,        # Model dimension
    nhead=8,           # Number of attention heads
    enc_layers=6,      # Encoder layers
    dec_layers=8,      # Decoder layers
    ff=4*512,          # Feed-forward dimension
    dropout=0.1,       # Dropout rate
    in_dim=7,          # Input feature dimension
    context_dim=8,     # Context feature dimension
)

# Checkpoint path
checkpoint_path = "models/model_1min.pt"

# Train the model
trained_model = train_cfm(
    train_ds,
    val_ds,
    model_cfg=model_cfg,
    ckpt_path=checkpoint_path,
    batch_size=min(2048, max(64, len(train_ds)//8)),  # Adaptive batch size
    lr=3e-4,           # Learning rate
    epochs=400,        # Maximum epochs
    patience=100,      # Early stopping patience
    warmup_steps=2000, # Learning rate warmup steps
    ema_decay=0.9997,  # EMA decay rate
)

print("Training completed successfully!")
print(f"Best model saved to: {checkpoint_path}")