# 6. Train MCI Conversion Model with Transfer Learning

This notebook trains 3D CNN models to predict MCI to AD conversion using transfer learning from the pre-trained AD/CN model. The training is performed using 5-fold cross-validation for robust evaluation.

Key steps:
1. Load the pre-trained AD/CN model as a starting point
2. Fine-tune the model on each fold of the MCI conversion data
3. Save the best model from each fold for later uncertainty analysis
4. Track training metrics across all folds

### Inputs and Outputs

**Inputs:**
- K-fold data splits (`train_fold_*.pkl`, `val_fold_*.pkl`) from notebook 05.
- The pre-trained AD/CN classifier model (`ad_cn_model_best_tuned.pth`) from notebook 03.
- The hyperparameter study file (`hyperparameter_study.pkl`) to define the model architecture.

**Outputs:**
- The best performing model for each of the 5 folds (`mci_model_fold_*_best.pth`).
- `training_summary.csv` with performance metrics for each fold.
- **W&B Artifacts:**
  - A run for each fold, logging metrics and saving the best model as an artifact.
  - A final summary run logging the cross-validation results table and plots.


In [None]:
%pip install optuna monai captum nibabel "numpy>=2.0"

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pickle
from pathlib import Path
from scipy import ndimage
import random
from tqdm.notebook import tqdm
import pandas as pd
from sklearn.metrics import roc_auc_score, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
from monai.transforms import (
    Compose,
    RandAffine,
)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Define Paths and Parameters

In [None]:
# Paths tau
# kfold_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/data/tau/kfold/")
# pretrained_model_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/ad_cn_model_best_tuned.pth")
# output_model_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/mci_conversion_tau/")
# C_LEARNING_RATE = 0.004391419283753976
# FT_LEARNING_RATE = C_LEARNING_RATE/20
# DROPOUT = 0.284169899466295

# Paths fdg
kfold_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/data/fdg/kfold/")
pretrained_model_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/ad_cn_model_best_tuned.pth")
output_model_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/mci_conversion/")
C_LEARNING_RATE = 0.004391419283753976
FT_LEARNING_RATE = C_LEARNING_RATE/20
DROPOUT = 0.284169899466295

output_model_path.mkdir(exist_ok=True)

# Training parameters
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_FOLDS = 5
EPOCHS = 30
BATCH_SIZE = 8  # Smaller batch size for MCI data


print(f"Training on {DEVICE}")
print(f"Will train {NUM_FOLDS} models for {EPOCHS} epochs each")

### Dataset and Model Definitions

In [None]:
class ADNIDataset(Dataset):
    def __init__(self, pkl_file):
        # Load data once
        with open(pkl_file, 'rb') as f:
            data_dict = pickle.load(f)
        self.images = data_dict["images"]
        self.labels = data_dict["labels"]
        self.num_samples = len(self.images)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        # Return raw tensor. Ensure it has channel dim (C, D, H, W)
        # Assuming original data is (D, H, W), add channel dim:
        if image.ndim == 3:
            image = image.unsqueeze(0)
        return image.float(), label.float()

# Define GPU-based augmentations
# prob=0.5 applies the transform 50% of the time
gpu_augmentations = Compose([
    # Rotation and Shift (Translation) combined in one affine matrix for speed
    RandAffine(
    prob=0.5,
    rotate_range=(0.349, 0.349, 0.349),  # ~20 degrees in radians
    translate_range=(10, 10, 10),        # Shift pixels
    padding_mode="zeros",
    device=DEVICE
    ),
])

