In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
import os
os.getcwd()

'/content'

In [None]:
import os
import re
import torch
from typing import Tuple, Dict
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
from pathlib import Path


def parse_frame_filename(stem: str) -> Dict[str, str]:

    # Split the frame index
    parts = stem.rsplit("_", 1)
    if len(parts) == 2:
        video_id, frame_part = parts
        frame_idx = int(frame_part)
    else:
        video_id = stem
        frame_idx = 0

    # Extract subject_id (first 4 chars)
    subject_id = video_id[:4]

    # Extract the Task
    video_parts = video_id.split("_")
    if len(video_parts) > 2:
        # Grab everything between index 2 (after Rep ID) and the last element (Suffix)
        task = "_".join(video_parts[2:-1]) if len(video_parts) > 3 else video_parts[2]
    else:
        task = "unknown"

    filename = video_id.split(".")[0]

    return {
        "subject_id": subject_id,
        "video_id": video_id,
        "task": task,
        "frame_idx": frame_idx,
        "filename": filename,
    }


class PreprocessedImageFolder(torch.utils.data.Dataset):

    def __init__(self, root, transform=None):
        self.image_folder = datasets.ImageFolder(root, transform=transform)
        self.transform = transform
        self.samples = self.image_folder.samples
        self.classes = self.image_folder.classes
        self.class_to_idx = self.image_folder.class_to_idx

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]

        # Parse filename to extract subject_id
        stem = Path(path).stem
        parsed = parse_frame_filename(stem)
        subject_id = parsed["subject_id"]

        image = Image.open(path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, label, subject_id


def get_train_transforms(image_size: int) -> transforms.Compose:

    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomApply(
            [transforms.ColorJitter(brightness=0.1, contrast=0.1)], p=0.5
        ),
        transforms.RandomAffine(
            degrees=5,
            translate=(0.02, 0.02),
            scale=(0.95, 1.05)
        ),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])


def get_eval_transforms(image_size: int) -> transforms.Compose:

    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])


def collate_with_subject_id(batch):

    images = torch.stack([item[0] for item in batch])
    labels = torch.tensor([item[1] for item in batch])
    subject_ids = [item[2] for item in batch]
    return images, labels, subject_ids


def get_dataloaders(
    data_root: str,
    image_size: int = 224,
    batch_size: int = 32,
    num_workers: int = 2,
    pin_memory: bool = True,
    seed: int = 42
):

    train_transform = get_train_transforms(image_size)
    eval_transform = get_eval_transforms(image_size)

    train_dir = os.path.join(data_root, "train")
    val_dir = os.path.join(data_root, "val")
    test_dir = os.path.join(data_root, "test")

    if not (os.path.exists(train_dir) and os.path.exists(val_dir) and os.path.exists(test_dir)):
        raise FileNotFoundError(
            f"Expected train/val/test folders in {data_root}. "
            f"Please ensure data is preprocessed with subject-level splits."
        )

    train_dataset = PreprocessedImageFolder(train_dir, transform=train_transform)
    val_dataset = PreprocessedImageFolder(val_dir, transform=eval_transform)
    test_dataset = PreprocessedImageFolder(test_dir, transform=eval_transform)

    # Verify subject-level splitting
    train_subjects = set()
    val_subjects = set()
    test_subjects = set()

    for _, _, subj_id in train_dataset:
        train_subjects.add(subj_id)
    for _, _, subj_id in val_dataset:
        val_subjects.add(subj_id)
    for _, _, subj_id in test_dataset:
        test_subjects.add(subj_id)

    # Check for subject leakage
    train_val_overlap = train_subjects & val_subjects
    train_test_overlap = train_subjects & test_subjects
    val_test_overlap = val_subjects & test_subjects

    if train_val_overlap or train_test_overlap or val_test_overlap:
        print("WARNING: Subject leakage detected!")
        if train_val_overlap:
            print(f"  Train/Val overlap: {len(train_val_overlap)} subjects")
        if train_test_overlap:
            print(f"  Train/Test overlap: {len(train_test_overlap)} subjects")
        if val_test_overlap:
            print(f"  Val/Test overlap: {len(val_test_overlap)} subjects")
    else:
        print("Subject-level split verified: No subject leakage detected")

    print(f"\nDataset Statistics:")
    print(f"  Train: {len(train_dataset)} frames, {len(train_subjects)} subjects")
    print(f"  Val:   {len(val_dataset)} frames, {len(val_subjects)} subjects")
    print(f"  Test:  {len(test_dataset)} frames, {len(test_subjects)} subjects")

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=collate_with_subject_id
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=collate_with_subject_id
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=collate_with_subject_id
    )

    classes = train_dataset.classes

    return train_loader, val_loader, test_loader, train_dataset, classes

In [None]:
import torch
import torch.nn as nn
from torchvision import models

NUM_CLASSES = 3


def get_model(model_name: str, pretrained: bool = True):

   # Choose what you want

    weights = "IMAGENET1K_V1" if pretrained else None

    if model_name == "efficientnet_b0":
        model = models.efficientnet_b0(weights=weights)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES)

    elif model_name == "efficientnet_b3":
        model = models.efficientnet_b3(weights=weights)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES)

    elif model_name == "efficientnet_b5":
        model = models.efficientnet_b5(weights=weights)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES)

    elif model_name == "convnext_tiny":
        model = models.convnext_tiny(weights=weights)
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, NUM_CLASSES)

    elif model_name == "convnext_small":
        model = models.convnext_small(weights=weights)
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, NUM_CLASSES)

    elif model_name == "convnext_base":
        model = models.convnext_base(weights=weights)
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, NUM_CLASSES)

    elif model_name == "swin_t":
        model = models.swin_t(weights=weights)
        model.head = nn.Linear(model.head.in_features, NUM_CLASSES)

    elif model_name == "swin_s":
        model = models.swin_s(weights=weights)
        model.head = nn.Linear(model.head.in_features, NUM_CLASSES)

    elif model_name == "swin_b":
        model = models.swin_b(weights=weights)
        model.head = nn.Linear(model.head.in_features, NUM_CLASSES)

    elif model_name == "vit_b_16":
        model = models.vit_b_16(weights=weights)
        model.heads.head = nn.Linear(model.heads.head.in_features, NUM_CLASSES)

    elif model_name == "vit_l_16":
        model = models.vit_l_16(weights=weights)
        model.heads.head = nn.Linear(model.heads.head.in_features, NUM_CLASSES)

    elif model_name == "mobilenet_v3_large":
        model = models.mobilenet_v3_large(weights=weights)
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, NUM_CLASSES)

    elif model_name == "mobilenet_v3_small":
        model = models.mobilenet_v3_small(weights=weights)
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, NUM_CLASSES)

    elif model_name == "resnet18":
        model = models.resnet18(weights=weights)
        model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)

    elif model_name == "resnet34":
        model = models.resnet34(weights=weights)
        model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)

    elif model_name == "resnet50":
        model = models.resnet50(weights=weights)
        model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)

    elif model_name == "resnet101":
        model = models.resnet101(weights=weights)
        model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)

    elif model_name == "resnet152":
        model = models.resnet152(weights=weights)
        model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)

    else:
        raise ValueError(f"Unsupported model: {model_name}")

    return model

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
    classification_report,
    confusion_matrix,
    precision_recall_fscore_support
)


