---
## Table of Contents

1. [Configuration & Setup](#1-configuration--setup)
2. [Feature Definitions](#2-feature-definitions)
3. [Data Preparation](#3-data-preparation)
4. [Model Architecture](#4-model-architecture)
5. [Training Utilities](#5-training-utilities)
6. [Hyperparameter Tuning](#6-hyperparameter-tuning)
7. [Final Training](#7-final-training)
8. [Threshold Optimization](#8-threshold-optimization)
9. [Inference & Results](#9-inference--results)

---
## 1. Configuration & Setup

In [None]:
# =============================================================================
# IMPORTS
# =============================================================================

import os
import sys
import gc
import time
import copy
import uuid
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd

# PySpark
import pyspark.sql.functions as sf
from pyspark.sql import Window
from pyspark.sql.types import FloatType, IntegerType
from pyspark.ml.feature import StringIndexer

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# Scikit-learn
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import fbeta_score, mean_absolute_error

# MLflow & Optuna
import mlflow
from mlflow.models import infer_signature
import optuna
from optuna.integration.mlflow import MLflowCallback

# Enable Arrow for faster Spark-to-Pandas conversion
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

print(f"PyTorch version: {torch.__version__}")
print(f"MLflow version: {mlflow.__version__}")

In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

# Paths
BASE_PATH = "dbfs:/student-groups/Group_2_2/5_year_custom_joined"
DATA_PATH = f"{BASE_PATH}/fe_graph_and_holiday_nnfeat/training_splits"
CV_DATA_PATH = f"{BASE_PATH}/fe_graph_and_holiday_nnfeat/cv_splits"
PREDICTIONS_PATH = f"{BASE_PATH}/nn_predictions_final"

# MLflow
EXPERIMENT_NAME = "/Shared/team_2_2/mlflow-nn-tower-final"
mlflow.set_experiment(EXPERIMENT_NAME)

# Device Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Best Hyperparameters (from Optuna tuning)
BEST_PARAMS = {
    "lr": 0.0001556,
    "batch_size": 4096,
    "alpha": 0.342,           # Loss weighting: α*MAE + (1-α)*BCE
    "time_dim": 16,           # Time2Vec dimensions
    "emb_drop": 0.046,        # Embedding dropout
    "num_drop": 0.324,        # Numerical tower dropout
    "final_drop": 0.100       # Final layer dropout
}

# Training Configuration
NUM_EPOCHS = 10
PATIENCE = 4                  # Early stopping patience
DELAY_THRESHOLD = 15.0        # Minutes - defines "delayed" flight
OPTIMAL_THRESHOLD = 0.36      # Classification threshold (from optimization)

---
## 2. Feature Definitions

The model uses three types of features:
- **Categorical**: Encoded as learnable embeddings
- **Numerical**: Standardized and processed through residual blocks
- **Temporal**: Departure time encoded via Time2Vec

In [None]:
# =============================================================================
# FEATURE DEFINITIONS
# =============================================================================

CATEGORICAL_COLS = [
    "OP_UNIQUE_CARRIER",       # Airline carrier code
    "ORIGIN_AIRPORT_SEQ_ID",   # Origin airport identifier
    "DEST_AIRPORT_SEQ_ID",     # Destination airport identifier
    "route",                   # Origin-Destination pair
    "AIRPORT_HUB_CLASS",       # Hub classification
    "AIRLINE_CATEGORY"         # Airline category
]

NUMERICAL_COLS = [
    # Flight characteristics
    "DISTANCE", "CRS_ELAPSED_TIME",
    
    # Historical delay features
    "prev_flight_delay_in_minutes", "origin_delays_4h",
    "delay_origin_7d", "delay_origin_carrier_7d", "delay_route_7d",
    "flight_count_24h", "AVG_TAXI_OUT_ORIGIN", "AVG_ARR_DELAY_ORIGIN",
    
    # Graph-based features
    "in_degree", "out_degree", "weighted_in_degree", "weighted_out_degree",
    "betweenness", "closeness",
    
    # Airport features
    "N_RUNWAYS",
    
    # Weather features
    "HourlyVisibility", "HourlyStationPressure", "HourlyWindSpeed",
    "HourlyDryBulbTemperature", "HourlyDewPointTemperature",
    "HourlyRelativeHumidity", "HourlyAltimeterSetting",
    "HourlyWetBulbTemperature", "HourlyPrecipitation",
    "HourlyCloudCoverage", "HourlyCloudElevation",
    
    # Traffic features
    "ground_flights_last_hour", "arrivals_last_hour",
    
    # Cyclical time encodings
    "dow_sin", "dow_cos",  # Day of week
    "doy_sin", "doy_cos"   # Day of year
]

# Ensure no duplicates
NUMERICAL_COLS = list(dict.fromkeys(NUMERICAL_COLS))

TIME_COL = "CRS_DEP_MINUTES"
TARGET_COL = "DEP_DELAY_NEW"

print(f"Categorical features: {len(CATEGORICAL_COLS)}")
print(f"Numerical features: {len(NUMERICAL_COLS)}")

---
## 3. Data Preparation

Functions for loading Spark DataFrames, applying string indexing, scaling numerical features, and converting to PyTorch datasets.

In [None]:
# =============================================================================
# PYTORCH DATASET
# =============================================================================

class FlightDataset(Dataset):
    """
    PyTorch Dataset for flight delay prediction.
    
    Separates features into categorical (for embeddings), numerical (for MLP),
    and temporal (for Time2Vec) components.
    """
    
    def __init__(
        self, 
        df: pd.DataFrame,
        cat_cols: List[str] = CATEGORICAL_COLS,
        num_cols: List[str] = NUMERICAL_COLS,
        time_col: str = TIME_COL,
        target_col: str = TARGET_COL,
        id_col: str = "flight_uid"
    ):
        self.cat = torch.tensor(df[cat_cols].values, dtype=torch.long)
        self.num = torch.tensor(df[num_cols].values, dtype=torch.float32)
        self.time = torch.tensor(df[time_col].values, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(df[target_col].values, dtype=torch.float32).unsqueeze(1)
        self.ids = df[id_col].values if id_col in df.columns else np.arange(len(df))
    
    def __len__(self) -> int:
        return len(self.y)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, ...]:
        return self.cat[idx], self.num[idx], self.time[idx], self.y[idx], self.ids[idx]

In [None]:
# =============================================================================
# DATA PREPARATION FUNCTIONS
# =============================================================================

def spark_to_pandas(
    df, 
    indexers: List[StringIndexer],
    cat_cols: List[str] = CATEGORICAL_COLS,
    num_cols: List[str] = NUMERICAL_COLS,
    time_col: str = TIME_COL,
    target_col: str = TARGET_COL
) -> pd.DataFrame:
    """
    Transform Spark DataFrame to Pandas with proper type casting.
    
    Args:
        df: Spark DataFrame
        indexers: Fitted StringIndexer models for categorical columns
        cat_cols: Categorical column names
        num_cols: Numerical column names
        time_col: Time feature column name
        target_col: Target column name
    
    Returns:
        Pandas DataFrame with indexed categoricals and float numericals
    """
    # Apply string indexers
    for indexer in indexers:
        df = indexer.transform(df)
    
    # Build select expression with proper types
    select_expr = (
        [sf.col(f"{c}_idx").cast(IntegerType()).alias(c) for c in cat_cols] +
        [sf.col(c).cast(FloatType()) for c in num_cols] +
        [sf.col(time_col).cast(FloatType()), 
         sf.col(target_col).cast(FloatType()),
         sf.col("flight_uid")]
    )
    
    return df.select(*select_expr).toPandas()


def prepare_data_splits(
    train_path: str,
    val_path: str,
    test_path: str = None
) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame], List, StandardScaler]:
    """
    Load and prepare train/val/test splits with consistent preprocessing.
    
    Returns:
        train_pd, val_pd, test_pd (optional), indexers, scaler
    """
    print("Loading data...")
    train_spark = spark.read.parquet(train_path)
    val_spark = spark.read.parquet(val_path)
    test_spark = spark.read.parquet(test_path) if test_path else None
    
    # Fit string indexers on training data
    print("Fitting string indexers...")
    indexers = [
        StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep")
        .fit(train_spark) 
        for c in CATEGORICAL_COLS
    ]
    
    # Convert to Pandas
    print("Converting to Pandas...")
    train_pd = spark_to_pandas(train_spark, indexers)
    val_pd = spark_to_pandas(val_spark, indexers)
    test_pd = spark_to_pandas(test_spark, indexers) if test_spark else None
    
    # Fit scaler on training data
    print("Fitting scaler...")
    scaler = StandardScaler()
    train_pd[NUMERICAL_COLS] = scaler.fit_transform(train_pd[NUMERICAL_COLS])
    val_pd[NUMERICAL_COLS] = scaler.transform(val_pd[NUMERICAL_COLS])
    if test_pd is not None:
        test_pd[NUMERICAL_COLS] = scaler.transform(test_pd[NUMERICAL_COLS])
    
    # Calculate embedding dimensions
    cat_dims = [int(train_pd[c].max() + 2) for c in CATEGORICAL_COLS]
    emb_dims = [min(64, int(n**0.3)) for n in cat_dims]
    
    print(f"Train: {len(train_pd):,} | Val: {len(val_pd):,}", end="")
    if test_pd is not None:
        print(f" | Test: {len(test_pd):,}")
    else:
        print()
    
    return train_pd, val_pd, test_pd, cat_dims, emb_dims, indexers, scaler

---
## 4. Model Architecture

The **ResFiLM-MLP** architecture consists of:

1. **Embedding Tower**: Learns dense representations for categorical features
2. **Numerical Tower**: 4 residual blocks with LayerNorm and GELU activations
3. **FiLM Layer**: Feature-wise Linear Modulation - uses numerical features to generate γ (scale) and β (shift) parameters that modulate the embeddings
4. **Time2Vec**: Learnable periodic encoding for departure time
5. **Dual Prediction Heads**: Separate heads for regression (delay minutes) and classification (delayed/not delayed)

In [None]:
# =============================================================================
# TRAINING UTILITIES
# =============================================================================

def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: optim.Optimizer,
    criterion_reg: nn.Module,
    criterion_clf: nn.Module,
    alpha: float,
    device: torch.device = DEVICE
) -> float:
    """
    Train model for one epoch.
    
    Loss = α * MAE_loss + (1-α) * BCE_loss
    
    Args:
        model: Neural network model
        loader: Training data loader
        optimizer: Optimizer
        criterion_reg: Regression loss (L1/MAE)
        criterion_clf: Classification loss (BCE)
        alpha: Weight for regression loss
        device: Compute device
    
    Returns:
        Average training loss
    """
    model.train()
    total_loss = 0.0
    
    for cat, num, time_feat, y, _ in loader:
        cat = cat.to(device)
        num = num.to(device)
        time_feat = time_feat.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()
        
        reg_out, clf_out = model(cat, num, time_feat)
        y_class = (y >= DELAY_THRESHOLD).float()
        
        loss = alpha * criterion_reg(reg_out, y) + (1 - alpha) * criterion_clf(clf_out, y_class)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)


@torch.no_grad()
def evaluate_model(
    model: nn.Module,
    loader: DataLoader,
    threshold: float = 0.5,
    device: torch.device = DEVICE
) -> Tuple[float, float]:
    """
    Evaluate model on a dataset.
    
    Args:
        model: Neural network model
        loader: Data loader
        threshold: Classification threshold
        device: Compute device
    
    Returns:
        f2_score, mae
    """
    model.eval()
    all_y_true, all_y_reg, all_y_prob = [], [], []
    
    for cat, num, time_feat, y, _ in loader:
        reg_out, clf_out = model(
            cat.to(device), num.to(device), time_feat.to(device)
        )
        
        all_y_true.append(y.cpu())
        all_y_reg.append(reg_out.cpu())
        all_y_prob.append(torch.sigmoid(clf_out).cpu())
    
    y_true = torch.cat(all_y_true).numpy().flatten()
    y_reg = torch.cat(all_y_reg).numpy().flatten()
    y_prob = torch.cat(all_y_prob).numpy().flatten()
    
    # Metrics
    mae = mean_absolute_error(y_true, y_reg)
    y_true_class = (y_true >= DELAY_THRESHOLD).astype(int)
    y_pred_class = (y_prob >= threshold).astype(int)
    f2 = fbeta_score(y_true_class, y_pred_class, beta=2, zero_division=0)
    
    return f2, mae

In [None]:
# =============================================================================
# PREDICTION & SAVING UTILITIES
# =============================================================================

@torch.no_grad()
def generate_predictions(
    model: nn.Module,
    loader: DataLoader,
    threshold: float = OPTIMAL_THRESHOLD,
    device: torch.device = DEVICE
) -> pd.DataFrame:
    """
    Generate predictions for all samples in a data loader.
    
    Args:
        model: Trained model
        loader: Data loader
        threshold: Classification threshold
        device: Compute device
    
    Returns:
        DataFrame with flight_uid, targets, and predictions
    """
    model.eval()
    results = {
        "flight_uid": [],
        "target_delay": [],
        "pred_delay": [],
        "target_class": [],
        "pred_prob": [],
        "pred_class": []
    }
    
    for cat, num, time_feat, y, ids in loader:
        reg_out, clf_out = model(
            cat.to(device), num.to(device), time_feat.to(device)
        )
        
        y_prob = torch.sigmoid(clf_out).cpu().numpy().flatten()
        
        results["flight_uid"].extend(ids)
        results["target_delay"].extend(y.numpy().flatten())
        results["pred_delay"].extend(reg_out.cpu().numpy().flatten())
        results["target_class"].extend((y >= DELAY_THRESHOLD).float().numpy().flatten())
        results["pred_prob"].extend(y_prob)
        results["pred_class"].extend((y_prob >= threshold).astype(int))
    
    return pd.DataFrame(results)


def save_predictions(
    predictions_df: pd.DataFrame,
    split_name: str,
    save_path: str = PREDICTIONS_PATH
) -> str:
    """
    Save predictions to DBFS as Parquet.
    
    Args:
        predictions_df: DataFrame with predictions
        split_name: Name of the split (e.g., "test", "validation")
        save_path: Base path for saving
    
    Returns:
        Full path where predictions were saved
    """
    unique_id = str(uuid.uuid4())[:8]
    full_path = f"{save_path}/{split_name}_{unique_id}"
    
    spark.createDataFrame(predictions_df).write.mode("overwrite").parquet(full_path)
    print(f"✓ Saved predictions to: {full_path}")
    
    return full_path

---
## 6. Hyperparameter Tuning

Optuna-based hyperparameter optimization using 3-fold cross-validation on a subset of folds for efficiency.

In [None]:
# =============================================================================
# HYPERPARAMETER TUNING (Optuna)
# =============================================================================

def prepare_fold_data(fold_df, fold_id: int):
    """Prepare data for a specific CV fold."""
    train_fe = fold_df.filter(sf.col("split_type") == "train")
    val_fe = fold_df.filter(sf.col("split_type") == "validation")
    
    # Cast types
    for c in NUMERICAL_COLS + [TIME_COL, TARGET_COL]:
        train_fe = train_fe.withColumn(c, sf.col(c).cast(FloatType()))
        val_fe = val_fe.withColumn(c, sf.col(c).cast(FloatType()))
    
    # String indexing
    indexers = []
    for c in CATEGORICAL_COLS:
        indexer = StringIndexer(
            inputCol=c, outputCol=f"{c}_idx",
            stringOrderType="alphabetAsc", handleInvalid="keep"
        )
        model = indexer.fit(train_fe)
        train_fe = model.transform(train_fe)
        val_fe = model.transform(val_fe)
        indexers.append(model)
    
    # Select and cast
    final_cols = [f"{c}_idx" for c in CATEGORICAL_COLS] + NUMERICAL_COLS + [TIME_COL, TARGET_COL]
    for c in CATEGORICAL_COLS:
        train_fe = train_fe.withColumn(f"{c}_idx", sf.col(f"{c}_idx").cast(IntegerType()))
        val_fe = val_fe.withColumn(f"{c}_idx", sf.col(f"{c}_idx").cast(IntegerType()))
    
    # Convert to Pandas
    train_pd = train_fe.select(final_cols).toPandas()
    val_pd = val_fe.select(final_cols).toPandas()
    
    # Rename columns
    rename_map = {f"{c}_idx": c for c in CATEGORICAL_COLS}
    train_pd = train_pd.rename(columns=rename_map)
    val_pd = val_pd.rename(columns=rename_map)
    
    # Scale
    scaler = StandardScaler()
    train_pd[NUMERICAL_COLS] = scaler.fit_transform(train_pd[NUMERICAL_COLS])
    val_pd[NUMERICAL_COLS] = scaler.transform(val_pd[NUMERICAL_COLS])
    
    # Embedding dimensions
    cat_dims = [int(train_pd[c].max() + 2) for c in CATEGORICAL_COLS]
    emb_dims = [min(64, int(n**0.3)) for n in cat_dims]
    
    return train_pd, val_pd, cat_dims, emb_dims


def optuna_objective(trial, cv_full_df, folds):
    """Optuna objective function for hyperparameter optimization."""
    params = {
        "lr": trial.suggest_float("lr", 1e-4, 5e-3, log=True),
        "batch_size": trial.suggest_categorical("batch_size", [1024, 2048, 4096]),
        "alpha": trial.suggest_float("alpha", 0.3, 0.7),
        "time_dim": trial.suggest_categorical("time_dim", [4, 8, 16]),
        "emb_drop": trial.suggest_float("emb_drop", 0.0, 0.4),
        "num_drop": trial.suggest_float("num_drop", 0.0, 0.4),
        "final_drop": trial.suggest_float("final_drop", 0.0, 0.4)
    }
    
    # Use 3 folds for efficiency: first, middle, last
    tuning_folds = [folds[0], folds[len(folds)//2], folds[-1]]
    val_f2_scores = []
    
    print(f"\n[Trial {trial.number}] Params: batch={params['batch_size']}, lr={params['lr']:.5f}")
    
    for i, fold_id in enumerate(tuning_folds):
        fold_df = cv_full_df.filter(sf.col("fold_id") == fold_id)
        train_pd, val_pd, cat_dims, emb_dims = prepare_fold_data(fold_df, fold_id)
        
        # Create datasets and loaders
        train_ds = FlightDataset(train_pd, id_col=None)
        val_ds = FlightDataset(val_pd, id_col=None)
        train_dl = DataLoader(train_ds, batch_size=params["batch_size"], shuffle=True, num_workers=0)
        val_dl = DataLoader(val_ds, batch_size=params["batch_size"], num_workers=0)
        
        # Create model
        model = ResFiLMMLP(
            cat_dims, emb_dims, len(NUMERICAL_COLS),
            time_dim=params["time_dim"],
            emb_dropout=params["emb_drop"],
            num_dropout=params["num_drop"],
            final_dropout=params["final_drop"]
        ).to(DEVICE)
        
        optimizer = optim.AdamW(model.parameters(), lr=params["lr"])
        criterion_reg = nn.L1Loss()
        criterion_clf = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4.0]).to(DEVICE))
        
        # Train for 3 epochs
        for _ in range(3):
            train_one_epoch(model, train_dl, optimizer, criterion_reg, criterion_clf, params["alpha"])
        
        # Evaluate
        val_f2, val_mae = evaluate_model(model, val_dl)
        val_f2_scores.append(val_f2)
        print(f"  Fold {fold_id}: Val F2={val_f2:.3f}")
        
        # Pruning
        trial.report(np.mean(val_f2_scores), i)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
    
    mean_f2 = np.mean(val_f2_scores)
    
    # Log to MLflow
    with mlflow.start_run(nested=True, run_name=f"Trial_{trial.number}"):
        mlflow.log_params(params)
        mlflow.log_metric("val_f2", mean_f2)
    
    return mean_f2

In [None]:
# =============================================================================
# RUN HYPERPARAMETER TUNING
# =============================================================================

# Note: Set RUN_TUNING = True to execute hyperparameter search
# This takes ~8 hours on Databricks with CPU

RUN_TUNING = False

if RUN_TUNING:
    # Load CV data
    cv_full_df = spark.read.parquet(CV_DATA_PATH)
    folds = sorted([row['fold_id'] for row in cv_full_df.select("fold_id").distinct().collect()])
    
    print(f"Starting Optuna Tuning (8 Trials)")
    print(f"Folds available: {folds}")
    
    mlflow.set_experiment("/Shared/team_2_2/mlflow-nn-tower-tuned")
    
    with mlflow.start_run(run_name="Hyperparameter_Tuning"):
        study = optuna.create_study(direction="maximize")
        study.optimize(
            lambda trial: optuna_objective(trial, cv_full_df, folds),
            n_trials=8
        )
    
    print(f"\n{'='*60}")
    print(f"BEST PARAMETERS: {study.best_params}")
    print(f"BEST F2 SCORE: {study.best_value:.4f}")
    print(f"{'='*60}")
else:
    print("Hyperparameter tuning skipped. Using pre-tuned parameters:")
    print(f"  {BEST_PARAMS}")

In [None]:
# =============================================================================
# FINAL TRAINING FUNCTION
# =============================================================================

def train_final_model(
    train_dl: DataLoader,
    val_dl: DataLoader,
    cat_dims: List[int],
    emb_dims: List[int],
    params: Dict = BEST_PARAMS,
    num_epochs: int = NUM_EPOCHS,
    patience: int = PATIENCE,
    device: torch.device = DEVICE
) -> Tuple[nn.Module, Dict]:
    """
    Train final model with early stopping.
    
    Args:
        train_dl: Training data loader
        val_dl: Validation data loader
        cat_dims: Categorical feature dimensions
        emb_dims: Embedding dimensions
        params: Hyperparameters
        num_epochs: Maximum training epochs
        patience: Early stopping patience
        device: Compute device
    
    Returns:
        Trained model and training history
    """
    # Initialize model
    model = ResFiLMMLP(
        cat_dims, emb_dims, len(NUMERICAL_COLS),
        time_dim=params["time_dim"],
        emb_dropout=params["emb_drop"],
        num_dropout=params["num_drop"],
        final_dropout=params["final_drop"]
    ).to(device)
    
    optimizer = optim.AdamW(model.parameters(), lr=params["lr"])
    criterion_reg = nn.L1Loss()
    criterion_clf = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4.0]).to(device))
    
    # Training state
    best_f2 = -1.0
    best_state = None
    patience_counter = 0
    history = {"train_f2": [], "train_mae": [], "val_f2": [], "val_mae": []}
    
    print(f"\n{'='*60}")
    print(f"TRAINING: {num_epochs} epochs, patience={patience}")
    print(f"{'='*60}")
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        train_tp, train_fp, train_fn = 0, 0, 0
        train_ae_sum = 0.0
        total_samples = 0
        
        for i, (cat, num, time_feat, y, _) in enumerate(train_dl):
            cat, num, time_feat, y = cat.to(device), num.to(device), time_feat.to(device), y.to(device)
            
            optimizer.zero_grad()
            reg_out, clf_out = model(cat, num, time_feat)
            y_class = (y >= DELAY_THRESHOLD).float()
            
            loss = params["alpha"] * criterion_reg(reg_out, y) + (1 - params["alpha"]) * criterion_clf(clf_out, y_class)
            loss.backward()
            optimizer.step()
            
            # Track metrics
            with torch.no_grad():
                train_loss += loss.item()
                train_ae_sum += torch.sum(torch.abs(reg_out - y)).item()
                y_pred = (torch.sigmoid(clf_out) > 0.5).long()
                y_true = y_class.long()
                train_tp += ((y_true == 1) & (y_pred == 1)).sum().item()
                train_fp += ((y_true == 0) & (y_pred == 1)).sum().item()
                train_fn += ((y_true == 1) & (y_pred == 0)).sum().item()
                total_samples += y.size(0)
            
            if (i + 1) % 500 == 0:
                print(f"  Epoch {epoch} - Batch {i+1}/{len(train_dl)}", end="\r")
        
        # Calculate training metrics
        train_mae = train_ae_sum / total_samples
        precision = train_tp / (train_tp + train_fp + 1e-8)
        recall = train_tp / (train_tp + train_fn + 1e-8)
        train_f2 = (1 + 4) * (precision * recall) / (4 * precision + recall + 1e-8)
        
        # Validation
        val_f2, val_mae = evaluate_model(model, val_dl)
        
        # Store history
        history["train_f2"].append(train_f2)
        history["train_mae"].append(train_mae)
        history["val_f2"].append(val_f2)
        history["val_mae"].append(val_mae)
        
        print(f"  Epoch {epoch}: Val F2={val_f2:.4f} MAE={val_mae:.2f} | Train F2={train_f2:.4f} MAE={train_mae:.2f}")
        
        # Early stopping check
        if val_f2 > best_f2:
            best_f2 = val_f2
            best_state = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"    >> No improvement. Patience: {patience_counter}/{patience}")
            if patience_counter >= patience:
                print("    >> Early stopping triggered!")
                break
    
    # Load best model
    model.load_state_dict(best_state)
    print(f"\n✓ Training complete. Best Val F2: {best_f2:.4f}")
    
    return model, history

