In [1]:
# ============================================================
#  Nexar Dashcam Crash Prediction - Hybrid VideoMAE + CNN-GRU
# ============================================================

import os
import gc
import math
import random
from pathlib import Path

import numpy as np
import pandas as pd
import imageio
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    average_precision_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    roc_curve,
    classification_report,
    confusion_matrix,
)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.models import resnet18, ResNet18_Weights

from transformers import VideoMAEConfig, VideoMAEModel

from tqdm.auto import tqdm

# ============================================================
#  CONFIG
# ============================================================

SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Clip configuration
NUM_FRAMES = 16       # T
FRAME_STRIDE = 3      # S (temporal stride in frames)
IMG_SIZE = 224

# Training configuration
BATCH_SIZE = 4
NUM_EPOCHS = 10
LR = 1e-5
WEIGHT_DECAY = 1e-4
EARLY_STOP_PATIENCE = 3   # early stopping patience (epochs without mAP improvement)

# Evaluation horizons (seconds before collision)
HORIZONS = [0.5, 1.0, 1.5]

# Data paths (must match competition files)
CSV_PATH = "/kaggle/input/nexar-collision-prediction/train.csv"
TRAIN_VIDEO_DIR = "/kaggle/input/nexar-256x256/train_resized"

# Where to save model & plots
OUTPUT_DIR = "./nexar_hybrid_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)


# ============================================================
#  REPRODUCIBILITY
# ============================================================

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(SEED)


# ============================================================
#  DATA LOADING & SPLITTING (70/10/20)
# ============================================================

df = pd.read_csv(CSV_PATH)

# The Kaggle train.csv already has "target" (1 = collision / near-miss, 0 = normal).
assert "target" in df.columns, "Expected 'target' column in train.csv"
assert "id" in df.columns, "Expected 'id' column in train.csv"
assert "time_of_event" in df.columns, "Expected 'time_of_event' column in train.csv"

# 70/10/20 stratified split
train_val_df, test_df = train_test_split(
    df,
    test_size=0.20,
    stratify=df["target"],
    random_state=SEED,
)

train_df, val_df = train_test_split(
    train_val_df,
    test_size=0.125,  # 0.125 * 0.8 = 0.10 of original
    stratify=train_val_df["target"],
    random_state=SEED,
)

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

print("Total:", len(df))
print("Train:", len(train_df))
print("Val  :", len(val_df))
print("Test :", len(test_df))


# ============================================================
#  UTILS
# ============================================================

def id_to_path(sample_id: int) -> str:
    """
    Convert numeric id to zero-padded 5-digit string and build video path.
    """
    sample = str(int(sample_id))
    sample = "0" * (5 - len(sample)) + sample
    return f"{TRAIN_VIDEO_DIR}/{sample}_256x256.mp4"


def load_video_meta(reader):
    md = reader.get_meta_data()
    fps = md["fps"]
    duration = md["duration"]
    n_frames = int(round(fps * duration))
    return fps, duration, n_frames


def read_clip_frames(reader, frame_indices, expected_T):
    """
    Read a list of frame indices from a video and return [T, H, W, 3],
    padding at the beginning if not enough frames.
    """
    frames = []
    for idx in frame_indices:
        try:
            frames.append(reader.get_data(idx))
        except:
            break

    if len(frames) == 0:
        # Should not really happen; fall back to first frame repeated
        first = reader.get_data(0)
        frames = [first] * expected_T

    frames = np.stack(frames)  # [len, H, W, 3]

    # If we have fewer than expected_T, pad at the beginning with zeros.
    if frames.shape[0] < expected_T:
        pad_len = expected_T - frames.shape[0]
        pad = np.zeros((pad_len,) + frames.shape[1:], dtype=frames.dtype)
        frames = np.concatenate([pad, frames], axis=0)

    # If we have more (shouldn't), keep last expected_T frames
    if frames.shape[0] > expected_T:
        frames = frames[-expected_T:]

    return frames  # [T, H, W, 3]