class TunableCNN3D(nn.Module):
    def __init__(self, n_layers, base_filters, dropout_rate, dense_units):
        super(TunableCNN3D, self).__init__()

        layers = []
        in_channels = 1

        for i in range(n_layers):
            out_channels = base_filters * (2 ** i)
            if i == 0:
                layers.extend([
                    nn.Conv3d(in_channels, out_channels, kernel_size=3),
                    nn.ReLU(),
                    nn.MaxPool3d(2),
                    nn.BatchNorm3d(out_channels)
                ])
            else:
                layers.extend([
                    nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same'),
                    nn.ReLU(),
                    nn.MaxPool3d(2),
                    nn.BatchNorm3d(out_channels),
                    nn.Dropout(dropout_rate)
                ])
            in_channels = out_channels

        self.features = nn.Sequential(*layers)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(in_channels, dense_units),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(dense_units, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

### Model Explainability with Grad-CAM

In [None]:
from captum.attr import LayerGradCam, LayerAttribution
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import gaussian_filter, binary_opening, binary_closing

class LogitWrapper(nn.Module):
    """
    Wraps the model to return logits (pre-sigmoid) for Grad-CAM stability.
    """
    def __init__(self, m):
        super().__init__()
        self.m = m

    def forward(self, x):
        x = self.m.features(x)
        # Execute classifier layers except the last one (Sigmoid)
        for i in range(len(self.m.classifier) - 1):
            x = self.m.classifier[i](x)
        return x  # logits (N, 1)

def compute_3d_gradcam(model, input_tensor, target_class=1, layer_index=-3):
    """
    Computes Grad-CAM using logits and earlier layers.

    Args:
        model: The trained model
        input_tensor: (1, C, D, H, W)
        target_class: 1 for Converter evidence, 0 for Non-Converter evidence
        layer_index: Index of Conv3d layer to use (negative indexing recommended)
    """
    # 1. Use wrapper to get logits
    logit_model = LogitWrapper(model).eval()

    # 2. Pick the correct target Conv3d layer
    convs = [m for m in model.features.modules() if isinstance(m, nn.Conv3d)]
    if len(convs) >= abs(layer_index):
        target_layer = convs[layer_index]
    else:
        target_layer = convs[-1]
        print(f"Warning: Requested layer index {layer_index} out of bounds. Using last conv layer.")

    # 3. Grad-CAM with correct target (scalar output => target=0)
    lgc = LayerGradCam(logit_model, target_layer)

    # relu_attributions=False gets us the raw contribution
    attr_raw = lgc.attribute(input_tensor, target=0, relu_attributions=False)

    if target_class == 1:
        # Evidence FOR class 1 (positive contribution to logit)
        attr = torch.relu(attr_raw)
    else:
        # Evidence FOR class 0 (negative contribution to logit)
        # A negative contribution to the logit decreases the probability of class 1,
        # effectively acting as evidence for class 0.
        attr = torch.relu(-attr_raw)

    # 4. Upsample correctly (trilinear)
    # Input shape to interpolate must be (N, C, D, H, W)
    attr = F.interpolate(attr, size=input_tensor.shape[2:], mode="trilinear", align_corners=False)

    # Convert to numpy
    h = attr.squeeze().detach().cpu().numpy()

    return h

def get_vis_settings(modality):
    """Returns visualization settings for specific modalities."""
    presets = {
        'tau': {'mask_thr': 0.08, 'cam_thr_p': 90, 'sigma': 0.8, 'alpha': 0.45},
        'fdg': {'mask_thr': 0.08, 'cam_thr_p': 95, 'sigma': 0.6, 'alpha': 0.35}
    }
    return presets.get(modality, presets['tau'])

def process_cam_image(img, hmap, settings):
    """
    Helper to process image and CAM for visualization/export.
    Returns: imgw (normalized image), hmap_norm (normalized CAM), brain_mask
    """
    # 1. Image Normalization
    p1, p99 = np.percentile(img, (1, 99))
    imgw = np.clip(img, p1, p99)
    imgw = (imgw - p1) / (p99 - p1 + 1e-8)

    # 2. Brain Masking
    brain_mask = imgw > settings['mask_thr']
    brain_mask = binary_closing(binary_opening(brain_mask, iterations=1), iterations=2)
    brain_mask = np.ones_like(imgw, dtype=bool)

    # 3. Prepare CAM (Rectify -> Smooth -> Mask -> Threshold)
    hmap_rect = np.maximum(hmap, 0)
    hmap_rect = gaussian_filter(hmap_rect, sigma=settings['sigma'])
    hmap_rect[~brain_mask] = 0

    # Threshold CAM
    vals = hmap_rect[brain_mask]
    if vals.size > 0:
        thr = np.percentile(vals, settings['cam_thr_p'])
        hmap_rect[hmap_rect < thr] = 0

    # 4. Normalize using ONLY masked voxels
    vals = hmap_rect[brain_mask]
    cap = np.percentile(vals, 99) if vals.size > 0 else (np.max(hmap_rect) + 1e-8)
    hmap_norm = np.clip(hmap_rect, 0, cap) / (cap + 1e-8)

    return imgw, hmap_norm, brain_mask

def plot_explainability(image_tensor, heatmap, title="Grad-CAM", save_path=None, modality='tau'):
    """
    Plots improved visualization using modality-specific presets.
    """
    # Get settings
    settings = get_vis_settings(modality)

    # Prepare data
    img = image_tensor.squeeze().cpu().numpy() if torch.is_tensor(image_tensor) else image_tensor.squeeze()
    hmap = heatmap.squeeze()

    # Process
    imgw, hmap_norm, brain_mask = process_cam_image(img, hmap, settings)

    # 5. Show the most relevant slice
    slice_idx = int(np.argmax(hmap_norm.sum(axis=(1, 2))))

    # Plotting
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(f"{title} (Modality: {modality}, Best Slice: {slice_idx})", fontsize=14)

    # Axial (Best Slice)
    axes[0].imshow(imgw[slice_idx, :, :], cmap='gray')
    axes[0].imshow(hmap_norm[slice_idx, :, :], cmap='hot', alpha=settings['alpha'])
    axes[0].set_title(f'Axial (Slice {slice_idx})')
    axes[0].axis('off')

    # Sagittal (Middle)
    sag_slice = img.shape[2] // 2
    axes[1].imshow(imgw[:, :, sag_slice], cmap='gray')
    axes[1].imshow(hmap_norm[:, :, sag_slice], cmap='hot', alpha=settings['alpha'])
    axes[1].set_title(f'Sagittal (Slice {sag_slice})')
    axes[1].axis('off')

    # Coronal (Middle)
    cor_slice = img.shape[1] // 2
    axes[2].imshow(imgw[:, cor_slice, :], cmap='gray')
    axes[2].imshow(hmap_norm[:, cor_slice, :], cmap='hot', alpha=settings['alpha'])
    axes[2].set_title(f'Coronal (Slice {cor_slice})')
    axes[2].axis('off')

    plt.tight_layout()

    if save_path is not None:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to: {save_path}")

    plt.show()

### Training and Validation Functions

In [None]:
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score

def set_bn_eval(m):
    """
    Sets BatchNorm layers to eval mode.
    This prevents the running mean/var from being updated and
    uses the robust pre-trained statistics instead of the noisy batch statistics.
    """
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

def train_epoch(model, train_loader, optimizer, criterion, device, freeze_bn=False):
    model.train()
    if freeze_bn:
        model.apply(set_bn_eval)

    running_loss = 0.0

    for images, labels in train_loader:
        # Move data to GPU as early as possible
        images, labels = images.to(device), labels.to(device)

        # Apply GPU augmentations
        # RandAffine expects (C, spatial...), so we apply it to each image in the batch
        images = torch.stack([gpu_augmentations(img) for img in images])

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs.squeeze(1), labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    return running_loss / len(train_loader.dataset)

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            # Handle batch size 1 for batch normalization
            if images.size(0) == 1:
                # Temporarily set model to eval mode within the loop
                # This is a workaround for batchnorm with batch_size=1
                original_training_mode = model.training
                model.eval()
                outputs = model(images)
                model.train(original_training_mode) # Restore original training mode
            else:
                outputs = model(images)

            loss = criterion(outputs.squeeze(1), labels)

            running_loss += loss.item() * images.size(0)
            all_predictions.extend(outputs.squeeze(1).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_loss = running_loss / len(val_loader.dataset)

    # Calculate metrics
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    binary_predictions = (all_predictions > 0.5).astype(int)

    auc = roc_auc_score(all_labels, all_predictions)
    accuracy = accuracy_score(all_labels, binary_predictions)
    balanced_acc = balanced_accuracy_score(all_labels, binary_predictions)

    return val_loss, auc, accuracy, balanced_acc

### K-Fold Training Loop

In [None]:
import joblib
study = joblib.load("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/hyperparameter_study.pkl")

# Extract best hyperparameters
best_params = study.best_trial.params
OPTIMAL_LR = best_params["lr"]
#print(f"Using optimal learning rate from study: {OPTIMAL_LR}")

# Store results for all folds
fold_results = []
all_training_logs = []  # Changed: List to store full history

for fold in range(1, NUM_FOLDS + 1):
    print(f"\n{'='*50}")
    print(f"Training Fold {fold}/{NUM_FOLDS}")
    print(f"{'='*50}")

    # Load data for this fold
    train_dataset = ADNIDataset(kfold_path / f"train_fold_{fold}.pkl")
    val_dataset = ADNIDataset(kfold_path / f"val_fold_{fold}.pkl")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

    print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

    # Instantiate the exact same model architecture
    model = TunableCNN3D(
        n_layers=best_params["n_layers"],
        base_filters=best_params["base_filters"],
        dropout_rate=DROPOUT,
        dense_units=best_params["dense_units"]
    ).to(DEVICE)


    # Load pre-trained AD/CN model if available
    if pretrained_model_path.exists():
        print("Loading pre-trained AD/CN model...")
        model.load_state_dict(torch.load(pretrained_model_path))
    else:
        print("Pre-trained model not found. Training from scratch.")

    # --- RE-INITIALIZE CLASSIFIER HEAD ---
    # The task is different (MCI conversion vs AD/CN), so starting from
    # random weights for the classifier helps avoid getting stuck.
    print("Re-initializing classifier head...")
    for layer in model.classifier:
        if isinstance(layer, nn.Linear):
            nn.init.xavier_uniform_(layer.weight)
            if layer.bias is not None:
                nn.init.zeros_(layer.bias)

    # --- STAGE 1: FREEZE FEATURE EXTRACTOR AND TRAIN CLASSIFIER ---
    print("Stage 1: Training the classifier head only...")

    # Freeze convolutional layers
    for param in model.features.parameters():
        param.requires_grad = False

    # Optimizer for the classifier ONLY
    optimizer = optim.Adam(model.classifier.parameters(), lr=C_LEARNING_RATE)
    criterion = nn.BCELoss()  # Standard BCE Loss

    # Train for a few epochs to warm up the new classifier
    WARMUP_EPOCHS = 20
    for epoch in range(WARMUP_EPOCHS):
        # Freeze BN even in stage 1
        train_loss = train_epoch(model, train_loader, optimizer, criterion, DEVICE, freeze_bn=True)
        # Optionally run validation here for logging
        if (epoch + 1) % 5 == 0:
            val_loss, val_auc, val_accuracy, val_bacc = validate_epoch(model, val_loader, criterion, DEVICE)
            print(f"  Warm-up Epoch {epoch+1}/{WARMUP_EPOCHS}, Train Loss: {train_loss:.4f}, Val AUC: {val_auc:.4f}, Val B-Acc: {val_bacc:.4f}")

    # --- STAGE 2: UNFREEZE ALL LAYERS AND FINE-TUNE ---
    print("Stage 2: Fine-tuning the entire model...")

    # Unfreeze all layers
    for param in model.parameters():
        param.requires_grad = True

    # Re-create the optimizer for the whole model with a VERY LOW learning rate
    optimizer = optim.Adam(model.parameters(), lr=FT_LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=10, factor=0.5)  # Note: mode='max' for AUC

    # Training tracking
    best_val_auc = 0.0
    best_epoch = 0
    fold_history = []

    # Training loop
    for epoch in range(EPOCHS):
        # Freeze BN during fine-tuning (Crucial!)
        train_loss = train_epoch(model, train_loader, optimizer, criterion, DEVICE, freeze_bn=True)
        val_loss, val_auc, val_accuracy, val_bacc = validate_epoch(model, val_loader, criterion, DEVICE)

        scheduler.step(val_auc)  # Use AUC for scheduler

        # Save epoch results
        epoch_results = {
            'fold': fold,
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_auc': val_auc,
            'val_accuracy': val_accuracy,
            'val_balanced_accuracy': val_bacc
        }
        fold_history.append(epoch_results)

        # Save best model
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_epoch = epoch + 1
            torch.save(model.state_dict(), output_model_path / f"mci_model_fold_{fold}_best.pth")

        # Print progress
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:2d}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
                  f"Val AUC: {val_auc:.4f}, Val Acc: {val_accuracy:.4f}, Val B-Acc: {val_bacc:.4f}")

    # Add fold history to global logs
    all_training_logs.extend(fold_history)

    # Store fold results
    fold_summary = {
        'fold': fold,
        'best_epoch': best_epoch,
        'best_val_auc': best_val_auc,
        'final_train_loss': train_loss,
        'final_val_loss': val_loss,
        'final_val_accuracy': val_accuracy,
        'final_val_balanced_accuracy': val_bacc
    }
    fold_results.append(fold_summary)

    print(f"Fold {fold} completed. Best AUC: {best_val_auc:.4f} at epoch {best_epoch}")