In [None]:
# =============================================================================
# RUN FINAL TRAINING
# =============================================================================

# Note: Set RUN_TRAINING = True to execute training
# This takes ~3-4 hours on Databricks with CPU

RUN_TRAINING = False

if RUN_TRAINING:
    # Prepare data
    train_pd, val_pd, _, cat_dims, emb_dims, indexers, scaler = prepare_data_splits(
        f"{DATA_PATH}/train.parquet",
        f"{DATA_PATH}/val.parquet"
    )
    
    # Create datasets
    train_ds = FlightDataset(train_pd)
    val_ds = FlightDataset(val_pd)
    
    train_dl = DataLoader(train_ds, batch_size=BEST_PARAMS["batch_size"], shuffle=True, num_workers=0)
    val_dl = DataLoader(val_ds, batch_size=BEST_PARAMS["batch_size"], num_workers=0)
    
    # Clean up
    del train_pd, val_pd
    gc.collect()
    
    # Train with MLflow logging
    with mlflow.start_run(run_name="Final_Production_Training"):
        mlflow.log_params(BEST_PARAMS)
        
        model, history = train_final_model(train_dl, val_dl, cat_dims, emb_dims)
        
        # Log final metrics
        mlflow.log_metric("best_val_f2", max(history["val_f2"]))
        mlflow.log_metric("best_val_mae", min(history["val_mae"]))
        
        # Save model
        mlflow.pytorch.log_model(model, "model_final")
        
    print("✓ Model saved to MLflow")
