# 3. Train AD/CN Classification Model with Hyperparameter Tuning

This enhanced version includes systematic hyperparameter tuning using Optuna, similar to the Keras Tuner approach in the original implementation. The model architecture and key hyperparameters are optimized automatically.

### Inputs and Outputs

**Inputs:**
- `ad_cn_train.pkl`, `ad_cn_val.pkl`: Pickled dictionaries containing preprocessed training and validation data.

**Outputs:**
- `hyperparameter_study.pkl`: A joblib file containing the results of the Optuna hyperparameter search.
- `ad_cn_model_best_tuned.pth`: The PyTorch state dictionary for the best performing model.
- **W&B Artifacts:**
  - Logs of hyperparameters and metrics for each tuning trial.
  - Logs for the final model training run, including loss/accuracy curves.
  - The final trained model file (`.pth`) saved as a W&B Artifact.
- Visualization images of data augmentation saved to `reports/figures/augmentation_examples/`.


In [None]:
%pip install optuna wandb monai

In [None]:
!pip install optuna-integration[pytorch_lightning]

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pickle
from pathlib import Path
from scipy import ndimage
import random
from tqdm.notebook import tqdm
import optuna
from optuna.integration import PyTorchLightningPruningCallback
import joblib
import wandb
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Define Paths and Parameters

In [None]:
drive_base_path = Path("PATH_TO_DATA")
drive_save_path = drive_base_path / "model_outputs" 

# Use the mounted Google Drive paths
train_path = drive_base_path / "data" / "fdg"  / "ad_cn_train.pkl"
val_path = drive_base_path / "data" / "fdg" / "ad_cn_val.pkl"

# Create a directory for saving models within Google Drive (or a specified location)
model_save_path = drive_save_path / "saved_models"

model_save_path.mkdir(exist_ok=True)

# Create visualization directory for augmentation examples
figures_path = Path("./reports/figures/augmentation_examples/")
figures_path.mkdir(parents=True, exist_ok=True)

# W&B Login
#wandb.login()

DEVICE = torch.device("cuda" if torch.cuda.is_available else "cpu")
N_TRIALS = 30  # Number of hyperparameter combinations to try
EPOCHS_PER_TRIAL = 35  # Epochs for each trial
BATCH_SIZE = 8

print(f"Training on {DEVICE}")
print(f"Will run {N_TRIALS} hyperparameter optimization trials")

### Dataset and Augmentation

In [None]:
from monai.transforms import (
    Compose,
    RandAffine,
)

class ADNIDataset(Dataset):
    def __init__(self, pkl_file):
        # Load data once
        with open(pkl_file, 'rb') as f:
            data_dict = pickle.load(f)
        self.images = data_dict["images"]
        self.labels = data_dict["labels"]
        self.num_samples = len(self.images)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        # Return raw tensor. Ensure it has channel dim (C, D, H, W)
        # Assuming original data is (D, H, W), add channel dim:
        if image.ndim == 3:
            image = image.unsqueeze(0)
        return image.float(), label.float()

# Define GPU-based augmentations
# prob=0.5 applies the transform 50% of the time
gpu_augmentations = Compose([
    # Rotation and Shift (Translation) combined in one affine matrix for speed
    RandAffine(
    prob=0.5,
    rotate_range=(0.349, 0.349, 0.349),  # ~20 degrees in radians
    translate_range=(10, 10, 10),        # Shift pixels
    padding_mode="zeros",
    device=DEVICE
    ),
])

### Tunable Model Architecture