# Save detailed training history (per epoch)
history_df = pd.DataFrame(all_training_logs)
history_df.to_csv(output_model_path / "training_history_full.csv", index=False)
print(f"Detailed training history saved to {output_model_path / 'training_history_full.csv'}")

# Save fold summary
results_df = pd.DataFrame(fold_results)
results_df.to_csv(output_model_path / "training_summary.csv", index=False)


print(f"\n{'='*60}")
print("TRAINING COMPLETED")
print(f"{'='*60}")
print("\nFold Summary:")
print(results_df)
print(f"\nMean AUC across folds: {results_df['best_val_auc'].mean():.4f} ± {results_df['best_val_auc'].std():.4f}")
print(f"Mean Accuracy across folds: {results_df['final_val_accuracy'].mean():.4f} ± {results_df['final_val_accuracy'].std():.4f}")
print(f"Mean Balanced Acc across folds: {results_df['final_val_balanced_accuracy'].mean():.4f} ± {results_df['final_val_balanced_accuracy'].std():.4f}")

### Example: Visualizing Model Explanations with Grad-CAM

After training, you can use Grad-CAM to visualize what regions of the brain the model focuses on when making predictions.

In [None]:
# Example: Generate Grad-CAM visualizations for a sample from validation set
# Run this cell after training is complete
import joblib
import warnings
import nibabel as nib
from google.colab import files