def spatial_augment_train(clip):
    """
    clip: torch tensor [T, 3, H, W] in [0,1]
    Apply simple random crop + horizontal flip + brightness/contrast jitter.
    """
    T, C, H, W = clip.shape

    # Random crop to IMG_SIZE
    if H > IMG_SIZE and W > IMG_SIZE:
        top = np.random.randint(0, H - IMG_SIZE + 1)
        left = np.random.randint(0, W - IMG_SIZE + 1)
        clip = clip[:, :, top:top + IMG_SIZE, left:left + IMG_SIZE]
    else:
        # Just center crop-ish
        top = max((H - IMG_SIZE) // 2, 0)
        left = max((W - IMG_SIZE) // 2, 0)
        bottom = top + IMG_SIZE
        right = left + IMG_SIZE
        clip = clip[:, :, top:bottom, left:right]

    # Random horizontal flip
    if np.random.rand() < 0.5:
        clip = torch.flip(clip, dims=[-1])

    # Brightness & contrast jitter
    brightness = np.random.normal(0.0, 0.05)
    contrast = np.random.normal(1.0, 0.1)
    clip = clip * contrast + brightness
    clip = torch.clamp(clip, 0.0, 1.0)

    return clip


def spatial_center_crop_eval(clip):
    """
    clip: torch tensor [T, 3, H, W] in [0,1]
    Center crop to IMG_SIZE.
    """
    T, C, H, W = clip.shape
    top = max((H - IMG_SIZE) // 2, 0)
    left = max((W - IMG_SIZE) // 2, 0)
    bottom = top + IMG_SIZE
    right = left + IMG_SIZE
    clip = clip[:, :, top:bottom, left:right]
    return clip


# ============================================================
#  DATASETS
# ============================================================

class NexarTrainDataset(Dataset):
    """
    Training dataset:
    - Positive videos: sample random windows in [0.5, 1.5] sec before time_of_event.
    - Negative videos: sample random window anywhere.
    """

    def __init__(self, df, num_frames=NUM_FRAMES, frame_stride=FRAME_STRIDE):
        self.df = df.reset_index(drop=True)
        self.num_frames = num_frames
        self.frame_stride = frame_stride

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        vid_path = id_to_path(row["id"])
        target = int(row["target"])
        time_of_event = row["time_of_event"]

        reader = imageio.get_reader(vid_path, "ffmpeg")
        fps, duration, n_frames = load_video_meta(reader)

        total_needed = self.num_frames * self.frame_stride

        if target == 1:
            # Positive: sample a random horizon in [0.5, 1.5] seconds before collision
            horizon = np.random.uniform(0.5, 1.5)
            end_time = max(time_of_event - horizon, 0.0)
            end_frame = int(round(end_time * fps))

            if end_frame <= 0:
                # If very early, just take earliest frames
                start_frame = 0
                end_frame = min(total_needed, n_frames)
            else:
                start_frame = end_frame - total_needed
                if start_frame < 0:
                    start_frame = 0
        else:
            # Negative: random window anywhere in the video
            if n_frames > total_needed:
                start_frame = np.random.randint(0, n_frames - total_needed + 1)
                end_frame = start_frame + total_needed
            else:
                start_frame = 0
                end_frame = min(total_needed, n_frames)

        frame_indices = list(range(start_frame, end_frame, self.frame_stride))
        clip_np = read_clip_frames(reader, frame_indices, self.num_frames)
        reader.close()

        # [T, H, W, 3] -> [T, 3, H, W]
        clip = torch.from_numpy(clip_np).float().permute(0, 3, 1, 2) / 255.0

        # Augment
        clip = spatial_augment_train(clip)

        # [T, 3, H, W]
        label = torch.tensor(target, dtype=torch.float32)

        return clip, label


class NexarEvalDataset(Dataset):
    """
    Evaluation dataset for a given horizon (seconds before event):
    - For positive videos: clips end at (time_of_event - horizon).
    - For negative videos: clips end at the last frame (near end of video).
    """

    def __init__(self, df, horizon, num_frames=NUM_FRAMES, frame_stride=FRAME_STRIDE):
        self.df = df.reset_index(drop=True)
        self.horizon = float(horizon)
        self.num_frames = num_frames
        self.frame_stride = frame_stride

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        vid_path = id_to_path(row["id"])
        target = int(row["target"])
        time_of_event = row["time_of_event"]

        reader = imageio.get_reader(vid_path, "ffmpeg")
        fps, duration, n_frames = load_video_meta(reader)

        total_needed = self.num_frames * self.frame_stride

        if target == 1:
            # Positive: deterministic window relative to event time
            end_time = max(time_of_event - self.horizon, 0.0)
            end_frame = int(round(end_time * fps))
            if end_frame <= 0:
                start_frame = 0
                end_frame = min(total_needed, n_frames)
            else:
                start_frame = end_frame - total_needed
                if start_frame < 0:
                    start_frame = 0
        else:
            # Negative: use last part of video
            end_frame = n_frames
            start_frame = max(0, end_frame - total_needed)

        frame_indices = list(range(start_frame, end_frame, self.frame_stride))
        clip_np = read_clip_frames(reader, frame_indices, self.num_frames)
        reader.close()

        # [T, H, W, 3] -> [T, 3, H, W]
        clip = torch.from_numpy(clip_np).float().permute(0, 3, 1, 2) / 255.0

        # Center crop only
        clip = spatial_center_crop_eval(clip)

        label = torch.tensor(target, dtype=torch.float32)
        return clip, label


# ============================================================
#  MODEL: Hybrid VideoMAE + CNN-GRU
# ============================================================

class HybridVideoMAEGRU(nn.Module):
    """
    Two-branch architecture:

    1. CNN-GRU branch:
       - Frame-wise ResNet18 backbone (IMAGENET pre-trained)
       - BiGRU over time

    2. VideoMAE branch:
       - Pretrained VideoMAE backbone (no classifier)
       - Mean pooling over patch tokens

    Fused features -> MLP head -> single logit (binary collision / near-miss).
    """

    def __init__(
        self,
        videomae_name: str = "MCG-NJU/videomae-base",
        gru_hidden_size: int = 256,
        fusion_hidden_size: int = 256,
    ):
        super().__init__()

        # --- CNN backbone (2D ResNet18) ---
        self.cnn = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.cnn.fc = nn.Identity()  # remove classifier (512-dim features)
        cnn_feat_dim = 512

        # --- Temporal GRU head ---
        self.gru = nn.GRU(
            input_size=cnn_feat_dim,
            hidden_size=gru_hidden_size,
            batch_first=True,
            bidirectional=True,
        )
        gru_out_dim = gru_hidden_size * 2  # bidirectional

        # --- VideoMAE branch ---
        config = VideoMAEConfig.from_pretrained(videomae_name)
        self.videomae = VideoMAEModel.from_pretrained(videomae_name, config=config)
        videomae_hidden_size = config.hidden_size

        # --- Fusion head ---
        fusion_dim = gru_out_dim + videomae_hidden_size

        self.head = nn.Sequential(
            nn.LayerNorm(fusion_dim),
            nn.Linear(fusion_dim, fusion_hidden_size),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(fusion_hidden_size, 1),  # single logit
        )

        # Cache normalization constants
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1)
        self.register_buffer("pixel_mean", mean)
        self.register_buffer("pixel_std", std)

    def forward(self, x):
        """
        x: [B, T, 3, H, W] in [0,1]
        """
        B, T, C, H, W = x.shape

        # Normalize
        x = (x - self.pixel_mean) / self.pixel_std  # broadcast over batch & time

        # ----- Branch 1: CNN-GRU -----
        # [B, T, C, H, W] -> [B*T, C, H, W]
        x_flat = x.view(B * T, C, H, W)
        cnn_feats = self.cnn(x_flat)  # [B*T, 512]
        cnn_feats = cnn_feats.view(B, T, -1)  # [B, T, 512]

        gru_out, _ = self.gru(cnn_feats)  # [B, T, 2*hidden]
        # Use last time step as summary
        gru_feat = gru_out[:, -1, :]  # [B, gru_out_dim]

        # ----- Branch 2: VideoMAE -----
        video_outputs = self.videomae(pixel_values=x)  # [B, num_patches, hidden]
        videomae_tokens = video_outputs.last_hidden_state
        videomae_feat = videomae_tokens.mean(dim=1)  # global average over tokens

        # ----- Fusion -----
        fused = torch.cat([gru_feat, videomae_feat], dim=-1)  # [B, fusion_dim]
        logit = self.head(fused).squeeze(-1)  # [B]

        return logit


# ============================================================
#  METRICS & EVALUATION
# ============================================================

def compute_binary_metrics(y_true, y_score, threshold=0.5):
    y_true = np.asarray(y_true).astype(int)
    y_score = np.asarray(y_score).astype(float)
    y_pred = (y_score >= threshold).astype(int)

    metrics = {}

    try:
        metrics["precision"] = precision_score(y_true, y_pred)
    except Exception:
        metrics["precision"] = float("nan")

    try:
        metrics["recall"] = recall_score(y_true, y_pred)
    except Exception:
        metrics["recall"] = float("nan")

    try:
        metrics["f1"] = f1_score(y_true, y_pred)
    except Exception:
        metrics["f1"] = float("nan")

    try:
        metrics["ap"] = average_precision_score(y_true, y_score)
    except Exception:
        metrics["ap"] = float("nan")

    try:
        metrics["roc_auc"] = roc_auc_score(y_true, y_score)
    except Exception:
        metrics["roc_auc"] = float("nan")

    try:
        fpr, tpr, _ = roc_curve(y_true, y_score)
    except Exception:
        fpr, tpr = np.array([0.0, 1.0]), np.array([0.0, 1.0])

    metrics["fpr"] = fpr
    metrics["tpr"] = tpr

    return metrics


def evaluate_model(
    model,
    df_split,
    horizons=HORIZONS,
    batch_size=BATCH_SIZE,
    num_workers=4,
    desc_prefix="VAL",
):
    """
    Evaluate model on df_split for each horizon separately.
    Returns:
        mean_ap: mean of AP across horizons
        horizon_metrics: dict[horizon] -> dict of metrics
    """
    model.eval()
    horizon_metrics = {}
    aps = []

    for h in horizons:
        dataset = NexarEvalDataset(df_split, horizon=h)
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
        )

        all_scores = []
        all_labels = []

        with torch.no_grad():
            for clips, labels in tqdm(loader, desc=f"{desc_prefix} horizon={h}s"):
                clips = clips.to(DEVICE)  # [B, T, 3, H, W]
                labels = labels.to(DEVICE)

                logits = model(clips)
                probs = torch.sigmoid(logits).detach().cpu().numpy()
                all_scores.append(probs)
                all_labels.append(labels.detach().cpu().numpy())

        y_score = np.concatenate(all_scores)
        y_true = np.concatenate(all_labels)

        m = compute_binary_metrics(y_true, y_score, threshold=0.5)
        horizon_metrics[h] = m
        aps.append(m["ap"])

    mean_ap = float(np.nanmean(aps))
    return mean_ap, horizon_metrics


# ============================================================
#  TRAINING LOOP (with EARLY STOPPING)
# ============================================================

def train_model(
    model,
    train_df,
    val_df,
    num_epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    early_stop_patience=EARLY_STOP_PATIENCE,
):
    train_dataset = NexarTrainDataset(train_df)
    val_dataset = NexarTrainDataset(val_df)  # train-style sampling but no horizon constraint
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    model.to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.BCEWithLogitsLoss()

    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))

    history = {
        "train_loss": [],
        "val_loss": [],
        "val_map": [],
    }

    best_val_map = -1.0
    best_state = None
    epochs_without_improve = 0

    for epoch in range(1, num_epochs + 1):
        print(f"\n===== Epoch {epoch}/{num_epochs} =====")

        # ------------------
        # Train
        # ------------------
        model.train()
        running_loss = 0.0
        n_samples = 0

        pbar = tqdm(train_loader, desc="TRAIN")
        for clips, labels in pbar:
            clips = clips.to(DEVICE)  # [B, T, 3, H, W]
            labels = labels.to(DEVICE)

            optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
                logits = model(clips)
                loss = criterion(logits, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            batch_size_curr = clips.size(0)
            running_loss += loss.item() * batch_size_curr
            n_samples += batch_size_curr
            pbar.set_postfix({"loss": running_loss / max(1, n_samples)})

        train_loss = running_loss / max(1, n_samples)

        # ------------------
        # Validation loss (on val_loader, same sampling style as train)
        # ------------------
        model.eval()
        val_running_loss = 0.0
        val_samples = 0
        with torch.no_grad():
            for clips, labels in tqdm(val_loader, desc="VAL (loss)"):
                clips = clips.to(DEVICE)
                labels = labels.to(DEVICE)
                logits = model(clips)
                loss = criterion(logits, labels)
                bs = clips.size(0)
                val_running_loss += loss.item() * bs
                val_samples += bs
        val_loss = val_running_loss / max(1, val_samples)

        # ------------------
        # Validation mAP (using horizon-based eval on val_df)
        # ------------------
        val_map, _ = evaluate_model(
            model,
            val_df,
            horizons=HORIZONS,
            batch_size=batch_size,
            desc_prefix="VAL (mAP)",
        )

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["val_map"].append(val_map)

        print(f"Epoch {epoch}: train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_mAP={val_map:.4f}")

        # Scheduler
        scheduler.step()

        # Save best model & handle early stopping
        if val_map > best_val_map:
            best_val_map = val_map
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            torch.save(best_state, os.path.join(OUTPUT_DIR, "hybrid_videomae_gru_best.pth"))
            print(f"*** New best val mAP: {best_val_map:.4f} (model saved)")
            epochs_without_improve = 0
        else:
            epochs_without_improve += 1
            print(f"No improvement in val mAP for {epochs_without_improve} epoch(s).")
            if epochs_without_improve >= early_stop_patience:
                print(f"Early stopping triggered (patience={early_stop_patience}).")
                break

        gc.collect()
        torch.cuda.empty_cache()

    # Load best state back to model
    if best_state is not None:
        model.load_state_dict(best_state)

    return model, history, best_val_map


# ============================================================
#  PLOTTING FUNCTIONS
# ============================================================

def plot_training_curves(history, out_dir=OUTPUT_DIR):
    epochs = np.arange(1, len(history["train_loss"]) + 1)

    # --- Loss curve (single figure) ---
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, history["train_loss"], label="Train Loss")
    plt.plot(epochs, history["val_loss"], label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    loss_path = os.path.join(out_dir, "loss_curve.png")
    plt.savefig(loss_path)
    plt.close()
    print(f"Saved: {loss_path}")

    # --- mAP curve (single figure) ---
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, history["val_map"], label="Val mAP (mean AP over horizons)")
    plt.xlabel("Epoch")
    plt.ylabel("mAP")
    plt.title("Validation mAP across training")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    map_path = os.path.join(out_dir, "map_curve.png")
    plt.savefig(map_path)
    plt.close()
    print(f"Saved: {map_path}")

    # --- Combined Loss & mAP side-by-side figure (NEW) ---
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Left: Loss
    axes[0].plot(epochs, history["train_loss"], label="Train Loss")
    axes[0].plot(epochs, history["val_loss"], label="Val Loss")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].set_title("Train vs Val Loss")
    axes[0].legend()
    axes[0].grid(True)

    # Right: mAP
    axes[1].plot(epochs, history["val_map"], label="Val mAP")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("mAP")
    axes[1].set_title("Validation mAP")
    axes[1].legend()
    axes[1].grid(True)

    fig.suptitle("Training Dynamics: Loss & mAP")
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    side_path = os.path.join(out_dir, "loss_map_side_by_side.png")
    fig.savefig(side_path)
    plt.close(fig)
    print(f"Saved: {side_path}")


