# Training & Evaluation — Multi-Modal Ordinal Alzheimer's Pipeline

Run the full training and evaluation pipeline on Google Colab GPU.

**Experiments covered:**
1. Unimodal MRI model
2. Unimodal Audio model
3. Multimodal fusion (concat, gated, attention)
4. CORAL ordinal vs cross-entropy comparison
5. Data fraction experiments (25%, 50%, 75%, 100%)
6. Hyperparameter sweeps
7. 5-fold cross-validation

---

**Prerequisites:** Run notebooks 01 and 02 first to extract embeddings.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
PROJECT_DIR = '/content/drive/MyDrive/alzheimer-research'

REPO_DIR = '/content/alzheimer-research'
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/YOUR_USERNAME/alzheimer-research.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

os.chdir(REPO_DIR)
!pip install -q -r requirements.txt

In [None]:
import sys
sys.path.insert(0, REPO_DIR)

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from config import Config

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB')

config = Config()
config.device = device
config.ensure_dirs()

# Set seeds for reproducibility
torch.manual_seed(config.seed)
np.random.seed(config.seed)

## 1. Load Embeddings

Load pre-extracted embeddings from Google Drive. If you haven't extracted them yet, use the synthetic data option below for testing.

In [None]:
EMB_DIR = os.path.join(PROJECT_DIR, 'data_embeddings')
USE_SYNTHETIC = True  # Set to False once you have real embeddings

if not USE_SYNTHETIC:
    mri_emb_path = os.path.join(EMB_DIR, 'mri_embeddings.npz')
    labels_path = os.path.join(EMB_DIR, 'labels.csv')

    assert os.path.exists(mri_emb_path), f'MRI embeddings not found at {mri_emb_path}'
    assert os.path.exists(labels_path), f'Labels not found at {labels_path}'

    import pandas as pd
    mri_data = np.load(mri_emb_path)
    mri_embeddings = mri_data['embeddings'].astype(np.float32)
    labels_df = pd.read_csv(labels_path)
    labels = labels_df['label'].values

    print(f'MRI embeddings: {mri_embeddings.shape}')
    print(f'Label distribution:')
    for name, idx in zip(config.class_names, range(config.num_classes)):
        print(f'  {name}: {(labels == idx).sum()}')
else:
    print('Using synthetic data for pipeline testing...')
    EMBED_DIM = 256
    N_SAMPLES = 2000

## 2. Experiment: Unimodal MRI Training

In [None]:
from experiments.train_unimodal import SyntheticEmbeddingDataset, train_model, run_cross_validation
from experiments.evaluation import plot_learning_curves, plot_confusion_matrix

if USE_SYNTHETIC:
    mri_dataset = SyntheticEmbeddingDataset(
        n_samples=N_SAMPLES, embed_dim=EMBED_DIM, num_classes=config.num_classes
    )
else:
    from experiments.train_unimodal import EmbeddingDataset
    mri_dataset = EmbeddingDataset(mri_emb_path, labels_path, modality='mri')

print(f'Dataset size: {len(mri_dataset)}')

In [None]:
# Train/val split
from torch.utils.data import Subset
n_total = len(mri_dataset)
n_train = int(0.8 * n_total)
indices = np.random.RandomState(config.seed).permutation(n_total)
train_set = Subset(mri_dataset, indices[:n_train])
val_set = Subset(mri_dataset, indices[n_train:])

print(f'Train: {len(train_set)}, Val: {len(val_set)}')

# Train with CORAL ordinal loss
config.num_epochs = 50
config.early_stopping_patience = 10

model_mri, head_mri, metrics_mri, history_mri = train_model(
    train_set, val_set, config,
    embed_dim=EMBED_DIM if USE_SYNTHETIC else mri_embeddings.shape[1],
    loss_type='coral',
)

print(f"\nBest MRI Results:")
print(f"  Accuracy: {metrics_mri['accuracy']:.4f}")
print(f"  QWK: {metrics_mri['qwk']:.4f}")
print(f"  MAE: {metrics_mri['mae']:.4f}")
print(f"  Off-by-1: {metrics_mri['off_by_1']:.4f}")

In [None]:
# Plot learning curves
plot_learning_curves(history_mri, title='Unimodal MRI - Learning Curves')
plt.show()

# Plot confusion matrix
plot_confusion_matrix(
    metrics_mri['confusion_matrix'],
    config.class_names,
    title='Unimodal MRI - Confusion Matrix'
)
plt.show()

## 3. Experiment: 5-Fold Cross-Validation

In [None]:
config.n_folds = 5
config.num_epochs = 50

fold_metrics_mri, avg_metrics_mri = run_cross_validation(
    mri_dataset, config,
    embed_dim=EMBED_DIM if USE_SYNTHETIC else mri_embeddings.shape[1],
    loss_type='coral',
)

## 4. Experiment: CORAL vs Cross-Entropy

In [None]:
from experiments.evaluation import run_loss_comparison

loss_results = run_loss_comparison(config, embed_dim=EMBED_DIM)

## 5. Experiment: Data Fraction

In [None]:
from experiments.evaluation import run_data_fraction_experiment, plot_data_fraction_results