def aggregate_subject_probas(
    y_proba: np.ndarray,
    y_true: np.ndarray,
    subject_ids: np.ndarray,
    label_names: list,
):

    num_classes = y_proba.shape[1]
    df = pd.DataFrame({
        "subject_id": subject_ids,
        "y_true": y_true,
    })

    # Add prob columns
    for c in range(num_classes):
        df[f"p_{c}"] = y_proba[:, c]

    # Group by subject, average probabilities
    grouped = df.groupby("subject_id")
    agg_records = []
    for subject_id, g in grouped:
        # Average probs
        mean_probas = g[[f"p_{c}" for c in range(num_classes)]].mean(axis=0).to_numpy()

        # True label: all rows for a subject should share the same label
        y_true_subj = int(g["y_true"].iloc[0])
        y_pred_subj = int(np.argmax(mean_probas))

        rec = {
            "subject_id": subject_id,
            "y_true": y_true_subj,
            "y_pred": y_pred_subj,
        }

        for c in range(num_classes):
            rec[f"p_{c}"] = float(mean_probas[c])
        agg_records.append(rec)

    subj_df = pd.DataFrame.from_records(agg_records)

    # Compute metrics at subject-level
    y_true_subj = subj_df["y_true"].to_numpy()
    y_pred_subj = subj_df["y_pred"].to_numpy()
    prob_cols = [f"p_{c}" for c in range(num_classes)]
    y_proba_subj = subj_df[prob_cols].to_numpy()

    metrics = compute_classification_metrics_core(
        y_true=y_true_subj,
        y_pred=y_pred_subj,
        y_proba=y_proba_subj,
        label_names=label_names,
    )

    return {
        "subject_df": subj_df,
        "metrics": metrics,
    }


def compute_classification_metrics_core(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    y_proba: np.ndarray,
    label_names: list,
):
    """Core metric computation used for both task-level and subject-level"""
    metrics = {}

    metrics["accuracy"] = float(accuracy_score(y_true, y_pred))
    metrics["macro_f1"] = float(f1_score(y_true, y_pred, average="macro"))
    metrics["weighted_f1"] = float(f1_score(y_true, y_pred, average="weighted"))

    precision, recall, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(len(label_names)), zero_division=0
    )

    metrics["precision"] = float(precision_score(y_true, y_pred, average="weighted", zero_division=0))
    metrics["recall"] = float(recall_score(y_true, y_pred, average="weighted", zero_division=0))

    metrics["per_class"] = {
        label_names[i]: {
            "precision": float(precision[i]),
            "recall": float(recall[i]),
            "f1": float(f1[i]),
            "support": int(support[i]),
        }
        for i in range(len(label_names))
    }

    if y_proba is not None and y_proba.ndim == 2 and y_proba.shape[1] == len(label_names):
        try:
            roc_auc = roc_auc_score(y_true, y_proba, multi_class="ovr", average="weighted")
            metrics["roc_auc_ovr"] = float(roc_auc)
        except ValueError:
            metrics["roc_auc_ovr"] = None
    else:
        metrics["roc_auc_ovr"] = None

    return metrics


def evaluate(model, data_loader, device, return_details=False):
    # Frame-Level Eval
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    all_subject_ids = []

    with torch.no_grad():
        for batch in data_loader:
            images, labels, subject_ids = batch
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            probs = F.softmax(outputs, dim=1)
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())
            all_probs.extend(probs.cpu().numpy())
            all_subject_ids.extend(subject_ids)

    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_subject_ids = np.array(all_subject_ids)

    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)

    try:
        auc = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='weighted')
    except ValueError:
        auc = 0.0

    if return_details:
        return {
            'accuracy': accuracy,
            'f1': f1,
            'precision': precision,
            'recall': recall,
            'auc': auc,
            'predictions': all_preds,
            'labels': all_labels,
            'probabilities': all_probs,
            'subject_ids': all_subject_ids
        }

    return accuracy, f1


def print_confusion_matrix(labels, predictions, class_names):

    cm = confusion_matrix(labels, predictions)

    print("\nConfusion Matrix:")
    print("-" * 60)

    max_name_len = max(len(name) for name in class_names)
    header_padding = " " * (max_name_len + 2)

    header = header_padding + "Predicted"
    print(header)
    header_labels = header_padding + " ".join(["{:>10}".format(name) for name in class_names])
    print(header_labels)
    print("-" * 60)

    for i, row in enumerate(cm):
        row_str = " ".join(["{:>10}".format(val) for val in row])
        label_str = "{:>{}}".format(class_names[i], max_name_len)
        print(label_str + "  " + row_str)

    print("-" * 60)

    print("\nPer-Class Metrics:")
    print("-" * 60)
    class_header = "{:<15} {:<12} {:<12} {:<12} {:<10}".format('Class', 'Precision', 'Recall', 'F1-Score', 'Support')
    print(class_header)
    print("-" * 60)

    report = classification_report(labels, predictions, target_names=class_names, output_dict=True, zero_division=0)
    for class_name in class_names:
        metrics = report[class_name]
        row_str = "{:<15} {:<12.4f} {:<12.4f} {:<12.4f} {:<10}".format(
            class_name,
            metrics['precision'],
            metrics['recall'],
            metrics['f1-score'],
            int(metrics['support'])
        )
        print(row_str)

    print("-" * 60)