def plot_roc_curves(horizon_metrics, out_dir=OUTPUT_DIR, split_name="test"):
    plt.figure(figsize=(7, 7))
    for h, m in horizon_metrics.items():
        fpr = m["fpr"]
        tpr = m["tpr"]
        auc = m["roc_auc"]
        label = f"{h:.1f}s (AUC={auc:.3f})"
        plt.plot(fpr, tpr, label=label)

    plt.plot([0, 1], [0, 1], "k--", label="Random")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curves ({split_name} set) for different horizons")
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.tight_layout()

    roc_path = os.path.join(out_dir, f"roc_curves_{split_name}.png")
    plt.savefig(roc_path)
    plt.close()
    print(f"Saved: {roc_path}")


def plot_confusion_matrix(cm, classes, title, out_path):
    """
    Simple confusion matrix heatmap with annotations.
    cm: 2x2 numpy array
    """
    plt.figure(figsize=(4, 4))
    plt.imshow(cm, interpolation='nearest', cmap='Blues')
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)

    thresh = cm.max() / 2.0 if cm.max() > 0 else 0.5
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(
                j, i, format(cm[i, j], 'd'),
                horizontalalignment="center",
                color="white" if cm[i, j] > thresh else "black",
            )

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()
    print(f"Saved: {out_path}")