# --- SET MODALITY PRESET HERE ---
MODALITY = 'fdg'  # Options: 'tau' or 'fdg'
# --------------------------------

# Define paths
study_path = "/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/hyperparameter_study.pkl"
fold_to_analyze = 4 if MODALITY == 'tau' else 5
model_path = output_model_path / f"mci_model_fold_{fold_to_analyze}_best.pth"
explainability_output_path = output_model_path / "gradcam_visualizations"
explainability_output_path.mkdir(exist_ok=True)

# Try to load parameters from the study file
try:
    # Load hyperparameters to make this cell standalone
    study = joblib.load(study_path)
    best_params = study.best_trial.params
    print("Successfully loaded parameters from study file.")
except Exception as e:
    print(f"Could not load study file: {e}")
    print("Using manual fallback parameters.")
    best_params = {
        "n_layers": 3,
        "base_filters": 16,
        "dropout_rate": 0.2,
        "dense_units": 64,
    }

# Load the model architecture
model = TunableCNN3D(
    n_layers=best_params["n_layers"],
    base_filters=best_params["base_filters"],
    dropout_rate=best_params["dropout_rate"],
    dense_units=best_params["dense_units"]
).to(DEVICE)

# Load trained weights
if model_path.exists():
    model.load_state_dict(torch.load(model_path))
    print(f"Loaded model weights from {model_path}")
else:
    print(f"Warning: Model path {model_path} does not exist. Using random weights.")

model.eval()

# Get validation loader (NO AUGMENTATION)
val_dataset = ADNIDataset(kfold_path / f"val_fold_{fold_to_analyze}.pkl")
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

print("Searching for a correctly predicted 'Converter' case (True Positive)...")

found_sample = False
max_samples_to_check = 200  # Avoid infinite loops

for i, (images, labels) in enumerate(val_loader):
    if i >= max_samples_to_check:
        print(f"Checked {max_samples_to_check} samples without finding a match.")
        break

    input_img = images[0:1].to(DEVICE)
    true_label = labels[0].item()

    # We only want True Label = 1 (Converter)
    if true_label != 1:
        continue

    # Get prediction
    with torch.no_grad():
        wrapper = LogitWrapper(model)
        logit = wrapper(input_img).item()
        prob = torch.sigmoid(torch.tensor(logit)).item()

    # Check if correctly predicted as Converter (Prob > 0.8)
    if prob > 0.7:
        found_sample = True
        print(f"Found match at sample index {i}!")
        print(f"True Label: {true_label} (Converter)")
        print(f"Model Prediction (Prob): {prob:.4f}")
        break

if found_sample:
    predicted_class = 'Converter'
    true_class = 'Converter'

    # Compute Grad-CAM heatmap
    # Explain the predicted class (Converter => 1)
    heatmap = compute_3d_gradcam(model, input_img, target_class=1, layer_index=-3)

    # Create filename
    filename = f"gradcam_fold{fold_to_analyze}_TRUE_POSITIVE_prob{prob:.3f}_{MODALITY}.png"
    save_path = explainability_output_path / filename

    # Visualize (uses centralized logic via modality arg)
    plot_explainability(
        input_img,
        heatmap,
        title=f"Grad-CAM (Logits, Layer -3): Correct Converter Prediction",
        save_path=save_path,
        modality=MODALITY
    )

    # --- NIfTI EXPORT FOR 3D SLICER ---
    print("\nPreparing NIfTI files for 3D Slicer...")

    # Prepare Numpy arrays
    img_np = input_img.squeeze().detach().cpu().numpy()
    h_np = heatmap.squeeze() # Raw unsmoothed map

    # Use centralized processing to ensure consistency with the plot
    settings = get_vis_settings(MODALITY)
    _, h_norm, _ = process_cam_image(img_np, h_np, settings)

    # Create NIfTI files
    # Using identity affine since spacing is generic
    affine = np.eye(4)

    # Save raw PET image
    pet_nii = nib.Nifti1Image(img_np.astype(np.float32), affine)
    # Save processed (smoothed/masked) Heatmap
    cam_nii = nib.Nifti1Image(h_norm.astype(np.float32), affine)

    # Save locally
    nib.save(pet_nii, "pet_image.nii.gz")
    nib.save(cam_nii, "gradcam_heatmap.nii.gz")

    # (Optional) Save a thresholded CAM mask
    h_thr = (h_norm > 0.4).astype(np.float32)
    nib.save(nib.Nifti1Image(h_thr, affine), "gradcam_mask.nii.gz")

    print("Files saved locally: pet_image.nii.gz, gradcam_heatmap.nii.gz, gradcam_mask.nii.gz")

    # Download files
    # print("Triggering downloads...")
    # files.download("pet_image.nii.gz")
    # files.download("gradcam_heatmap.nii.gz")
    # files.download("gradcam_mask.nii.gz")