def get_full_metrics(model, data_loader, device, class_names):

    # Get frame-level predictions
    frame_results = evaluate(model, data_loader, device, return_details=True)

    # Aggregate to subject-level
    subject_results = aggregate_subject_probas(
        y_proba=frame_results['probabilities'],
        y_true=frame_results['labels'],
        subject_ids=frame_results['subject_ids'],
        label_names=class_names
    )

    print("\n" + "="*60)
    print("FRAME-LEVEL METRICS")
    print("="*60)
    print("  Accuracy:  {:.4f}".format(frame_results['accuracy']))
    print("  F1 Score:  {:.4f}".format(frame_results['f1']))
    print("  Precision: {:.4f}".format(frame_results['precision']))
    print("  Recall:    {:.4f}".format(frame_results['recall']))
    print("  AUC-ROC:   {:.4f}".format(frame_results['auc']))

    print("\n" + "="*60)
    print("SUBJECT-LEVEL METRICS (Primary)")
    print("="*60)
    subj_metrics = subject_results['metrics']
    print("  Accuracy:  {:.4f}".format(subj_metrics['accuracy']))
    print("  F1 Score:  {:.4f}".format(subj_metrics['weighted_f1']))
    print("  Precision: {:.4f}".format(subj_metrics['precision']))
    print("  Recall:    {:.4f}".format(subj_metrics['recall']))
    print("  AUC-ROC:   {:.4f}".format(subj_metrics.get('roc_auc_ovr', 0.0) or 0.0))
    print("  Total Subjects: {}".format(len(subject_results['subject_df'])))

    subj_df = subject_results['subject_df']
    print_confusion_matrix(subj_df['y_true'].tolist(), subj_df['y_pred'].tolist(), class_names)

    return {
        'accuracy': subj_metrics['accuracy'],
        'f1': subj_metrics['weighted_f1'],
        'precision': subj_metrics['precision'],
        'recall': subj_metrics['recall'],
        'auc': subj_metrics.get('roc_auc_ovr', 0.0) or 0.0,
        'predictions': subj_df['y_pred'].tolist(),
        'labels': subj_df['y_true'].tolist(),
        'frame_metrics': {
            'accuracy': frame_results['accuracy'],
            'f1': frame_results['f1'],
            'precision': frame_results['precision'],
            'recall': frame_results['recall'],
            'auc': frame_results['auc']
        }
    }

In [None]:
import os
from collections import Counter
import torch
import torch.nn as nn
from tqdm.notebook import tqdm


def setup_device():

    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    else:
        device = torch.device("cpu")
        print("CUDA not available, using CPU")
    return device


def train_one_epoch(model, data_loader, criterion, optimizer, device):
   # Trains model epoch-by-epoch
    model.train()
    running_loss = 0.0

    for batch in tqdm(data_loader, leave=False):
        images, labels, subject_ids = batch
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    return running_loss / len(data_loader.dataset)


def train_model(
    model_name: str,
    train_loader,
    val_loader,
    test_loader,
    class_names,
    device,
    num_epochs: int,
    patience: int,
    learning_rate: float,
    label_smoothing: float,
    class_weights: torch.Tensor = None,
    seed: int = 42
):
    """Trains a model with early stopping and label smoothing."""
    print(f"\n{'='*70}")
    print(f"Training: {model_name.upper()} | LR: {learning_rate}")
    print(f"{'='*70}")

    torch.cuda.empty_cache()

    # Set CUDA random seed before model initialization
    torch.cuda.manual_seed(seed)

    model = get_model(model_name=model_name, pretrained=True)
    model = model.to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total Parameters: {total_params:,}")
    print(f"{'='*70}\n")

    if class_weights is not None:
        criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=label_smoothing)
    else:
        criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

    best_val_f1 = 0.0
    best_state_dict = None
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_metrics = evaluate(model, val_loader, device, return_details=True)
        scheduler.step()

        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {train_loss:.4f} | "
              f"Val Acc: {val_metrics['accuracy']:.4f} | Val F1: {val_metrics['f1']:.4f} | "
              f"Val AUC: {val_metrics['auc']:.4f}")

        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            best_state_dict = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            epochs_no_improve = 0
            print(f"  New best F1: {best_val_f1:.4f}")
        else:
            epochs_no_improve += 1
            print(f"  No improvement: {epochs_no_improve}/{patience}")

        if epochs_no_improve >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break

    if best_state_dict:
        model.load_state_dict(best_state_dict)
        model = model.to(device)

    print(f"\n{'='*70}")
    print(f"{model_name.upper()} | Test Results")
    print(f"{'='*70}")

    test_metrics = get_full_metrics(model, test_loader, device, class_names)

    print(f"\nBest Val F1: {best_val_f1:.4f}")
    print(f"{'='*70}\n")

    save_dir = "/content/drive/MyDrive/NeuroFace/src"
    os.makedirs(save_dir, exist_ok=True)
    lr_str = str(learning_rate).replace('.', '_')
    save_path = f"{save_dir}/{model_name}_lr{lr_str}_neuroface_best_model.pth"
    torch.save(model.state_dict(), save_path)
    print(f"Saved: {save_path}\n")

    del model
    torch.cuda.empty_cache()

    return {
        'accuracy': test_metrics['accuracy'],
        'f1': test_metrics['f1'],
        'precision': test_metrics['precision'],
        'recall': test_metrics['recall'],
        'auc': test_metrics['auc'],
        'best_val_f1': best_val_f1,
        'learning_rate': learning_rate,
        'predictions': test_metrics['predictions'],
        'labels': test_metrics['labels'],
        'frame_metrics': test_metrics.get('frame_metrics', {})
    }

In [None]:
import random
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seed for reproducibility
RANDOM_SEED = 42