# ============================================================
#  MAIN: TRAIN + EVAL ON TEST
# ============================================================

if __name__ == "__main__":
    # Instantiate model
    model = HybridVideoMAEGRU(
        videomae_name="MCG-NJU/videomae-base",  # you can switch to -large/-giant if you have more GPU
        gru_hidden_size=256,
        fusion_hidden_size=256,
    )

    # 1) Train on train_df, validate on val_df (with early stopping)
    model, history, best_val_map = train_model(
        model,
        train_df=train_df,
        val_df=val_df,
        num_epochs=NUM_EPOCHS,
        batch_size=BATCH_SIZE,
        lr=LR,
        weight_decay=WEIGHT_DECAY,
        early_stop_patience=EARLY_STOP_PATIENCE,
    )

    print(f"\nBest validation mAP: {best_val_map:.4f}")

    # 2) Plot training curves (separate + combined)
    plot_training_curves(history, out_dir=OUTPUT_DIR)

    # 3) Final evaluation on held-out TEST (70/10/20 split)
    print("\n===== Final evaluation on TEST split =====")
    test_map, test_horizon_metrics = evaluate_model(
        model,
        test_df,
        horizons=HORIZONS,
        batch_size=BATCH_SIZE,
        desc_prefix="TEST",
    )

    print(f"\nTEST mean AP across horizons (0.5s, 1.0s, 1.5s): {test_map:.4f}")
    for h in HORIZONS:
        m = test_horizon_metrics[h]
        print(
            f"Horizon {h:.1f}s -> "
            f"AP={m['ap']:.4f}, "
            f"Precision={m['precision']:.4f}, "
            f"Recall={m['recall']:.4f}, "
            f"F1={m['f1']:.4f}, "
            f"ROC-AUC={m['roc_auc']:.4f}"
        )

    # 4) Plot ROC curves for test set (multi-horizon)
    plot_roc_curves(test_horizon_metrics, out_dir=OUTPUT_DIR, split_name="test")

    # 5) Save test metrics to CSV for paper tables
    rows = []
    for h in HORIZONS:
        m = test_horizon_metrics[h]
        rows.append(
            {
                "horizon_sec": h,
                "AP": m["ap"],
                "precision": m["precision"],
                "recall": m["recall"],
                "f1": m["f1"],
                "roc_auc": m["roc_auc"],
            }
        )
    metrics_df = pd.DataFrame(rows)
    metrics_path = os.path.join(OUTPUT_DIR, "test_metrics_by_horizon.csv")
    metrics_df.to_csv(metrics_path, index=False)
    print(f"\nSaved detailed test metrics: {metrics_path}")

    # 6) EXTRA: Classification report & confusion matrix per horizon
    print("\n===== Detailed Classification Reports & Confusion Matrices (TEST) =====")
    class_names = ["Non-collision", "Collision/Near-miss"]

    for h in HORIZONS:
        print(f"\n--- Horizon {h:.1f} seconds ---")
        dataset = NexarEvalDataset(test_df, horizon=h)
        loader = DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )

        all_scores = []
        all_labels = []
        with torch.no_grad():
            model.eval()
            for clips, labels in tqdm(loader, desc=f"TEST raw preds horizon={h}s"):
                clips = clips.to(DEVICE)
                labels = labels.to(DEVICE)
                logits = model(clips)
                probs = torch.sigmoid(logits).detach().cpu().numpy()
                all_scores.append(probs)
                all_labels.append(labels.detach().cpu().numpy())

        y_score = np.concatenate(all_scores)
        y_true = np.concatenate(all_labels).astype(int)
        y_pred = (y_score >= 0.5).astype(int)

        # Classification report (4 decimal places)
        report = classification_report(
            y_true,
            y_pred,
            target_names=class_names,
            digits=4,
        )
        print(report)

        # Save classification report to text file
        report_path = os.path.join(OUTPUT_DIR, f"classification_report_test_{h:.1f}s.txt")
        with open(report_path, "w") as f:
            f.write(f"Horizon {h:.1f} seconds\n")
            f.write(report)
        print(f"Saved classification report: {report_path}")

        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        print("Confusion Matrix:")
        print(cm)

        # Save confusion matrix figure
        cm_title = f"Confusion Matrix (Test, horizon={h:.1f}s)"
        cm_path = os.path.join(OUTPUT_DIR, f"confusion_matrix_test_{h:.1f}s.png")
        plot_confusion_matrix(cm, class_names, cm_title, cm_path)