else:
    print("Could not find a correctly predicted Converter case in the searched samples.")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from pathlib import Path

MODALITY='fdg'

# --- CONFIGURATION ---
OUTPUT_DIR = output_model_path / "paper_figures"
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
CONFIDENCE_THRESHOLD = 0.70  # Threshold for "High Confidence"
NUM_SAMPLES = 10        # Number of samples to include in the combined figure
MODALITY_SETTINGS = get_vis_settings(MODALITY)

print(f"Searching for {NUM_SAMPLES} high-confidence (>{CONFIDENCE_THRESHOLD}) Converter samples...")

# 1. Collect High Confidence Samples
high_conf_samples = []
model.eval()

# Iterate through validation loader to find samples
# Note: Assuming val_loader is already defined from previous cells with batch_size=1
for i, (images, labels) in enumerate(val_loader):
    if len(high_conf_samples) >= NUM_SAMPLES:
        break

    true_label = labels[0].item()
    # We want True Positives (Converters)
    if true_label != 1:
        continue

    input_img = images.to(DEVICE)

    with torch.no_grad():
        wrapper = LogitWrapper(model)
        logit = wrapper(input_img).item()
        prob = torch.sigmoid(torch.tensor(logit)).item()

    if prob > CONFIDENCE_THRESHOLD:
        print(f"Found Sample {len(high_conf_samples)+1}: Index {i}, Prob: {prob:.4f}")

        # Compute GradCAM
        heatmap = compute_3d_gradcam(model, input_img, target_class=1, layer_index=-3)

        # Prepare Data for Plotting
        img_np = input_img.squeeze().cpu().numpy()
        hmap_np = heatmap.squeeze()

        # Process (Normalize, Mask, Smooth)
        imgw, hmap_norm, brain_mask = process_cam_image(img_np, hmap_np, MODALITY_SETTINGS)

        # Store data
        high_conf_samples.append({
            'idx': i,
            'prob': prob,
            'img': imgw,
            'hmap': hmap_norm
        })

if not high_conf_samples:
    print("No high confidence samples found. Try lowering the threshold.")
else:
    print(f"\nProcessing {len(high_conf_samples)} samples for paper figures...")

# 2. Helper function to plot a single slice clean
def plot_clean_slice(ax, img_slice, hmap_slice, alpha):
    ax.imshow(img_slice, cmap='gray')
    ax.imshow(hmap_slice, cmap='hot', alpha=alpha)
    ax.axis('off')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')

# 3. Generate Individual Figures and Per-Patient Strips
alpha = MODALITY_SETTINGS['alpha']

for sample_idx, sample in enumerate(high_conf_samples):
    img = sample['img']
    hmap = sample['hmap']
    orig_idx = sample['idx']
    prob = sample['prob']

    # Use middle slices as requested
    # Slice indices: (Depth, Height, Width) -> (Axial, Coronal, Sagittal) approx for standard orientation
    slice_ax = img.shape[0] // 2
    slice_cor = img.shape[1] // 2
    slice_sag = img.shape[2] // 2

    # --- A. Save Individual Slices (Axial, Sagittal, Coronal) ---
    slices = {
        'axial': (img[slice_ax, :, :], hmap[slice_ax, :, :]),
        'coronal': (img[:, slice_cor, :], hmap[:, slice_cor, :]), # Note orientation usually needs flipping/rotation depending on lib
        'sagittal': (img[:, :, slice_sag], hmap[:, :, slice_sag])
    }

    # Correcting orientation for visualization if necessary (often needed for Coronal/Sagittal in matplotlib)
    # Assuming standard plotting: Sagittal and Coronal often need 90 deg rot or flip.
    # Adjusting strictly for "clean view".

    # Save each slice individually
    for view_name, (img_s, hmap_s) in slices.items():
        fig, ax = plt.subplots(figsize=(4, 4))
        # Fix orientation for non-axial usually involves np.rot90
        if view_name != 'axial':
             img_s = np.rot90(img_s)
             hmap_s = np.rot90(hmap_s)

        plot_clean_slice(ax, img_s, hmap_s, alpha)

        # Save without border
        fname = OUTPUT_DIR / f"patient_{orig_idx}_{view_name}.png"
        plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
        plt.margins(0,0)
        plt.savefig(fname, bbox_inches='tight', pad_inches=0, dpi=300)
        plt.close(fig)

    # --- B. Save 3-Slice Strip WITH PROBABILITY (Next to each other) ---
    # Using 4 columns: 3 for slices, 1 for text
    # Width ratios favor the images (1) over the text (0.3)
    fig, axes = plt.subplots(1, 4, figsize=(14, 4), gridspec_kw={'width_ratios': [1, 1, 1, 0.3]})

    # Axial
    plot_clean_slice(axes[0], slices['axial'][0], slices['axial'][1], alpha)
    # Sagittal (Middle in paper standard often, or Coronal)
    plot_clean_slice(axes[1], np.rot90(slices['sagittal'][0]), np.rot90(slices['sagittal'][1]), alpha)
    # Coronal
    plot_clean_slice(axes[2], np.rot90(slices['coronal'][0]), np.rot90(slices['coronal'][1]), alpha)

    # Probability Text
    axes[3].axis('off')
    axes[3].text(0.1, 0.5, f"Prob:\n{prob:.4f}", va='center', ha='left', fontsize=16)

    # CHANGE: wspace=0 to remove vertical separation lines
    plt.subplots_adjust(wspace=0, hspace=0)
    fname_strip = OUTPUT_DIR / f"patient_{orig_idx}_combined_strip.png"
    plt.savefig(fname_strip, bbox_inches='tight', pad_inches=0, dpi=300)
    plt.close(fig)