def set_seed(seed):
    """Set random seed for reproducibility across all libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Random seed set to: {seed}")

set_seed(RANDOM_SEED)

# Hyperparameters
IMAGE_SIZE = 224
BATCH_SIZE = 16
NUM_WORKERS = 12
NUM_EPOCHS = 20
PATIENCE = 3
LEARNING_RATE = 0.001

MODELS_TO_TRAIN = [
    "efficientnet_b0",
    "efficientnet_b3",
    "efficientnet_b5",
    "convnext_tiny",
    "convnext_small",
    "convnext_base",
    "swin_t",
    "swin_s",
    "swin_b",
    "vit_b_16",
    "vit_l_16",
    "mobilenet_v3_large",
    "mobilenet_v3_small",
    "resnet101",
    "resnet152",
]

device = setup_device()

train_loader, val_loader, test_loader, train_dataset, classes = get_dataloaders(
    data_root="/content/drive/MyDrive/NeuroFace/processed_frames",
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    seed=RANDOM_SEED
)

print("\nClasses: {}".format(classes))
print("Number of classes: {}".format(len(classes)))
print("\n" + "="*120)
print("TRAINING ALL MODELS - SUBJECT-LEVEL CLASSIFICATION")
print("="*120 + "\n")

results = {}

for model_name in MODELS_TO_TRAIN:
    try:
        results[model_name] = train_model(
            model_name=model_name,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            class_names=classes,
            device=device,
            num_epochs=NUM_EPOCHS,
            patience=PATIENCE,
            learning_rate=LEARNING_RATE,
            label_smoothing=0.0,
            class_weights=None,
            seed=RANDOM_SEED
        )
    except Exception as e:
        print("Error training {}: {}".format(model_name, e))
        import traceback
        traceback.print_exc()
        continue


# Create results dataframe with consistent column names
results_data = []
for model_name, metrics in results.items():
    results_data.append({
        'Model': model_name,
        'Accuracy': metrics['accuracy'],
        'Precision': metrics['precision'],
        'Recall': metrics['recall'],
        'F1': metrics['f1'],
        'AUC': metrics['auc']
    })

df_results = pd.DataFrame(results_data)
df_results = df_results.sort_values('F1', ascending=False).reset_index(drop=True)
df_results.index = df_results.index + 1

print("Table 1: Performance Comparison of Deep Learning Models (Subject-Level)")
print("-" * 100)
print(df_results.to_string(index=True, float_format=lambda x: '{:.4f}'.format(x)))
print("-" * 100)
print("\nNote: All models trained with learning rate = {}, batch size = {}, early stopping patience = {}".format(
    LEARNING_RATE, BATCH_SIZE, PATIENCE))
print("Dataset split: Preprocessed train/val/test with subject-level separation")
print("Random seed: {}".format(RANDOM_SEED))
print("Metrics computed at SUBJECT-LEVEL (frame predictions aggregated by mean probability)\n")

# Save results to CSV
csv_path = "/content/drive/MyDrive/NeuroFace/src/model_comparison_results_subject_level.csv"
df_results.to_csv(csv_path, index=True, float_format='%.4f')
print("Results saved to: {}\n".format(csv_path))

# Print best model
best_model = df_results.iloc[0]['Model']
print("="*120)
print("BEST PERFORMING MODEL: {}".format(best_model.upper()))
print("="*120)
print("  Accuracy:  {:.4f}".format(df_results.iloc[0]['Accuracy']))
print("  Precision: {:.4f}".format(df_results.iloc[0]['Precision']))
print("  Recall:    {:.4f}".format(df_results.iloc[0]['Recall']))
print("  F1 Score:  {:.4f}".format(df_results.iloc[0]['F1']))
print("  AUC-ROC:   {:.4f}".format(df_results.iloc[0]['AUC']))
print("="*120 + "\n")

# Generate publication-ready confusion matrices for all models
print("\n" + "="*120)
print("CONFUSION MATRICES FOR ALL MODELS (Publication Format - Subject-Level)")
print("="*120 + "\n")

fig_dir = "/content/drive/MyDrive/NeuroFace/src/confusion_matrices_subject_level"
os.makedirs(fig_dir, exist_ok=True)

for idx, row in df_results.iterrows():
    model_name = row['Model']
    print("\n" + "="*100)
    print("MODEL: {} (Rank #{})".format(model_name.upper(), idx))
    print("Accuracy: {:.4f} | Precision: {:.4f} | Recall: {:.4f} | F1: {:.4f} | AUC: {:.4f}".format(
        row['Accuracy'], row['Precision'], row['Recall'], row['F1'], row['AUC']))
    print("="*100)

    # Get confusion matrix data
    cm = confusion_matrix(results[model_name]['labels'], results[model_name]['predictions'])

    # Print text version
    print("\nConfusion Matrix (Subject-Level):")
    print("-" * 60)
    max_name_len = max(len(name) for name in classes)
    header_padding = " " * (max_name_len + 2)
    header = header_padding + "Predicted"
    print(header)
    header_labels = header_padding + " ".join(["{:>10}".format(name) for name in classes])
    print(header_labels)
    print("-" * 60)

    for i, row_data in enumerate(cm):
        row_str = " ".join(["{:>10}".format(val) for val in row_data])
        label_str = "{:>{}}".format(classes[i], max_name_len)
        print(label_str + "  " + row_str)

    print("-" * 60)

    # Print frame-level metrics if available
    if 'frame_metrics' in results[model_name] and results[model_name]['frame_metrics']:
        fm = results[model_name]['frame_metrics']
        print("\nFrame-Level Metrics (for reference):")
        print("  Accuracy:  {:.4f}".format(fm.get('accuracy', 0.0)))
        print("  F1 Score:  {:.4f}".format(fm.get('f1', 0.0)))
        print("  Precision: {:.4f}".format(fm.get('precision', 0.0)))
        print("  Recall:    {:.4f}".format(fm.get('recall', 0.0)))
        print("  AUC-ROC:   {:.4f}".format(fm.get('auc', 0.0)))

    # Create visual confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=classes, yticklabels=classes,
                cbar_kws={'label': 'Subject Count'})
    plt.title('Confusion Matrix (Subject-Level): {}\nF1-Score: {:.4f}'.format(model_name.upper(), row['F1']),
              fontsize=14, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()

    fig_path = "{}/{}_confusion_matrix_subject_level.png".format(fig_dir, model_name)
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print("\nConfusion matrix figure saved: {}".format(fig_path))
    plt.close()


print("Figures saved to: {}".format(fig_dir))


# Create summary statistics table
print("\n" + "="*120)
print("SUMMARY STATISTICS")
print("="*120 + "\n")

summary_stats = pd.DataFrame({
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1', 'AUC'],
    'Mean': [
        df_results['Accuracy'].mean(),
        df_results['Precision'].mean(),
        df_results['Recall'].mean(),
        df_results['F1'].mean(),
        df_results['AUC'].mean()
    ],
    'Std': [
        df_results['Accuracy'].std(),
        df_results['Precision'].std(),
        df_results['Recall'].std(),
        df_results['F1'].std(),
        df_results['AUC'].std()
    ],
    'Min': [
        df_results['Accuracy'].min(),
        df_results['Precision'].min(),
        df_results['Recall'].min(),
        df_results['F1'].min(),
        df_results['AUC'].min()
    ],
    'Max': [
        df_results['Accuracy'].max(),
        df_results['Precision'].max(),
        df_results['Recall'].max(),
        df_results['F1'].max(),
        df_results['AUC'].max()
    ]
})

print("Table 2: Summary Statistics Across All Models (n={})".format(len(df_results)))
print("-" * 80)
print(summary_stats.to_string(index=False, float_format=lambda x: '{:.4f}'.format(x)))
print("-" * 80)

# Save summary to CSV
summary_path = "/content/drive/MyDrive/NeuroFace/src/summary_statistics_subject_level.csv"
summary_stats.to_csv(summary_path, index=False, float_format='%.4f')
print("\nSummary statistics saved to: {}\n".format(summary_path))

# Create LaTeX table
print("\n" + "="*120)
print("LATEX TABLE (For Direct Copy-Paste into Publications)")
print("="*120 + "\n")

# Rename columns for LaTeX output
df_latex = df_results.copy()
df_latex.columns = ['Model', 'Accuracy', 'Precision', 'Recall', 'F1-Score', 'AUC-ROC']

latex_table = df_latex.to_latex(
    index=True,
    float_format="%.4f",
    caption="Performance comparison of deep learning models for subject-level classification (Random Seed: {}).".format(RANDOM_SEED),
    label="tab:model_comparison_subject",
    column_format='l|lrrrrr'
)
print(latex_table)

latex_path = "/content/drive/MyDrive/NeuroFace/src/results_table_subject_level.tex"
with open(latex_path, 'w') as f:
    f.write(latex_table)
print("LaTeX table saved to: {}\n".format(latex_path))

print("Random Seed: {}".format(RANDOM_SEED))


Random seed set to: 42
Using GPU: Tesla T4
Memory: 15.83 GB




  Train/Val overlap: 253 subjects
  Train/Test overlap: 221 subjects
  Val/Test overlap: 107 subjects

Dataset Statistics:
  Train: 2574 frames, 714 subjects
  Val:   387 frames, 272 subjects
  Test:  345 frames, 267 subjects

Classes: ['ALS', 'HC', 'PS']
Number of classes: 3

TRAINING ALL MODELS - SUBJECT-LEVEL CLASSIFICATION


Training: EFFICIENTNET_B0 | LR: 0.001
Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [00:00<00:00, 173MB/s]


Total Parameters: 4,011,391



  0%|          | 0/161 [00:00<?, ?it/s]

Epoch 1/20 | Loss: 0.1467 | Val Acc: 0.6047 | Val F1: 0.5482 | Val AUC: 0.8729
  New best F1: 0.5482


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 2/20 | Loss: 0.0516 | Val Acc: 0.7080 | Val F1: 0.6972 | Val AUC: 0.9167
  New best F1: 0.6972


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 3/20 | Loss: 0.0547 | Val Acc: 0.4134 | Val F1: 0.4268 | Val AUC: 0.6061
  No improvement: 1/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 4/20 | Loss: 0.0126 | Val Acc: 0.4910 | Val F1: 0.5121 | Val AUC: 0.6783
  No improvement: 2/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 5/20 | Loss: 0.0032 | Val Acc: 0.5168 | Val F1: 0.5151 | Val AUC: 0.7290
  No improvement: 3/3

Early stopping at epoch 5

EFFICIENTNET_B0 | Test Results





FRAME-LEVEL METRICS
  Accuracy:  0.5072
  F1 Score:  0.5263
  Precision: 0.6851
  Recall:    0.5072
  AUC-ROC:   0.6383

SUBJECT-LEVEL METRICS (Primary)
  Accuracy:  0.5019
  F1 Score:  0.5323
  Precision: 0.6567
  Recall:    0.5019
  AUC-ROC:   0.6823
  Total Subjects: 267

Confusion Matrix:
------------------------------------------------------------
     Predicted
            ALS         HC         PS
------------------------------------------------------------
ALS          31         62          1
 HC           8         30          1
 PS          47         14         73
------------------------------------------------------------

Per-Class Metrics:
------------------------------------------------------------
Class           Precision    Recall       F1-Score     Support   
------------------------------------------------------------
ALS             0.3605       0.3298       0.3444       94        
HC              0.2830       0.7692       0.4138       39        
PS             

100%|██████████| 47.2M/47.2M [00:00<00:00, 116MB/s]


Total Parameters: 10,700,843



  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 1/20 | Loss: 0.1550 | Val Acc: 0.6124 | Val F1: 0.5676 | Val AUC: 0.8207
  New best F1: 0.5676


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 2/20 | Loss: 0.0662 | Val Acc: 0.8450 | Val F1: 0.8349 | Val AUC: 0.9278
  New best F1: 0.8349


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 3/20 | Loss: 0.0161 | Val Acc: 0.6176 | Val F1: 0.5914 | Val AUC: 0.7971
  No improvement: 1/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 4/20 | Loss: 0.0043 | Val Acc: 0.7829 | Val F1: 0.7818 | Val AUC: 0.9283
  No improvement: 2/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 5/20 | Loss: 0.0018 | Val Acc: 0.7313 | Val F1: 0.7274 | Val AUC: 0.9186
  No improvement: 3/3

Early stopping at epoch 5

EFFICIENTNET_B3 | Test Results





FRAME-LEVEL METRICS
  Accuracy:  0.4754
  F1 Score:  0.4868
  Precision: 0.6804
  Recall:    0.4754
  AUC-ROC:   0.7069

SUBJECT-LEVEL METRICS (Primary)
  Accuracy:  0.4682
  F1 Score:  0.4611
  Precision: 0.6258
  Recall:    0.4682
  AUC-ROC:   0.7434
  Total Subjects: 267

Confusion Matrix:
------------------------------------------------------------
     Predicted
            ALS         HC         PS
------------------------------------------------------------
ALS           3         88          3
 HC           5         30          4
 PS           0         42         92
------------------------------------------------------------

Per-Class Metrics:
------------------------------------------------------------
Class           Precision    Recall       F1-Score     Support   
------------------------------------------------------------
ALS             0.3750       0.0319       0.0588       94        
HC              0.1875       0.7692       0.3015       39        
PS             

100%|██████████| 117M/117M [00:00<00:00, 202MB/s]


Total Parameters: 28,346,931



  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 1/20 | Loss: 0.1957 | Val Acc: 0.3669 | Val F1: 0.2843 | Val AUC: 0.6591
  New best F1: 0.2843


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 2/20 | Loss: 0.0484 | Val Acc: 0.4599 | Val F1: 0.2898 | Val AUC: 0.6157
  New best F1: 0.2898


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 3/20 | Loss: 0.0962 | Val Acc: 0.8346 | Val F1: 0.8215 | Val AUC: 0.9476
  New best F1: 0.8215


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 4/20 | Loss: 0.0241 | Val Acc: 0.7080 | Val F1: 0.6120 | Val AUC: 0.8164
  No improvement: 1/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 5/20 | Loss: 0.0063 | Val Acc: 0.6873 | Val F1: 0.5930 | Val AUC: 0.8142
  No improvement: 2/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 6/20 | Loss: 0.0046 | Val Acc: 0.6848 | Val F1: 0.5911 | Val AUC: 0.8081
  No improvement: 3/3

Early stopping at epoch 6

EFFICIENTNET_B5 | Test Results





FRAME-LEVEL METRICS
  Accuracy:  0.7101
  F1 Score:  0.7027
  Precision: 0.6981
  Recall:    0.7101
  AUC-ROC:   0.8749

SUBJECT-LEVEL METRICS (Primary)
  Accuracy:  0.7154
  F1 Score:  0.7002
  Precision: 0.7109
  Recall:    0.7154
  AUC-ROC:   0.8852
  Total Subjects: 267

Confusion Matrix:
------------------------------------------------------------
     Predicted
            ALS         HC         PS
------------------------------------------------------------
ALS          57         24         13
 HC           4          7         28
 PS           6          1        127
------------------------------------------------------------

Per-Class Metrics:
------------------------------------------------------------
Class           Precision    Recall       F1-Score     Support   
------------------------------------------------------------
ALS             0.8507       0.6064       0.7081       94        
HC              0.2188       0.1795       0.1972       39        
PS             

100%|██████████| 109M/109M [00:00<00:00, 160MB/s]


Total Parameters: 27,822,435



  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 1/20 | Loss: 1.1049 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.3452
  New best F1: 0.1158


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 2/20 | Loss: 1.0886 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.3404
  No improvement: 1/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 3/20 | Loss: 1.0886 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.3928
  No improvement: 2/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 4/20 | Loss: 1.0836 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.3816
  No improvement: 3/3

Early stopping at epoch 4

CONVNEXT_TINY | Test Results





FRAME-LEVEL METRICS
  Accuracy:  0.5391
  F1 Score:  0.3777
  Precision: 0.2907
  Recall:    0.5391
  AUC-ROC:   0.6418

SUBJECT-LEVEL METRICS (Primary)
  Accuracy:  0.5019
  F1 Score:  0.3354
  Precision: 0.2519
  Recall:    0.5019
  AUC-ROC:   0.6095
  Total Subjects: 267

Confusion Matrix:
------------------------------------------------------------
     Predicted
            ALS         HC         PS
------------------------------------------------------------
ALS           0          0         94
 HC           0          0         39
 PS           0          0        134
------------------------------------------------------------

Per-Class Metrics:
------------------------------------------------------------
Class           Precision    Recall       F1-Score     Support   
------------------------------------------------------------
ALS             0.0000       0.0000       0.0000       94        
HC              0.0000       0.0000       0.0000       39        
PS             

100%|██████████| 192M/192M [00:01<00:00, 163MB/s]


Total Parameters: 49,456,995



  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 1/20 | Loss: 1.1214 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.6651
  New best F1: 0.1158


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 2/20 | Loss: 1.0999 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.3992
  No improvement: 1/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 3/20 | Loss: 1.0868 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.4685
  No improvement: 2/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 4/20 | Loss: 1.0829 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.3966
  No improvement: 3/3

Early stopping at epoch 4

CONVNEXT_SMALL | Test Results





FRAME-LEVEL METRICS
  Accuracy:  0.5391
  F1 Score:  0.3777
  Precision: 0.2907
  Recall:    0.5391
  AUC-ROC:   0.9301

SUBJECT-LEVEL METRICS (Primary)
  Accuracy:  0.5019
  F1 Score:  0.3354
  Precision: 0.2519
  Recall:    0.5019
  AUC-ROC:   0.9002
  Total Subjects: 267

Confusion Matrix:
------------------------------------------------------------
     Predicted
            ALS         HC         PS
------------------------------------------------------------
ALS           0          0         94
 HC           0          0         39
 PS           0          0        134
------------------------------------------------------------

Per-Class Metrics:
------------------------------------------------------------
Class           Precision    Recall       F1-Score     Support   
------------------------------------------------------------
ALS             0.0000       0.0000       0.0000       94        
HC              0.0000       0.0000       0.0000       39        
PS             

100%|██████████| 338M/338M [00:02<00:00, 137MB/s]


Total Parameters: 87,569,539



  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 1/20 | Loss: 1.1106 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.3383
  New best F1: 0.1158


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 2/20 | Loss: 1.0864 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.3384
  No improvement: 1/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 3/20 | Loss: 1.0872 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.3376
  No improvement: 2/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 4/20 | Loss: 1.0820 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.3376
  No improvement: 3/3

Early stopping at epoch 4

CONVNEXT_BASE | Test Results





FRAME-LEVEL METRICS
  Accuracy:  0.5391
  F1 Score:  0.3777
  Precision: 0.2907
  Recall:    0.5391
  AUC-ROC:   0.6411

SUBJECT-LEVEL METRICS (Primary)
  Accuracy:  0.5019
  F1 Score:  0.3354
  Precision: 0.2519
  Recall:    0.5019
  AUC-ROC:   0.6080
  Total Subjects: 267

Confusion Matrix:
------------------------------------------------------------
     Predicted
            ALS         HC         PS
------------------------------------------------------------
ALS           0          0         94
 HC           0          0         39
 PS           0          0        134
------------------------------------------------------------

Per-Class Metrics:
------------------------------------------------------------
Class           Precision    Recall       F1-Score     Support   
------------------------------------------------------------
ALS             0.0000       0.0000       0.0000       94        
HC              0.0000       0.0000       0.0000       39        
PS             

100%|██████████| 108M/108M [00:00<00:00, 117MB/s]


Total Parameters: 27,521,661



  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 1/20 | Loss: 1.1386 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.2129
  New best F1: 0.1158


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 2/20 | Loss: 1.0510 | Val Acc: 0.2713 | Val F1: 0.1307 | Val AUC: 0.5164
  New best F1: 0.1307


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 3/20 | Loss: 1.0842 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.7933
  No improvement: 1/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 4/20 | Loss: 1.0686 | Val Acc: 0.4367 | Val F1: 0.3764 | Val AUC: 0.7633
  New best F1: 0.3764


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 5/20 | Loss: 1.0542 | Val Acc: 0.4393 | Val F1: 0.3795 | Val AUC: 0.7530
  New best F1: 0.3795


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 6/20 | Loss: 1.0468 | Val Acc: 0.4083 | Val F1: 0.3408 | Val AUC: 0.7223
  No improvement: 1/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 7/20 | Loss: 1.0561 | Val Acc: 0.4160 | Val F1: 0.3508 | Val AUC: 0.7320
  No improvement: 2/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 8/20 | Loss: 1.0479 | Val Acc: 0.4083 | Val F1: 0.3408 | Val AUC: 0.7391
  No improvement: 3/3

Early stopping at epoch 8

SWIN_T | Test Results





FRAME-LEVEL METRICS
  Accuracy:  0.6609
  F1 Score:  0.5915
  Precision: 0.5643
  Recall:    0.6609
  AUC-ROC:   0.7508

SUBJECT-LEVEL METRICS (Primary)
  Accuracy:  0.5693
  F1 Score:  0.4765
  Precision: 0.4169
  Recall:    0.5693
  AUC-ROC:   0.6889
  Total Subjects: 267

Confusion Matrix:
------------------------------------------------------------
     Predicted
            ALS         HC         PS
------------------------------------------------------------
ALS           0         71         23
 HC           0         22         17
 PS           0          4        130
------------------------------------------------------------

Per-Class Metrics:
------------------------------------------------------------
Class           Precision    Recall       F1-Score     Support   
------------------------------------------------------------
ALS             0.0000       0.0000       0.0000       94        
HC              0.2268       0.5641       0.3235       39        
PS             

100%|██████████| 190M/190M [00:01<00:00, 125MB/s]


Total Parameters: 48,839,565



  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 1/20 | Loss: 1.1265 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.3181
  New best F1: 0.1158


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 2/20 | Loss: 1.0922 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.2241
  No improvement: 1/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 3/20 | Loss: 1.0915 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.0873
  No improvement: 2/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 4/20 | Loss: 1.0834 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.2433
  No improvement: 3/3

Early stopping at epoch 4

SWIN_S | Test Results





FRAME-LEVEL METRICS
  Accuracy:  0.5391
  F1 Score:  0.3777
  Precision: 0.2907
  Recall:    0.5391
  AUC-ROC:   0.1241

SUBJECT-LEVEL METRICS (Primary)
  Accuracy:  0.5019
  F1 Score:  0.3354
  Precision: 0.2519
  Recall:    0.5019
  AUC-ROC:   0.1412
  Total Subjects: 267

Confusion Matrix:
------------------------------------------------------------
     Predicted
            ALS         HC         PS
------------------------------------------------------------
ALS           0          0         94
 HC           0          0         39
 PS           0          0        134
------------------------------------------------------------

Per-Class Metrics:
------------------------------------------------------------
Class           Precision    Recall       F1-Score     Support   
------------------------------------------------------------
ALS             0.0000       0.0000       0.0000       94        
HC              0.0000       0.0000       0.0000       39        
PS             

100%|██████████| 335M/335M [00:05<00:00, 67.8MB/s]


Total Parameters: 86,746,299



  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 1/20 | Loss: 1.1232 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.6061
  New best F1: 0.1158


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 2/20 | Loss: 1.0919 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.5992
  No improvement: 1/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 3/20 | Loss: 1.0887 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.4304
  No improvement: 2/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 4/20 | Loss: 1.0831 | Val Acc: 0.2713 | Val F1: 0.1158 | Val AUC: 0.5266
  No improvement: 3/3

Early stopping at epoch 4

SWIN_B | Test Results





FRAME-LEVEL METRICS
  Accuracy:  0.5391
  F1 Score:  0.3777
  Precision: 0.2907
  Recall:    0.5391
  AUC-ROC:   0.9408

SUBJECT-LEVEL METRICS (Primary)
  Accuracy:  0.5019
  F1 Score:  0.3354
  Precision: 0.2519
  Recall:    0.5019
  AUC-ROC:   0.9304
  Total Subjects: 267

Confusion Matrix:
------------------------------------------------------------
     Predicted
            ALS         HC         PS
------------------------------------------------------------
ALS           0          0         94
 HC           0          0         39
 PS           0          0        134
------------------------------------------------------------

Per-Class Metrics:
------------------------------------------------------------
Class           Precision    Recall       F1-Score     Support   
------------------------------------------------------------
ALS             0.0000       0.0000       0.0000       94        
HC              0.0000       0.0000       0.0000       39        
PS             

100%|██████████| 330M/330M [00:05<00:00, 64.8MB/s]


Total Parameters: 85,800,963



  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 1/20 | Loss: 1.1010 | Val Acc: 0.2817 | Val F1: 0.2407 | Val AUC: 0.5493
  New best F1: 0.2407


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 2/20 | Loss: 0.7952 | Val Acc: 0.5814 | Val F1: 0.4908 | Val AUC: 0.8726
  New best F1: 0.4908


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 3/20 | Loss: 0.6765 | Val Acc: 0.3514 | Val F1: 0.2657 | Val AUC: 0.5107
  No improvement: 1/3


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 4/20 | Loss: 0.5953 | Val Acc: 0.4858 | Val F1: 0.4935 | Val AUC: 0.6977
  New best F1: 0.4935


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 5/20 | Loss: 0.4995 | Val Acc: 0.6615 | Val F1: 0.6533 | Val AUC: 0.7375
  New best F1: 0.6533


  0%|          | 0/161 [00:00<?, ?it/s]



Epoch 6/20 | Loss: 0.4015 | Val Acc: 0.7287 | Val F1: 0.7209 | Val AUC: 0.7654
  New best F1: 0.7209


  0%|          | 0/161 [00:00<?, ?it/s]



In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import os


def score_frames_by_confidence(model, data_loader, device):
    """Score all frames by prediction confidence."""
    model.eval()
    frame_scores = []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Scoring frames"):
            images, labels, subject_ids = batch
            images = images.to(device, non_blocking=True)

            outputs = model(images)
            probs = F.softmax(outputs, dim=1)
            max_probs, predictions = torch.max(probs, dim=1)
            entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)

            for i in range(len(labels)):
                frame_scores.append({
                    'subject_id': subject_ids[i],
                    'true_label': labels[i].item(),
                    'predicted_label': predictions[i].item(),
                    'confidence': max_probs[i].item(),
                    'entropy': entropy[i].item(),
                    'correct': predictions[i].item() == labels[i].item()
                })

    return pd.DataFrame(frame_scores)


def filter_frames_by_strategy(df_scores, strategy='temporal_consistent', k=0.7, min_frames=5):
    """Filter frames based on strategy."""
    kept_indices = set()

    if strategy == 'top_k_confident':
        for subject_id, group in df_scores.groupby('subject_id'):
            n_keep = max(min_frames, int(len(group) * k))
            top_frames = group.nlargest(n_keep, 'confidence')
            for idx in top_frames.index:
                kept_indices.add(idx)

    elif strategy == 'temporal_consistent':
        for subject_id, group in df_scores.groupby('subject_id'):
            majority_pred = group['predicted_label'].mode()[0]
            consistent_frames = group[group['predicted_label'] == majority_pred]

            if len(consistent_frames) < min_frames:
                consistent_frames = group.nlargest(min_frames, 'confidence')

            n_keep = max(min_frames, int(len(consistent_frames) * k))
            top_consistent = consistent_frames.nlargest(n_keep, 'confidence')

            for idx in top_consistent.index:
                kept_indices.add(idx)

    return kept_indices


class FilteredDataset(torch.utils.data.Dataset):
    """Dataset that filters frames based on indices."""
    def __init__(self, base_dataset, indices_to_keep):
        self.base_dataset = base_dataset
        self.indices_to_keep = sorted(list(indices_to_keep))
        self.classes = base_dataset.classes
        self.class_to_idx = base_dataset.class_to_idx

    def __len__(self):
        return len(self.indices_to_keep)

    def __getitem__(self, idx):
        original_idx = self.indices_to_keep[idx]
        return self.base_dataset[original_idx]


# ============================================================================
# MAIN EXECUTION
# ============================================================================

print("="*120)
print("FRAME FILTERING AND SELECTIVE MODEL TRAINING")
print("="*120 + "\n")

# Configuration
FILTER_MODEL = 'resnet101'
FILTER_STRATEGY = 'temporal_consistent'
FILTER_K = 0.7
MIN_FRAMES_PER_SUBJECT = 5

MODELS_TO_TRAIN_FILTERED = [
    "resnet101",
    "mobilenet_v3_small",
    "efficientnet_b5"
]

# Use EXISTING dataloaders from Block 5 (don't recreate!)
print("Using existing dataloaders from memory...")

# If dataloaders not in memory, create them WITHOUT subject check
if 'train_loader' not in globals():
    print("Creating new dataloaders (fast version)...")

    IMAGE_SIZE = 224
    BATCH_SIZE = 16
    NUM_WORKERS = 12
    RANDOM_SEED = 42

    train_transform = get_train_transforms(IMAGE_SIZE)
    eval_transform = get_eval_transforms(IMAGE_SIZE)

    train_dir = "/content/drive/MyDrive/NeuroFace/processed_frames/train"
    val_dir = "/content/drive/MyDrive/NeuroFace/processed_frames/val"
    test_dir = "/content/drive/MyDrive/NeuroFace/processed_frames/test"

    train_dataset = PreprocessedImageFolder(train_dir, transform=train_transform)
    val_dataset = PreprocessedImageFolder(val_dir, transform=eval_transform)
    test_dataset = PreprocessedImageFolder(test_dir, transform=eval_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_with_subject_id
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_with_subject_id
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_with_subject_id
    )
    classes = train_dataset.classes

    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

print("\n" + "="*80)
print(f"FRAME FILTERING PIPELINE")
print(f"Filter Model: {FILTER_MODEL}")
print(f"Strategy: {FILTER_STRATEGY}")
print(f"Keep: {FILTER_K*100:.0f}% of frames")
print("="*80 + "\n")

# Load filter model
print(f"Loading {FILTER_MODEL}...")
torch.cuda.manual_seed(RANDOM_SEED)
filter_model = get_model(model_name=FILTER_MODEL, pretrained=True)

save_dir = "/content/drive/MyDrive/NeuroFace/src"
model_path = f"{save_dir}/{FILTER_MODEL}_lr0_001_neuroface_best_model.pth"

if os.path.exists(model_path):
    print(f"Loading trained weights: {model_path}")
    filter_model.load_state_dict(torch.load(model_path))
else:
    print(f"Warning: Trained model not found, using ImageNet weights")

filter_model = filter_model.to(device)

# Score frames
print("\nScoring frames...")
train_scores = score_frames_by_confidence(filter_model, train_loader, device)
val_scores = score_frames_by_confidence(filter_model, val_loader, device)
test_scores = score_frames_by_confidence(filter_model, test_loader, device)

del filter_model
torch.cuda.empty_cache()

# Filter frames
print(f"\nApplying '{FILTER_STRATEGY}' filtering...")
train_keep = filter_frames_by_strategy(train_scores, FILTER_STRATEGY, FILTER_K, MIN_FRAMES_PER_SUBJECT)
val_keep = filter_frames_by_strategy(val_scores, FILTER_STRATEGY, FILTER_K, MIN_FRAMES_PER_SUBJECT)
test_keep = filter_frames_by_strategy(test_scores, FILTER_STRATEGY, FILTER_K, MIN_FRAMES_PER_SUBJECT)

# Create filtered datasets
filtered_train = FilteredDataset(train_loader.dataset, train_keep)
filtered_val = FilteredDataset(val_loader.dataset, val_keep)
filtered_test = FilteredDataset(test_loader.dataset, test_keep)

filtered_train_loader = torch.utils.data.DataLoader(
    filtered_train, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_with_subject_id
)
filtered_val_loader = torch.utils.data.DataLoader(
    filtered_val, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_with_subject_id
)
filtered_test_loader = torch.utils.data.DataLoader(
    filtered_test, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_with_subject_id
)

# Print statistics
print("\n" + "="*80)
print("FILTERING STATISTICS")
print("="*80)
print(f"Train: {len(train_loader.dataset)} → {len(filtered_train)} ({len(filtered_train)/len(train_loader.dataset):.1%})")
print(f"Val:   {len(val_loader.dataset)} → {len(filtered_val)} ({len(filtered_val)/len(val_loader.dataset):.1%})")
print(f"Test:  {len(test_loader.dataset)} → {len(filtered_test)} ({len(filtered_test)/len(test_loader.dataset):.1%})")
print("="*80 + "\n")

# Train models on filtered data
print("="*120)
print("TRAINING MODELS ON FILTERED DATA")
print("="*120 + "\n")

filtered_results = {}

for model_name in MODELS_TO_TRAIN_FILTERED:
    try:
        print(f"\n{'='*100}")
        print(f"Training {model_name.upper()} on filtered dataset")
        print(f"{'='*100}\n")

        filtered_results[model_name] = train_model(
            model_name=model_name,
            train_loader=filtered_train_loader,
            val_loader=filtered_val_loader,
            test_loader=filtered_test_loader,
            class_names=classes,
            device=device,
            num_epochs=NUM_EPOCHS,
            patience=PATIENCE,
            learning_rate=LEARNING_RATE,
            label_smoothing=0.0,
            class_weights=None,
            seed=RANDOM_SEED
        )

    except Exception as e:
        print(f"Error training {model_name}: {e}")
        import traceback
        traceback.print_exc()
        continue

# Compare results if available
print("\n" + "="*120)
print("RESULTS SUMMARY")
print("="*120 + "\n")

if 'results' in globals():
    print("Comparison: Original vs Filtered\n")
    for model_name in MODELS_TO_TRAIN_FILTERED:
        if model_name in results and model_name in filtered_results:
            orig = results[model_name]
            filt = filtered_results[model_name]
            print(f"{model_name.upper()}:")
            print(f"  F1:  {orig['f1']:.4f} → {filt['f1']:.4f} ({filt['f1']-orig['f1']:+.4f})")
            print(f"  Acc: {orig['accuracy']:.4f} → {filt['accuracy']:.4f} ({filt['accuracy']-orig['accuracy']:+.4f})")
            print(f"  AUC: {orig['auc']:.4f} → {filt['auc']:.4f} ({filt['auc']-orig['auc']:+.4f})\n")
else:
    print("Filtered Results Only:\n")
    for model_name, metrics in filtered_results.items():
        print(f"{model_name.upper()}:")
        print(f"  F1:  {metrics['f1']:.4f}")
        print(f"  Acc: {metrics['accuracy']:.4f}")
        print(f"  AUC: {metrics['auc']:.4f}\n")

print("="*120)
print(f"Strategy: {FILTER_STRATEGY} | Retention: {len(filtered_train)/len(train_loader.dataset):.1%}")
print("="*120)