In [None]:
class TunableCNN3D(nn.Module):
    def __init__(self, trial):
        super(TunableCNN3D, self).__init__()

        # Hyperparameters to tune
        n_layers = trial.suggest_int('n_layers', 4, 6)
        base_filters = trial.suggest_int('base_filters', 8, 16)
        dropout_rate = trial.suggest_float('dropout_rate', 0.2, 0.5)
        dense_units = trial.suggest_int('dense_units', 256, 1024, step=256)

        layers = []
        in_channels = 1

        # Build convolutional layers
        for i in range(n_layers):
            out_channels = base_filters * (2 ** i)

            if i == 0:
                layers.extend([
                    nn.Conv3d(in_channels, out_channels, kernel_size=3),
                    nn.ReLU(),
                    nn.MaxPool3d(2),
                    nn.BatchNorm3d(out_channels)
                ])
            else:
                layers.extend([
                    nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same'),
                    nn.ReLU(),
                    nn.MaxPool3d(2),
                    nn.BatchNorm3d(out_channels),
                    nn.Dropout(dropout_rate)
                ])

            in_channels = out_channels

        self.features = nn.Sequential(*layers)

        # Classifier
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(in_channels, dense_units),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(dense_units, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

### Hyperparameter Optimization

Optimize for validation loss in the optuna training optimizer

In [None]:
# Load data globally ONCE to avoid reloading from Drive for every trial
print("Loading datasets into memory for Optuna...")
train_dataset = ADNIDataset(train_path)
val_dataset = ADNIDataset(val_path)
print("Datasets loaded.")

def objective(trial):
    # W&B Init for each trial
    # run = wandb.init(
    #     project="AD_CN_FDG",
    #     group="Hyperparameter-Tuning",
    #     config=trial.params,
    #     reinit=True
    # )

    # Suggest hyperparameters
    lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
    #weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-3, log=True)

    # Create model
    model = TunableCNN3D(trial).to(DEVICE)

    # Data loaders
    # Use the pre-loaded global datasets
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2)

    # Training setup
    optimizer = optim.Adam(model.parameters(), lr=lr) #weight_decay=weight_decay
    criterion = nn.BCELoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                 mode='min',  # Use 'min' for loss
                                                 factor=0.5,
                                                 patience=3,
                                                 )


    best_val_loss = float('inf') # Initialize to infinity for loss minimization
    patience = 7 # Patience for early stopping in trials
    epochs_no_improve = 0

    # Training loop
    for epoch in range(EPOCHS_PER_TRIAL):
        # Training
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0

        for images, labels in train_loader:
            # Move data to GPU as early as possible
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            # Apply GPU augmentations
            # RandAffine expects (C, spatial...), so we apply it to each image in the batch
            images = torch.stack([gpu_augmentations(img) for img in images])

            optimizer.zero_grad()
            outputs = model(images)                 # (N,1) probs (0-1) due to Sigmoid in model
            labels  = labels.view(-1, 1).float()    # (N,1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)

            # FIX: outputs are already probabilities, do not apply sigmoid again
            predicted = (outputs > 0.5).float()
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        epoch_train_loss = train_loss / len(train_loader.dataset)
        epoch_train_accuracy = correct_train / total_train

        # Validation
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for images, labels in val_loader:
                # Move data to GPU as early as possible
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs.squeeze(), labels)
                val_loss += loss.item() * images.size(0)

                predicted = (outputs.squeeze() > 0.5).float()
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        epoch_val_loss = val_loss / len(val_loader.dataset)
        epoch_val_accuracy = correct_val / total_val

        # Print epoch loss and accuracy for verbosity
        print(f"Trial {trial.number}, Epoch {epoch+1}/{EPOCHS_PER_TRIAL}, Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_accuracy:.4f}, Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_accuracy:.4f}")

        # Log metrics to W&B
        # wandb.log({
        #     "epoch": epoch + 1,
        #     "train_loss": epoch_train_loss,
        #     "train_accuracy": epoch_train_accuracy,
        #     "val_loss": epoch_val_loss,
        #     "val_accuracy": epoch_val_accuracy
        # })

        scheduler.step(epoch_val_loss)

        # Report intermediate result for pruning (report loss instead of accuracy)
        trial.report(epoch_val_loss, epoch)

        # Handle pruning
        if trial.should_prune():
            # wandb.finish()
            raise optuna.exceptions.TrialPruned()

        # Early stopping check for trials (check loss instead of accuracy)
        if epoch_val_loss < best_val_loss: # Change to minimize loss
            best_val_loss = epoch_val_loss # Update best loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # FIX: Add break statement to enforce early stopping
        if epochs_no_improve == patience:
            print(f"Trial {trial.number}: Early stopping after {epoch+1} epochs.")
            break

    # At the end of the function, before returning
    # wandb.finish()
    return best_val_loss # Return loss instead of accuracy