Total: 1500
Train: 1050
Val  : 150
Test : 300


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 195MB/s]


config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/377M [00:00<?, ?B/s]


===== Epoch 1/10 =====


  scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))


TRAIN:   0%|          | 0/262 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):


VAL (loss):   0%|          | 0/38 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a467a54fa30>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'


VAL (mAP) horizon=0.5s:   0%|          | 0/38 [00:00<?, ?it/s]

AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a467a54fa30>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a467a54fa30>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_

VAL (mAP) horizon=1.0s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.5s:   0%|          | 0/38 [00:00<?, ?it/s]

Epoch 1: train_loss=0.6384 | val_loss=0.5639 | val_mAP=0.8388
*** New best val mAP: 0.8388 (model saved)

===== Epoch 2/10 =====


TRAIN:   0%|          | 0/262 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):


VAL (loss):   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=0.5s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.0s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.5s:   0%|          | 0/38 [00:00<?, ?it/s]

Epoch 2: train_loss=0.5291 | val_loss=0.4905 | val_mAP=0.8847
*** New best val mAP: 0.8847 (model saved)

===== Epoch 3/10 =====


TRAIN:   0%|          | 0/262 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):


VAL (loss):   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=0.5s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.0s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.5s:   0%|          | 0/38 [00:00<?, ?it/s]