print(f"Individual figures and strips saved to {OUTPUT_DIR}")

# 4. Generate Combined 4-Patient Figure (Grid)
if len(high_conf_samples) > 0:
    # Create a figure with N rows and 4 columns (3 slices + 1 text)
    n_rows = len(high_conf_samples)
    fig, axes = plt.subplots(n_rows, 4, figsize=(12, 3 * n_rows), gridspec_kw={'width_ratios': [1, 1, 1, 0.3]})

    # Handle case if n_rows=1 where axes is 1D
    if n_rows == 1:
        axes = np.expand_dims(axes, axis=0)

    for i, sample in enumerate(high_conf_samples):
        img = sample['img']
        hmap = sample['hmap']
        prob = sample['prob']

        # Use middle slices as requested
        slice_ax = img.shape[0] // 2
        slice_cor = img.shape[1] // 2
        slice_sag = img.shape[2] // 2

        # Axial
        plot_clean_slice(axes[i, 0], img[slice_ax, :, :], hmap[slice_ax, :, :], alpha)

        # Sagittal
        img_sag = np.rot90(img[:, :, slice_sag])
        hmap_sag = np.rot90(hmap[:, :, slice_sag])
        plot_clean_slice(axes[i, 1], img_sag, hmap_sag, alpha)

        # Coronal
        img_cor = np.rot90(img[:, slice_cor, :])
        hmap_cor = np.rot90(hmap[:, slice_cor, :])
        plot_clean_slice(axes[i, 2], img_cor, hmap_cor, alpha)

        # Probability Text
        axes[i, 3].axis('off')
        axes[i, 3].text(0.1, 0.5, f"Prob:\n{prob:.4f}", va='center', ha='left', fontsize=14)

    # CHANGE: wspace=0 to remove vertical separation, hspace=0.05 for horizontal separation
    plt.subplots_adjust(wspace=0, hspace=0.05)
    fname_composite = OUTPUT_DIR / f"composite_paper_figure_{MODALITY}.png"
    plt.savefig(fname_composite, bbox_inches='tight', dpi=300)
    plt.show()
    print(f"Composite figure saved to {fname_composite}")

In [None]:
import numpy as np
import torch
from pathlib import Path
from tqdm import tqdm
import joblib # Explicitly import joblib for loading study file

# Assuming `TunableCNN3D`, `ADNIDataset`, `get_vis_settings`, `process_cam_image`,
# `LogitWrapper`, `compute_3d_gradcam` are defined in previous cells or globally available.

def collect_high_conf_truepos_gradcams(
    model,
    loader,
    device,
    modality: str,
    conf_thr: float = 0.70,
    max_n: int = 40,
    target_class: int = 1,
    layer_index: int = -4,
):
    """
    Collects PET volumes and Grad-CAM heatmaps for true-label converters (label==1)
    with model probability > conf_thr.

    Returns:
        imgs:  list of (D,H,W) numpy arrays
        cams:  list of (D,H,W) numpy arrays in [0,1] from your compute_3d_gradcam
        probs: list of float probabilities
    """
    model.eval()
    imgs, cams, probs = [], [], []
    wrapper = LogitWrapper(model).to(device).eval()

    # Get visualization settings for processing
    settings = get_vis_settings(modality)

    for images, labels in tqdm(loader, desc=f"Collecting {modality} high-conf TP"):
        # assumes batch_size=1
        y = float(labels[0].item())
        if y != 1.0:
            continue

        x = images.to(device)
        with torch.no_grad():
            logit = wrapper(x).item()
            p = torch.sigmoid(torch.tensor(logit)).item()

        if p < conf_thr:
            continue

        # 1. Compute Raw Grad-CAM (returns ONLY heatmap)
        h_raw = compute_3d_gradcam(
            model=model,
            input_tensor=x,
            target_class=target_class,
            layer_index=layer_index
        )

        # 2. Process (Normalize, Mask, etc.) using the helper function
        img_np = x.squeeze().cpu().numpy()
        imgw, hmap_norm, _ = process_cam_image(img_np, h_raw, settings)

        # imgw/hmap_norm are numpy (D,H,W)
        imgs.append(imgw.astype(np.float32))
        cams.append(hmap_norm.astype(np.float32))
        probs.append(float(p))

        if len(imgs) >= max_n:
            break

    return imgs, cams, probs


def aggregate_maps(imgs, cams, agg="mean"):
    """
    Aggregate volumes across subjects: mean or median.
    """
    imgs = np.stack(imgs, axis=0)   # (N,D,H,W)
    cams = np.stack(cams, axis=0)   # (N,D,H,W)

    if agg == "median":
        img_agg = np.median(imgs, axis=0)
        cam_agg = np.median(cams, axis=0)
    else:
        img_agg = np.mean(imgs, axis=0)
        cam_agg = np.mean(cams, axis=0)

    return img_agg, cam_agg


# Load best hyperparameters to define model architecture
# This assumes the hyperparameter_study.pkl exists and TunableCNN3D class is defined.
study = joblib.load("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/hyperparameter_study.pkl")
best_params = study.best_trial.params