### Run Hyperparameter Optimization

In [None]:
# Create study
study = optuna.create_study(
    direction='minimize', # Changed to minimize for loss optimization
    pruner=optuna.pruners.MedianPruner(n_startup_trials=8, n_warmup_steps=5) # Removed pruner for small number of trials
)

# Define a callback to save the study whenever a new best trial is found
def save_best_study_callback(study, trial):
    if study.best_trial.number == trial.number:
        joblib.dump(study, model_save_path / "hyperparameter_study.pkl")
        print(f"New best trial found! Study saved to {model_save_path / 'hyperparameter_study.pkl'}")

# Run optimization
study.optimize(objective, n_trials=N_TRIALS, callbacks=[save_best_study_callback])

# Print results
print("\nOptimization completed!")
print(f"Best trial: {study.best_trial.number}")
print(f"Best validation loss: {study.best_value:.4f}")
print("\nBest hyperparameters:")
for key, value in study.best_params.items():
    print(f"  {key}: {value}")

# Save study
joblib.dump(study, model_save_path / "hyperparameter_study.pkl")
print(f"\nStudy saved to {model_save_path / 'hyperparameter_study.pkl'}")

### Train Final Model with Best Hyperparameters

In [None]:
import joblib

# Load the Optuna study
study = joblib.load(model_save_path / "hyperparameter_study.pkl")

# Verify it loaded
print(f"Best trial number: {study.best_trial.number}")
print(f"Best validation loss: {study.best_value:.4f}")
print("Best params:", study.best_params)

In [None]:
# Create final model with best hyperparameters
print("\nTraining final model...")
from sklearn.metrics import roc_auc_score
import numpy as np

# --- Parameter Configuration ---
# Define manual parameters here to override the Optuna study results if desired.
# Set manual_params to None to use the best parameters found by Optuna.
manual_params = None
# Example of how to define manual parameters:
# manual_params = {
#     'n_layers': 5,
#     'base_filters': 15,
#     'dropout_rate': 0.30103587368703894,
#     'dense_units': 1024,
#     'lr': 0.0008165831666899182,
# }

# Determine which parameters to use
if manual_params is not None:
    print("Using MANUAL hyperparameters.")
    final_params = manual_params
elif 'study' in globals() and study is not None:
    print("Using OPTUNA BEST hyperparameters.")
    final_params = study.best_params
else:
    raise ValueError("No hyperparameters available. Please define 'manual_params' or run the Optuna study.")

print(f"Final Parameters: {final_params}")

# Start a new W&B run for the final model
final_run = wandb.init(
    project="AD_CN_Classification",
    name="Final-Model-Training",
    job_type="train",
    config=final_params
)

# Create a mock trial with best parameters
class MockTrial:
    def __init__(self, params):
        self.params = params

    def suggest_int(self, name, low, high, step=1):
        return self.params[name]

    def suggest_float(self, name, low, high, log=False):
        return self.params[name]

# Assuming 'study' object is available from the previous cell execution
mock_trial = MockTrial(final_params)
final_model = TunableCNN3D(mock_trial).to(DEVICE)

# Training setup
# Load data lazily in the dataset
train_dataset = ADNIDataset(train_path)
val_dataset = ADNIDataset(val_path)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2)

optimizer = optim.Adam(
    final_model.parameters(),
    lr=final_params['lr'],
    weight_decay=final_params.get('weight_decay', 0)
)
# Using BCELoss because the model ends with nn.Sigmoid()
criterion = nn.BCELoss()