else:
    print("Training skipped. Load pre-trained model from MLflow for inference.")

---
## 8. Threshold Optimization

Find the optimal classification threshold by sweeping thresholds and maximizing F2 score on the validation set.

In [None]:
# =============================================================================
# THRESHOLD OPTIMIZATION
# =============================================================================

@torch.no_grad()
def find_optimal_threshold(
    model: nn.Module,
    loader: DataLoader,
    threshold_range: np.ndarray = np.arange(0.05, 0.95, 0.01),
    device: torch.device = DEVICE
) -> Tuple[float, float, pd.DataFrame]:
    """
    Find optimal classification threshold by maximizing F2 score.
    
    Args:
        model: Trained model
        loader: Validation data loader
        threshold_range: Thresholds to evaluate
        device: Compute device
    
    Returns:
        optimal_threshold, best_f2, threshold_curve_df
    """
    print("Generating validation probabilities...")
    model.eval()
    
    all_y_true, all_y_prob = [], []
    for cat, num, time_feat, y, _ in loader:
        _, clf_out = model(cat.to(device), num.to(device), time_feat.to(device))
        all_y_true.append((y >= DELAY_THRESHOLD).long().cpu().numpy().flatten())
        all_y_prob.append(torch.sigmoid(clf_out).cpu().numpy().flatten())
    
    y_true = np.concatenate(all_y_true)
    y_prob = np.concatenate(all_y_prob)
    
    print("Sweeping thresholds...")
    best_f2, optimal_threshold = -1.0, 0.5
    f2_scores = []
    
    for threshold in threshold_range:
        y_pred = (y_prob >= threshold).astype(int)
        f2 = fbeta_score(y_true, y_pred, beta=2, zero_division=0)
        f2_scores.append(f2)
        
        if f2 > best_f2:
            best_f2 = f2
            optimal_threshold = threshold
    
    curve_df = pd.DataFrame({"threshold": threshold_range, "f2_score": f2_scores})
    
    print(f"\n{'='*50}")
    print(f"OPTIMAL THRESHOLD: {optimal_threshold:.3f}")
    print(f"BEST F2 SCORE: {best_f2:.5f}")
    print(f"{'='*50}")
    
    return optimal_threshold, best_f2, curve_df