# -------- RUN PER MODALITY (Aggregating Across All Folds) --------
MODALITY = "fdg"  # <-- change to "fdg" then "tau" if needed elsewhere, but for cross-fold aggregation, we process one modality at a time.
CONF_THR = 0.70
MAX_N_PER_FOLD = 30  # Max number of samples to collect per fold
AGG = "mean"        # or "median"
LAYER_INDEX = -3    # keep consistent across modalities

OUTPUT_DIR = output_model_path / "thesis_figures"
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# Global lists to store results from all folds
all_imgs = []
all_cams = []
all_probs = []

print(f"Aggregating Grad-CAMs for {MODALITY} across {NUM_FOLDS} folds...")

for fold in range(1, NUM_FOLDS + 1):
    print(f"Processing Fold {fold}...")
    # 1. Instantiate the model for the current fold using best hyperparameters
    model = TunableCNN3D(
        n_layers=best_params["n_layers"],
        base_filters=best_params["base_filters"],
        dropout_rate=best_params["dropout_rate"], # Use dropout from best_params
        dense_units=best_params["dense_units"]
    ).to(DEVICE)

    # 2. Load the best trained weights for the current fold
    model_fold_path = output_model_path / f"mci_model_fold_{fold}_best.pth"
    if model_fold_path.exists():
        model.load_state_dict(torch.load(model_fold_path))
    else:
        print(f"Warning: Model for fold {fold} not found at {model_fold_path}. Skipping this fold.")
        continue # Skip to next fold if model not found

    model.eval() # Ensure model is in evaluation mode for inference

    # 3. Load the validation dataset for the current fold
    val_dataset = ADNIDataset(kfold_path / f"val_fold_{fold}.pkl")
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

    # 4. Collect high-confidence true positives and their Grad-CAMs for this fold
    fold_imgs, fold_cams, fold_probs = collect_high_conf_truepos_gradcams(
        model=model,
        loader=val_loader,
        device=DEVICE,
        modality=MODALITY,
        conf_thr=CONF_THR,
        max_n=MAX_N_PER_FOLD, # Apply MAX_N_PER_FOLD for each fold
        layer_index=LAYER_INDEX
    )

    # 5. Extend global lists with results from the current fold
    all_imgs.extend(fold_imgs)
    all_cams.extend(fold_cams)
    all_probs.extend(fold_probs)

# After processing all folds, print summary and aggregate globally
print(f"\n{MODALITY}: collected total N={len(all_imgs)} high-confidence true-positive samples "
      f"(mean p={np.mean(all_probs):.3f}, min={np.min(all_probs):.3f}, max={np.max(all_probs):.3f})")

# 6. Aggregate across all collected samples from all folds
img_mean, cam_mean = aggregate_maps(all_imgs, all_cams, agg=AGG)

# Save the aggregated results for all folds
np.savez_compressed(
    OUTPUT_DIR / f"agg_gradcam_{MODALITY}_all_folds.npz", # Changed filename to reflect aggregation across all folds
    img=img_mean,
    cam=cam_mean,
    probs=np.array(all_probs),
    conf_thr=CONF_THR,
    max_n_total_collected=len(all_imgs), # Save total number of samples collected across all folds
    agg=AGG,
    layer_index=LAYER_INDEX
)

print("Saved:", OUTPUT_DIR / f"agg_gradcam_{MODALITY}_all_folds.npz")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
from pathlib import Path

# --- PAPER QUALITY SETTINGS ---
plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "DejaVu Serif"],
    "font.size": 12,
    "axes.titlesize": 12,
    "axes.labelsize": 12
})

def robust_display(img, lo=1, hi=99):
    vmin = np.percentile(img[img > 1e-4], lo) if np.any(img > 1e-4) else np.min(img)
    vmax = np.percentile(img, hi)
    img = np.clip(img, vmin, vmax)
    norm_img = (img - vmin) / (vmax - vmin + 1e-8)
    norm_img[img <= vmin] = 0
    return norm_img