# Extended training for final model
FINAL_EPOCHS = 100
best_val_loss = float('inf')
patience = 10  # Number of epochs to wait for improvement
epochs_no_improve = 0

# Lists to store metrics for plotting
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
train_aucs = []
val_aucs = []

for epoch in range(FINAL_EPOCHS):
    # Training
    final_model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    train_probs_all = []
    train_targets_all = []

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{FINAL_EPOCHS} [Train]"):
        # Move data to GPU as early as possible
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        # Apply GPU augmentations
        images = torch.stack([gpu_augmentations(img) for img in images])

        optimizer.zero_grad()
        outputs = final_model(images)

        # Reshape for BCELoss: outputs (B, 1) -> (B,), labels (B,) -> (B,)
        # Or keep both (B, 1). Let's standardise to flat vectors.
        outputs_flat = outputs.view(-1)
        labels_flat = labels.view(-1).float()

        loss = criterion(outputs_flat, labels_flat)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)

        # Model outputs are ALREADY probabilities (0-1). Do not apply sigmoid.
        predicted = (outputs_flat > 0.5).float()
        total_train += labels.size(0)
        correct_train += (predicted == labels_flat).sum().item()

        train_probs_all.extend(outputs_flat.detach().cpu().numpy())
        train_targets_all.extend(labels_flat.cpu().numpy())

    epoch_train_loss = running_loss / len(train_loader.dataset)
    epoch_train_accuracy = correct_train / total_train

    try:
        epoch_train_auc = roc_auc_score(train_targets_all, train_probs_all)
    except ValueError:
        epoch_train_auc = 0.5

    # Store training metrics
    train_losses.append(epoch_train_loss)
    train_accuracies.append(epoch_train_accuracy)
    train_aucs.append(epoch_train_auc)


    # Validation
    final_model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0

    val_probs_all = []
    val_targets_all = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = final_model(images)

            outputs_flat = outputs.view(-1)
            labels_flat = labels.view(-1).float()

            loss = criterion(outputs_flat, labels_flat)
            val_loss += loss.item() * images.size(0)

            predicted = (outputs_flat > 0.5).float()
            total_val += labels.size(0)
            correct_val += (predicted == labels_flat).sum().item()

            val_probs_all.extend(outputs_flat.cpu().numpy())
            val_targets_all.extend(labels_flat.cpu().numpy())


    val_loss /= len(val_loader.dataset)
    epoch_val_accuracy = correct_val / total_val

    try:
        epoch_val_auc = roc_auc_score(val_targets_all, val_probs_all)
    except ValueError:
        epoch_val_auc = 0.5

    # Store validation metrics
    val_losses.append(val_loss)
    val_accuracies.append(epoch_val_accuracy)
    val_aucs.append(epoch_val_auc)


    # Print epoch loss and accuracy for verbosity
    print(f"Epoch {epoch+1}/{FINAL_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_accuracy:.4f}, Train AUC: {epoch_train_auc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {epoch_val_accuracy:.4f}, Val AUC: {epoch_val_auc:.4f}")

    # Log metrics to W&B
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": epoch_train_loss,
        "train_accuracy": epoch_train_accuracy,
        "train_auc": epoch_train_auc,
        "val_loss": val_loss,
        "val_accuracy": epoch_val_accuracy,
        "val_auc": epoch_val_auc,
        "best_val_loss": best_val_loss
    })

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(final_model.state_dict(), model_save_path / "ad_cn_model_best_tuned.pth")
        print(f"New best model saved with validation loss: {val_loss:.4f}")
        epochs_no_improve = 0 # Reset counter
        wandb.run.summary["best_val_loss"] = best_val_loss
        wandb.run.summary["best_epoch"] = epoch + 1
    else:
        epochs_no_improve += 1

    # Early stopping check
    if epochs_no_improve == patience:
        print(f"Early stopping after {epoch+1} epochs due to no improvement in validation loss for {patience} epochs.")
        break # Stop training loop

print("\nFinal training completed!")