frac_results = run_data_fraction_experiment(
    config,
    SyntheticEmbeddingDataset,
    {'n_samples': N_SAMPLES, 'embed_dim': EMBED_DIM, 'seed': config.seed},
    embed_dim=EMBED_DIM,
)

plot_data_fraction_results(frac_results)
plt.show()

## 6. Experiment: Multimodal Fusion

In [None]:
from experiments.train_multimodal import (
    SyntheticMultimodalDataset,
    train_model as train_multimodal,
    run_cross_validation as run_cv_multi,
)

multi_dataset = SyntheticMultimodalDataset(
    n_samples=N_SAMPLES, embed_dim=EMBED_DIM, seed=config.seed
)

fusion_results = {}

for fusion_type in ['concat', 'gated', 'attention']:
    print(f'\n{"="*50}')
    print(f'Fusion Type: {fusion_type}')
    print(f'{"="*50}')

    n_total = len(multi_dataset)
    n_train = int(0.8 * n_total)
    indices = np.random.RandomState(config.seed).permutation(n_total)
    train_set = Subset(multi_dataset, indices[:n_train])
    val_set = Subset(multi_dataset, indices[n_train:])

    model_f, head_f, metrics_f, history_f = train_multimodal(
        train_set, val_set, config,
        embed_dim=EMBED_DIM, fusion_type=fusion_type, loss_type='coral',
    )

    fusion_results[fusion_type] = metrics_f

    print(f'  Accuracy: {metrics_f["accuracy"]:.4f}')
    print(f'  QWK: {metrics_f["qwk"]:.4f}')
    print(f'  MAE: {metrics_f["mae"]:.4f}')

In [None]:
# Compare all models
print('\n' + '=' * 70)
print(f'{"Model":<25} {"Accuracy":>10} {"QWK":>10} {"MAE":>10} {"Off-by-1":>10}')
print('-' * 70)
print(f'{"Unimodal MRI":<25} {metrics_mri["accuracy"]:>10.4f} {metrics_mri["qwk"]:>10.4f} {metrics_mri["mae"]:>10.4f} {metrics_mri["off_by_1"]:>10.4f}')
for ft, m in fusion_results.items():
    name = f'Multimodal ({ft})'
    print(f'{name:<25} {m["accuracy"]:>10.4f} {m["qwk"]:>10.4f} {m["mae"]:>10.4f} {m["off_by_1"]:>10.4f}')
print('=' * 70)

## 7. Hyperparameter Sweep

In [None]:
# Sweep over learning rates and embedding dims
sweep_results = []

for lr in [1e-2, 1e-3, 1e-4]:
    for embed_dim in [128, 256]:
        for dropout in [0.2, 0.4]:
            config.learning_rate = lr
            config.fusion_dropout = dropout
            config.num_epochs = 30  # Shorter for sweep

            sweep_dataset = SyntheticEmbeddingDataset(
                n_samples=N_SAMPLES, embed_dim=embed_dim, seed=config.seed
            )
            n_total = len(sweep_dataset)
            n_train = int(0.8 * n_total)
            indices = np.random.RandomState(config.seed).permutation(n_total)
            train_set = Subset(sweep_dataset, indices[:n_train])
            val_set = Subset(sweep_dataset, indices[n_train:])

            _, _, metrics, _ = train_model(
                train_set, val_set, config,
                embed_dim=embed_dim, loss_type='coral', verbose=False,
            )

            result = {
                'lr': lr, 'embed_dim': embed_dim, 'dropout': dropout,
                'accuracy': metrics['accuracy'], 'qwk': metrics['qwk'],
                'mae': metrics['mae'],
            }
            sweep_results.append(result)
            print(f'lr={lr}, dim={embed_dim}, dropout={dropout} → '
                  f'Acc={metrics["accuracy"]:.4f}, QWK={metrics["qwk"]:.4f}')

# Find best config
best = max(sweep_results, key=lambda x: x['qwk'])
print(f'\nBest config: lr={best["lr"]}, dim={best["embed_dim"]}, '
      f'dropout={best["dropout"]} → QWK={best["qwk"]:.4f}')

## 8. Save Results

In [None]:
import pandas as pd

RESULTS_DIR = os.path.join(PROJECT_DIR, 'experiment_results')
os.makedirs(RESULTS_DIR, exist_ok=True)

# Save sweep results
sweep_df = pd.DataFrame(sweep_results)
sweep_df.to_csv(os.path.join(RESULTS_DIR, 'hyperparameter_sweep.csv'), index=False)

# Save model comparison
comparison_rows = []
comparison_rows.append({'model': 'Unimodal MRI', **{k: metrics_mri[k] for k in ['accuracy', 'qwk', 'mae', 'off_by_1']}})
for ft, m in fusion_results.items():
    comparison_rows.append({'model': f'Multimodal ({ft})', **{k: m[k] for k in ['accuracy', 'qwk', 'mae', 'off_by_1']}})
comp_df = pd.DataFrame(comparison_rows)
comp_df.to_csv(os.path.join(RESULTS_DIR, 'model_comparison.csv'), index=False)

print('Results saved to Google Drive:')
print(f'  {RESULTS_DIR}/hyperparameter_sweep.csv')
print(f'  {RESULTS_DIR}/model_comparison.csv')
print('\nDone!')