# Example usage (requires trained model and data):
# optimal_threshold, best_f2, curve_df = find_optimal_threshold(model, val_dl)

---
## 9. Inference & Results

Load the trained model and generate predictions on test data.

In [None]:
# =============================================================================
# INFERENCE ON TEST SET
# =============================================================================

# MLflow Run ID for the trained model
MLFLOW_RUN_ID = "8706b956e0bd4ed681234979ad86206b"

RUN_INFERENCE = False

if RUN_INFERENCE:
    print("Loading trained model from MLflow...")
    model = mlflow.pytorch.load_model(f"runs:/{MLFLOW_RUN_ID}/model_final")
    model.eval()
    
    # Prepare test data
    print("Preparing test data...")
    train_spark = spark.read.parquet(f"{DATA_PATH}/train.parquet")
    test_spark = spark.read.parquet(f"{DATA_PATH}/test.parquet")
    
    # Fit indexers on train (must match training)
    indexers = [
        StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep")
        .fit(train_spark) 
        for c in CATEGORICAL_COLS
    ]
    
    # Convert test to Pandas
    test_pd = spark_to_pandas(test_spark, indexers)
    
    # Fit scaler on train, transform test
    train_pd_for_scaler = spark_to_pandas(train_spark, indexers)
    scaler = StandardScaler()
    scaler.fit(train_pd_for_scaler[NUMERICAL_COLS])
    test_pd[NUMERICAL_COLS] = scaler.transform(test_pd[NUMERICAL_COLS])
    
    del train_spark, test_spark, train_pd_for_scaler
    gc.collect()
    
    # Create test loader
    test_ds = FlightDataset(test_pd)
    test_dl = DataLoader(test_ds, batch_size=BEST_PARAMS["batch_size"], num_workers=0)
    
    # Generate predictions
    print("Generating predictions...")
    start_time = time.time()
    predictions_df = generate_predictions(model, test_dl, threshold=OPTIMAL_THRESHOLD)
    inference_time = time.time() - start_time
    
    # Calculate metrics
    test_f2 = fbeta_score(
        predictions_df["target_class"],
        predictions_df["pred_class"],
        beta=2
    )
    test_mae = mean_absolute_error(
        predictions_df["target_delay"],
        predictions_df["pred_delay"]
    )
    
    print(f"\n{'='*60}")
    print(f"TEST SET RESULTS (Threshold: {OPTIMAL_THRESHOLD})")
    print(f"{'='*60}")
    print(f"  F2 Score: {test_f2:.5f}")
    print(f"  MAE: {test_mae:.4f} minutes")
    print(f"  Inference Time: {inference_time:.2f} seconds")
    print(f"  Samples: {len(predictions_df):,}")
    print(f"{'='*60}")
    
    # Save predictions
    save_predictions(predictions_df, "FINAL_TEST_OPTIMIZED")
