# 03 — Train and Evaluate the Multi-Task System

Two-phase training:
- **Phase 1**: MRI encoder pretraining (ordinal CDR only, cross-sectional)
- **Phase 2**: Full multi-task (ordinal + survival + alignment) with longitudinal data

Supports checkpoint-resume across Colab sessions.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os, sys
PROJECT_DIR = '/content/drive/MyDrive/alzheimer-research'
sys.path.insert(0, PROJECT_DIR)

!pip install -q nibabel

import torch
import numpy as np
from torch.utils.data import DataLoader

from config import Config
from models import AlzheimerMultiTaskModel
from data.nacc_dataset import NACCMRIDataset, NACCLongitudinalDataset
from data.speech_dataset import SpeechEmbeddingDataset, SyntheticSpeechDataset
from data.preprocessing import build_mri_augmentation
from training import Phase1Trainer, Phase2Trainer

cfg = Config()
cfg.ensure_dirs()

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')
if DEVICE.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name()}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')

## 1. Load Data

Set `USE_REAL_DATA = True` once the manifest CSV is ready.

In [None]:
USE_REAL_DATA = False  # flip to True when data is ready

MANIFEST_CSV = str(cfg.embedding_dir / 'nacc_mri_manifest.csv')
SPEECH_NPZ = str(cfg.embedding_dir / 'speech_features.npz')

## 2. Build Model

In [None]:
model = AlzheimerMultiTaskModel.from_config(cfg)

total_params = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters:     {total_params:,}')
print(f'Trainable parameters: {trainable:,}')
print(f'Model size (fp16):    {total_params * 2 / 1e6:.1f} MB')

## 3. Phase 1 — MRI Ordinal Pretraining

In [None]:
if USE_REAL_DATA:
    train_ids = np.load(cfg.embedding_dir / 'train_subject_ids.npy', allow_pickle=True)
    val_ids = np.load(cfg.embedding_dir / 'val_subject_ids.npy', allow_pickle=True)

    aug = build_mri_augmentation(cfg)
    train_ds = NACCMRIDataset(MANIFEST_CSV, cfg.mri_volume_shape, transform=aug, subject_ids=train_ids)
    val_ds = NACCMRIDataset(MANIFEST_CSV, cfg.mri_volume_shape, subject_ids=val_ids)

    train_loader = DataLoader(
        train_ds, batch_size=cfg.phase1_batch_size,
        shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory,
    )
    val_loader = DataLoader(
        val_ds, batch_size=cfg.phase1_batch_size,
        shuffle=False, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory,
    )

    p1_trainer = Phase1Trainer(model, train_loader, val_loader, cfg, DEVICE)
    p1_results = p1_trainer.fit(resume=False)
    print(f'Phase 1 best QWK: {p1_results["best_qwk"]:.4f}')
else:
    print('Skipping Phase 1 — set USE_REAL_DATA = True when manifest is ready')

## 4. Phase 2 — Multi-Task + Longitudinal + Alignment

In [None]:
if USE_REAL_DATA:
    mri_train_ds = NACCLongitudinalDataset(
        MANIFEST_CSV, cfg.mri_volume_shape, transform=aug, subject_ids=train_ids,
    )
    mri_val_ds = NACCLongitudinalDataset(
        MANIFEST_CSV, cfg.mri_volume_shape, subject_ids=val_ids,
    )

    mri_train_loader = DataLoader(
        mri_train_ds, batch_size=cfg.phase2_mri_batch_size,
        shuffle=True, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory,
    )
    mri_val_loader = DataLoader(
        mri_val_ds, batch_size=cfg.phase2_mri_batch_size,
        shuffle=False, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory,
    )

    # Speech (optional)
    speech_train_loader, speech_val_loader = None, None
    from pathlib import Path
    if Path(SPEECH_NPZ).exists():
        speech_ds = SpeechEmbeddingDataset(SPEECH_NPZ)
        speech_train_loader = DataLoader(
            speech_ds, batch_size=cfg.phase2_speech_batch_size,
            shuffle=True, drop_last=True,
        )

    p2_trainer = Phase2Trainer(
        model, mri_train_loader, mri_val_loader,
        speech_train_loader, speech_val_loader, cfg, DEVICE,
    )
    p2_results = p2_trainer.fit(resume=False)
    print(f'Phase 2 best QWK: {p2_results["best_qwk"]:.4f}')
else:
    print('Skipping Phase 2 — set USE_REAL_DATA = True when data is ready')

## 5. Evaluation