Epoch 3: train_loss=0.4796 | val_loss=0.4951 | val_mAP=0.8620
No improvement in val mAP for 1 epoch(s).

===== Epoch 4/10 =====


TRAIN:   0%|          | 0/262 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):


VAL (loss):   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=0.5s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.0s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.5s:   0%|          | 0/38 [00:00<?, ?it/s]

Epoch 4: train_loss=0.4330 | val_loss=0.4170 | val_mAP=0.9047
*** New best val mAP: 0.9047 (model saved)

===== Epoch 5/10 =====


TRAIN:   0%|          | 0/262 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):


VAL (loss):   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=0.5s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.0s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.5s:   0%|          | 0/38 [00:00<?, ?it/s]

Epoch 5: train_loss=0.3739 | val_loss=0.3890 | val_mAP=0.9388
*** New best val mAP: 0.9388 (model saved)

===== Epoch 6/10 =====


TRAIN:   0%|          | 0/262 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):


VAL (loss):   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=0.5s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.0s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.5s:   0%|          | 0/38 [00:00<?, ?it/s]

Epoch 6: train_loss=0.3766 | val_loss=0.3402 | val_mAP=0.9256
No improvement in val mAP for 1 epoch(s).

===== Epoch 7/10 =====


TRAIN:   0%|          | 0/262 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):