else:
    print("Inference skipped. Set RUN_INFERENCE = True to generate predictions.")

---
## 10. Summary & Conclusions

### Model Performance

| Dataset | F2 Score | MAE (minutes) |
|---------|----------|---------------|
| Train | 0.621 | 9.53 |
| Validation | 0.626 | 10.84 |
| **Test** | **0.619** | **11.34** |

### Key Findings

1. **Multi-Task Learning**: The dual-head architecture with combined regression and classification loss effectively handles both delay prediction and delay classification tasks.

2. **FiLM Modulation**: Feature-wise Linear Modulation allows the numerical features (weather, historical delays) to contextualize the categorical embeddings (airports, carriers), improving feature interaction modeling.

3. **Class Imbalance**: The 4x positive weighting in BCE loss significantly improved recall for delayed flights, which is critical given the F2 metric's emphasis on recall.

4. **Threshold Optimization**: Moving from the default 0.5 threshold to 0.36 improved F2 from ~0.60 to ~0.63, demonstrating the importance of threshold tuning for imbalanced classification.

### Hyperparameter Insights

| Parameter | Optimal Value | Impact |
|-----------|---------------|--------|
| Learning Rate | 1.56e-4 | Lower LR with AdamW prevented overfitting |
| Batch Size | 4096 | Larger batches stabilized training |
| α (Loss Weight) | 0.342 | Slight emphasis on classification |
| Time Dim | 16 | Rich temporal encoding |
| Dropout | 0.05-0.32 | Moderate regularization |

### Future Improvements

1. **GPU Training**: Current implementation runs on CPU; GPU would enable larger models and more epochs
2. **Attention Mechanisms**: Self-attention over temporal sequences could capture longer-range dependencies
3. **Ensemble**: Combining with XGBoost predictions could improve robustness