def plot_aggregated_fdg_tau(
    fdg_img, fdg_cam,
    tau_img, tau_cam,
    slices=None,
    save_path=None,
    axis=0,          # 0=Axial
    cam_thr=0.0
):
    """
    Create a high-quality, compact aggregated Grad-CAM figure (Paper Style).
    """
    def get_slice(arr, idx, ax):
        if ax == 0: return arr[idx, :, :]          # Axial
        elif ax == 1: return np.rot90(arr[:, idx, :]) # Coronal
        else: return np.rot90(arr[:, :, idx])      # Sagittal

    if slices is None:
        mid = fdg_img.shape[axis] // 2
        slices = [mid - 12, mid - 6, mid, mid + 6, mid + 12]

    fdg_bg = robust_display(fdg_img)
    tau_bg = robust_display(tau_img)

    ncols = len(slices)

    # --- DYNAMIC FIGURE SIZE CALCULATION ---
    # 1. Get aspect ratio of a single slice to ensure perfect fit
    sample_slice = get_slice(fdg_img, slices[0], axis)
    sh, sw = sample_slice.shape
    img_aspect = sw / sh  # width / height

    # 2. Define layout parameters
    spacer_ratio = 0.05   # Spacer height relative to image height

    # We have 2 rows of images + spacer.
    # Total logical height units = 1 (FDG) + spacer + 1 (Tau)
    total_img_h_units = 2 + spacer_ratio

    # Total logical width units = ncols * width_per_image
    # Since we set height unit = 1 (corresponding to image height 'sh'),
    # width unit should be 'img_aspect' (corresponding to image width 'sw').
    total_img_w_units = ncols * img_aspect

    # Aspect ratio of the DATA part of the figure (Width / Height)
    data_area_aspect = total_img_w_units / total_img_h_units

    # 3. Set Figure Dimensions
    # We fix the height of the image area, calculate width to match aspect,
    # and add space for the colorbar at the bottom.
    data_h_inches = 5.0
    fig_w_inches = data_h_inches * data_area_aspect

    # Reserve space for colorbar at bottom
    bottom_margin_inches = 0.6
    fig_h_inches = data_h_inches + bottom_margin_inches

    # Calculate relative bottom margin for GridSpec
    gs_bottom = bottom_margin_inches / fig_h_inches

    fig = plt.figure(figsize=(fig_w_inches, fig_h_inches), dpi=300)

    # Create GridSpec
    # We use 'gs_bottom' to leave exact space for labels/colorbar
    gs = fig.add_gridspec(
        3, ncols,
        height_ratios=[1, spacer_ratio, 1],   # FDG, Spacer, Tau
        wspace=0,
        hspace=0,
        left=0, right=1, top=1, bottom=gs_bottom
    )

    axes = np.empty((2, ncols), dtype=object)

    # FDG row
    for j in range(ncols):
        axes[0, j] = fig.add_subplot(gs[0, j])

    # Tau row
    for j in range(ncols):
        axes[1, j] = fig.add_subplot(gs[2, j])


    # Calculate relative slice positions for labeling
    mid_idx = slices[len(slices)//2]

    for j, s_idx in enumerate(slices):
        # Slice label (e.g., "Mid", "-6", "+12")
        offset = s_idx - mid_idx
        if offset == 0:
            lbl = "Midline"
        else:
            lbl = f"{offset:+} mm"

        # --- FDG ROW ---
        bg_s = get_slice(fdg_bg, s_idx, axis)
        cam_s = get_slice(fdg_cam, s_idx, axis)
        cam_s_masked = np.ma.masked_where(cam_s < cam_thr, cam_s)

        # Force aspect='equal' just in case, though figsize should handle it
        axes[0, j].imshow(bg_s, cmap="gray", origin='lower', aspect='equal')
        im = axes[0, j].imshow(cam_s_masked, cmap="hot", alpha=0.6, vmin=cam_thr, vmax=1.0, origin='lower', aspect='equal')
        axes[0, j].axis("off")

        # Add in-set text for slice info
        axes[0, j].text(0.05, 0.95, lbl, transform=axes[0, j].transAxes,
                        color='white', fontsize=10, va='top', ha='left', fontweight='bold',
                        path_effects=[path_effects.withStroke(linewidth=2, foreground="black")])

        # --- TAU ROW ---
        bg_s_tau = get_slice(tau_bg, s_idx, axis)
        cam_s_tau = get_slice(tau_cam, s_idx, axis)
        cam_s_tau_masked = np.ma.masked_where(cam_s_tau < cam_thr, cam_s_tau)

        axes[1, j].imshow(bg_s_tau, cmap="gray", origin='lower', aspect='equal')
        axes[1, j].imshow(cam_s_tau_masked, cmap="hot", alpha=0.6, vmin=cam_thr, vmax=1.0, origin='lower', aspect='equal')
        axes[1, j].axis("off")

    # Row Labels (High contrast, rotated)
    # Placed relative to the first axes of each row
    axes[0, 0].text(-0.02, 0.5, "FDG-PET", transform=axes[0, 0].transAxes,
                    va='center', ha='right', fontsize=14, fontweight='bold', rotation=90)
    axes[1, 0].text(-0.02, 0.5, "Tau-PET", transform=axes[1, 0].transAxes,
                    va='center', ha='right', fontsize=14, fontweight='bold', rotation=90)

    # Refined Colorbar in the reserved bottom space
    # Position: [left, bottom, width, height] in figure coordinates
    # Centered horizontally, inside the bottom margin area
    cbar_width = 0.4
    cbar_height = 0.02
    cbar_bottom = (gs_bottom - cbar_height) / 2 # Center vertically in the margin

    cbar_ax = fig.add_axes([0.5 - cbar_width/2, cbar_bottom, cbar_width, cbar_height])
    cbar = fig.colorbar(im, cax=cbar_ax, orientation='horizontal')
    cbar.set_label(f"Relevance Probability (Grad-CAM > {cam_thr})", fontsize=11)
    cbar.outline.set_visible(False)
    cbar.ax.tick_params(size=0)

    if save_path:
        save_path = Path(save_path)
        save_path.parent.mkdir(exist_ok=True, parents=True)
        # bbox_inches='tight' might mess up our perfect calculation,
        # but usually it's safe if we used standard margins.
        # Since we manually controlled everything, we can try saving without it or with it.
        # 'tight' is safer for text labels extending out.
        plt.savefig(save_path, bbox_inches="tight", dpi=300)
        print("Saved figure to:", save_path)

    plt.show()


# -------- LOAD AND PLOT --------
# Preserving your exact paths
base_fdg_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/mci_conversion/thesis_figures")
base_tau_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/mci_conversion_tau/thesis_figures")


if (base_fdg_path / "agg_gradcam_fdg.npz").exists() and (base_tau_path / "agg_gradcam_tau.npz").exists():
    fdg_npz = np.load(base_fdg_path / "agg_gradcam_fdg.npz")
    tau_npz = np.load(base_tau_path / "agg_gradcam_tau.npz")

    # Check shape to define sensible slices
    D_dim = fdg_npz["img"].shape[0]
    mid = D_dim // 2
    center_slices = [mid - 12, mid - 6, mid, mid + 6, mid + 12]

    plot_aggregated_fdg_tau(
        fdg_img=fdg_npz["img"],
        fdg_cam=fdg_npz["cam"],
        tau_img=tau_npz["img"],
        tau_cam=tau_npz["cam"],
        slices=center_slices,
        save_path=base_fdg_path / "FIG_agg_gradcam_fdg_vs_tau_coronal_PAPER.png",
        axis=1,          # 0 = AXIAL
        cam_thr=0.25


    )
else:
    print("Could not find .npz files.")