VAL (loss):   0%|          | 0/38 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a467a54fa30>Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a467a54fa30>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
        self._shutdown_workers()
self._shutdown_workers()  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers

      File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():if w.is_alive():

  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a

VAL (mAP) horizon=0.5s:   0%|          | 0/38 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a467a54fa30><function _MultiProcessingDataLoaderIter.__del__ at 0x7a467a54fa30>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
        if w.is_alive():if w.is_alive():

  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), 'can only te

VAL (mAP) horizon=1.0s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.5s:   0%|          | 0/38 [00:00<?, ?it/s]

Epoch 7: train_loss=0.3390 | val_loss=0.3499 | val_mAP=0.9392
*** New best val mAP: 0.9392 (model saved)

===== Epoch 8/10 =====


TRAIN:   0%|          | 0/262 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):


VAL (loss):   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=0.5s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.0s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.5s:   0%|          | 0/38 [00:00<?, ?it/s]

Epoch 8: train_loss=0.3200 | val_loss=0.4858 | val_mAP=0.9295
No improvement in val mAP for 1 epoch(s).

===== Epoch 9/10 =====


TRAIN:   0%|          | 0/262 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):


VAL (loss):   0%|          | 0/38 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a467a54fa30><function _MultiProcessingDataLoaderIter.__del__ at 0x7a467a54fa30>
Traceback (most recent call last):

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
Traceback (most recent call last):
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
        self._shutdown_workers()if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive

      File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
assert self._parent_pid == os.getpid(), 'can only test a child process'
    AssertionErrorif w.is_alive():: 
  File "/usr/lib/python3.10/multiprocessing/pro

VAL (mAP) horizon=0.5s:   0%|          | 0/38 [00:00<?, ?it/s]

  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7a467a54fa30>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