# Log model artifact to W&B
model_artifact = wandb.Artifact(
    name="ad_cn_classifier",
    type="model",
    description="Best 3D CNN model for AD/CN classification after hyperparameter tuning."
)
model_artifact.add_file(model_save_path / "ad_cn_model_best_tuned.pth")
final_run.log_artifact(model_artifact)

print(f"Best model saved to {model_save_path / 'ad_cn_model_best_tuned.pth'}")
wandb.finish()

### Plots

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import torch

# Plotting metrics
plt.figure(figsize=(18, 5))

# Loss
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

# Accuracy
plt.subplot(1, 3, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

# AUC
plt.subplot(1, 3, 3)
plt.plot(train_aucs, label='Train AUC')
plt.plot(val_aucs, label='Validation AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.title('Training and Validation AUC')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# --- ROC Curve for Best Model ---
print("\nEvaluating best model for ROC Curve...")

# Load the best model weights
best_model_path = model_save_path / "ad_cn_model_best_tuned.pth"
final_model.load_state_dict(torch.load(best_model_path))
final_model.eval()

y_true = []
y_scores = []

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(DEVICE)
        outputs = final_model(images)
        y_true.extend(labels.cpu().numpy())
        y_scores.extend(outputs.view(-1).cpu().numpy())

fpr, tpr, thresholds = roc_curve(y_true, y_scores)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (Best Validation Model)')
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

## Evaluate on test dataset

In [None]:
# --- Evaluation on Test Dataset ---
# Define the path to your test data (modify as needed)
test_path = drive_base_path / "data" / "fdg" / "ad_cn_test.pkl" # UPDATE THIS PATH

if test_path.exists():
    print("\nEvaluating the best model on the test dataset...")

    # Load the test dataset lazily
    test_dataset = ADNIDataset(test_path)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=2)

    # Load the best trained model
    best_model_path = model_save_path / "ad_cn_model_best_tuned.pth"
    if best_model_path.exists():
        # Create a mock trial with best parameters to instantiate the model architecture
        class MockTrial:
            def __init__(self, params):
                self.params = params

            def suggest_int(self, name, low, high, step=1):
                return self.params[name]

            def suggest_float(self, name, low, high, log=False):
                return self.params[name]

        # Determine params to rebuild the architecture
        # Use 'final_params' from the training cell if available, else fallback to study
        if 'final_params' in globals():
            print("Using 'final_params' from training session to rebuild model.")
            rebuild_params = final_params
        elif 'study' in globals() and study is not None:
            print("Using 'study.best_params' to rebuild model.")
            rebuild_params = study.best_params
        else:
             # Fallback error
             raise ValueError("Cannot rebuild model: 'final_params' and 'study' are missing. Please define architecture parameters.")

        mock_trial = MockTrial(rebuild_params)
        test_model = TunableCNN3D(mock_trial).to(DEVICE)
        test_model.load_state_dict(torch.load(best_model_path))
        test_model.eval() # Set the model to evaluation mode

        criterion = nn.BCELoss() # Use BCELoss to match training

        test_loss = 0.0
        correct_test = 0
        total_test = 0

        with torch.no_grad(): # No gradient calculation during evaluation
            for images, labels in tqdm(test_loader, desc="Evaluating Test Set"):
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = test_model(images)

                outputs_flat = outputs.view(-1)
                labels_flat = labels.view(-1).float()

                loss = criterion(outputs_flat, labels_flat)
                test_loss += loss.item() * images.size(0)

                # outputs are probabilities (0-1)
                predicted = (outputs_flat > 0.5).float()
                total_test += labels.size(0)
                correct_test += (predicted == labels_flat).sum().item()

        test_loss /= len(test_loader.dataset)
        test_accuracy = correct_test / total_test

        print(f"\nTest Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

    else:
        print(f"Best model file not found at {best_model_path}. Please ensure the final training completed successfully.")

else:
    print(f"Test data file not found at {test_path}. Please update the path or ensure the file exists.")