In [None]:
if USE_REAL_DATA:
    from training.callbacks import CheckpointManager
    from evaluation.metrics import compute_all_ordinal_metrics, optimize_thresholds
    from evaluation.visualization import (
        plot_learning_curves, plot_confusion_matrix, plot_calibration,
    )
    from sklearn.metrics import confusion_matrix as cm_func

    # Load best model
    ckpt_mgr = CheckpointManager(cfg.checkpoint_dir, prefix='phase2')
    ckpt = ckpt_mgr.load(which='best')
    model.load_state_dict(ckpt['model_state'])
    model.to(DEVICE).eval()

    # Test set evaluation
    test_ids = np.load(cfg.embedding_dir / 'test_subject_ids.npy', allow_pickle=True)
    test_ds = NACCLongitudinalDataset(
        MANIFEST_CSV, cfg.mri_volume_shape, subject_ids=test_ids,
    )
    test_loader = DataLoader(test_ds, batch_size=cfg.phase2_mri_batch_size)

    all_preds, all_labels, all_severity = [], [], []
    with torch.no_grad():
        for batch in test_loader:
            out = model.forward_mri_multitask(
                batch['volumes'].to(DEVICE),
                batch['time_deltas'].to(DEVICE),
                batch['lengths'].to(DEVICE),
            )
            preds = (out['ord_cum_logits'] > 0).sum(dim=1).cpu()
            all_preds.append(preds)
            all_labels.append(batch['label'])
            all_severity.append(out['ord_severity'].cpu())

    preds = torch.cat(all_preds).numpy()
    labels = torch.cat(all_labels).numpy()
    severity = torch.cat(all_severity).squeeze().numpy()

    # Post-hoc threshold optimisation
    best_thresholds, opt_qwk = optimize_thresholds(severity, labels)
    print(f'Optimised thresholds: {best_thresholds}')
    print(f'Post-hoc QWK: {opt_qwk:.4f}')

    # Full metrics
    metrics = compute_all_ordinal_metrics(labels, preds)
    for k, v in metrics.items():
        print(f'  {k}: {v:.4f}')

    # Plots
    plot_learning_curves(p2_results['history'], cfg.results_dir / 'learning_curves.png')
    cm = cm_func(labels, preds)
    plot_confusion_matrix(cm, list(cfg.class_names), cfg.results_dir / 'confusion_matrix.png')
    print(f'Plots saved to {cfg.results_dir}')
else:
    print('Skipping evaluation — set USE_REAL_DATA = True when data is ready')

## 6. Sanity Check (Synthetic Forward Pass)

Verify the entire model runs end-to-end with random tensors.

In [None]:
print('=== Sanity check: forward pass with random data ===')
model_test = AlzheimerMultiTaskModel.from_config(cfg).to(DEVICE)

# Phase 1 forward
fake_vol = torch.randn(2, 1, *cfg.mri_volume_shape, device=DEVICE)
with torch.amp.autocast('cuda', enabled=cfg.use_amp and DEVICE.type == 'cuda'):
    out1 = model_test.forward_phase1(fake_vol)
print(f'Phase 1 output keys: {list(out1.keys())}')
print(f'  severity shape: {out1["severity"].shape}')
print(f'  cum_logits shape: {out1["cum_logits"].shape}')

# Phase 2 forward (longitudinal)
fake_seq = torch.randn(2, 3, 1, *cfg.mri_volume_shape, device=DEVICE)
fake_dt = torch.tensor([[0, 12, 24], [0, 6, 0]], dtype=torch.float32, device=DEVICE)
fake_lengths = torch.tensor([3, 2], device=DEVICE)

with torch.amp.autocast('cuda', enabled=cfg.use_amp and DEVICE.type == 'cuda'):
    out2 = model_test.forward_mri_multitask(
        fake_seq, fake_dt, fake_lengths, run_survival=True,
    )
print(f'\nPhase 2 output keys: {list(out2.keys())}')
print(f'  embedding shape: {out2["embedding"].shape}')
print(f'  survival probs shape: {out2["surv_survival"].shape}')

# Speech forward
fake_speech = torch.randn(4, cfg.speech_input_dim, device=DEVICE)
with torch.amp.autocast('cuda', enabled=cfg.use_amp and DEVICE.type == 'cuda'):
    out3 = model_test.forward_speech(fake_speech)
print(f'\nSpeech output keys: {list(out3.keys())}')
print(f'  embedding shape: {out3["embedding"].shape}')

del model_test
if DEVICE.type == 'cuda':
    torch.cuda.empty_cache()
print('\nSanity check passed!')