VAL (mAP) horizon=1.0s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.5s:   0%|          | 0/38 [00:00<?, ?it/s]

Epoch 9: train_loss=0.2797 | val_loss=0.3949 | val_mAP=0.9323
No improvement in val mAP for 2 epoch(s).

===== Epoch 10/10 =====


TRAIN:   0%|          | 0/262 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):


VAL (loss):   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=0.5s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.0s:   0%|          | 0/38 [00:00<?, ?it/s]

VAL (mAP) horizon=1.5s:   0%|          | 0/38 [00:00<?, ?it/s]

Epoch 10: train_loss=0.2848 | val_loss=0.3924 | val_mAP=0.9328
No improvement in val mAP for 3 epoch(s).
Early stopping triggered (patience=3).

Best validation mAP: 0.9392
Saved: ./nexar_hybrid_outputs/loss_curve.png
Saved: ./nexar_hybrid_outputs/map_curve.png
Saved: ./nexar_hybrid_outputs/loss_map_side_by_side.png

===== Final evaluation on TEST split =====


TEST horizon=0.5s:   0%|          | 0/75 [00:00<?, ?it/s]

TEST horizon=1.0s:   0%|          | 0/75 [00:00<?, ?it/s]

TEST horizon=1.5s:   0%|          | 0/75 [00:00<?, ?it/s]


TEST mean AP across horizons (0.5s, 1.0s, 1.5s): 0.8977
Horizon 0.5s -> AP=0.9438, Precision=0.8742, Recall=0.8800, F1=0.8771, ROC-AUC=0.9428
Horizon 1.0s -> AP=0.9040, Precision=0.8538, Recall=0.7400, F1=0.7929, ROC-AUC=0.9037
Horizon 1.5s -> AP=0.8452, Precision=0.8348, Recall=0.6400, F1=0.7245, ROC-AUC=0.8488
Saved: ./nexar_hybrid_outputs/roc_curves_test.png

Saved detailed test metrics: ./nexar_hybrid_outputs/test_metrics_by_horizon.csv

===== Detailed Classification Reports & Confusion Matrices (TEST) =====

--- Horizon 0.5 seconds ---


TEST raw preds horizon=0.5s:   0%|          | 0/75 [00:00<?, ?it/s]

                     precision    recall  f1-score   support

      Non-collision     0.8792    0.8733    0.8763       150
Collision/Near-miss     0.8742    0.8800    0.8771       150

           accuracy                         0.8767       300
          macro avg     0.8767    0.8767    0.8767       300
       weighted avg     0.8767    0.8767    0.8767       300

Saved classification report: ./nexar_hybrid_outputs/classification_report_test_0.5s.txt
Confusion Matrix:
[[131  19]
 [ 18 132]]
Saved: ./nexar_hybrid_outputs/confusion_matrix_test_0.5s.png

--- Horizon 1.0 seconds ---


TEST raw preds horizon=1.0s:   0%|          | 0/75 [00:00<?, ?it/s]

                     precision    recall  f1-score   support

      Non-collision     0.7706    0.8733    0.8187       150
Collision/Near-miss     0.8538    0.7400    0.7929       150

           accuracy                         0.8067       300
          macro avg     0.8122    0.8067    0.8058       300
       weighted avg     0.8122    0.8067    0.8058       300

Saved classification report: ./nexar_hybrid_outputs/classification_report_test_1.0s.txt
Confusion Matrix:
[[131  19]
 [ 39 111]]
Saved: ./nexar_hybrid_outputs/confusion_matrix_test_1.0s.png

--- Horizon 1.5 seconds ---


TEST raw preds horizon=1.5s:   0%|          | 0/75 [00:00<?, ?it/s]

                     precision    recall  f1-score   support

      Non-collision     0.7081    0.8733    0.7821       150
Collision/Near-miss     0.8348    0.6400    0.7245       150

           accuracy                         0.7567       300
          macro avg     0.7714    0.7567    0.7533       300
       weighted avg     0.7714    0.7567    0.7533       300

Saved classification report: ./nexar_hybrid_outputs/classification_report_test_1.5s.txt
Confusion Matrix:
[[131  19]
 [ 54  96]]
Saved: ./nexar_hybrid_outputs/confusion_matrix_test_1.5s.png
