# 7. Monte Carlo Dropout for Uncertainty Estimation

This notebook applies Monte Carlo (MC) dropout to the trained MCI conversion models to estimate prediction uncertainty. By performing multiple forward passes with dropout enabled at inference time, we can obtain a distribution of predictions for each subject.

The key steps are:

1.  **Custom MCDropout Layer**: A custom dropout layer is defined that remains active during inference.
2.  **Modified Model**: The trained 3D CNN models are loaded, and their standard dropout layers are replaced with the new `MCDropout` layers.
3.  **Iterative Inference**: For each subject in each validation fold, the model performs 500 forward passes, generating a distribution of predictions.
4.  **Uncertainty Calculation**: The mean and standard deviation of these prediction distributions are calculated. The standard deviation serves as a measure of the model's uncertainty for that prediction.
5.  **Results Aggregation**: The results for all subjects across all folds (ground truth label, mean prediction, and standard deviation) are aggregated and saved into a single pandas DataFrame for later analysis.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
!pip install optuna

### Define Paths and Parameters

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Paths
kfold_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/data/tau/kfold")
ad_cn_model_dir = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/")  # Path to the AD/CN model directory
model_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/mci_conversion_tau/")
output_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/mci_conversion_tau/monte_carlo/")
output_path.mkdir(exist_ok=True)

# Create visualization directory
figures_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/reports/figures/fdg/uncertainty_distributions/")
figures_path.mkdir(parents=True, exist_ok=True)

# Parameters
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_FOLDS = 5
MC_ITERATIONS = 500

In [None]:
DEVICE

### Custom Dataset and MCDropout Model

In [None]:
import joblib

class ADNIDataset(Dataset):
    def __init__(self, pkl_file):
        with open(pkl_file, 'rb') as f:
            data = pickle.load(f)
        self.images = data["images"]
        self.labels = data["labels"]
        self.subject_ids = data["subject_ids"]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx].unsqueeze(0), self.labels[idx].float(), self.subject_ids[idx]

class MCDropout(nn.Dropout):
    def forward(self, x):
        # Enable dropout during eval mode for MC inference
        return nn.functional.dropout(x, self.p, True, self.inplace)

# This is the model architecture from notebook 03 and 06
class TunableCNN3D(nn.Module):
    def __init__(self, n_layers, base_filters, dropout_rate, dense_units):
        super(TunableCNN3D, self).__init__()

        layers = []
        in_channels = 1

        for i in range(n_layers):
            out_channels = base_filters * (2 ** i)
            if i == 0:
                layers.extend([
                    nn.Conv3d(in_channels, out_channels, kernel_size=3),
                    nn.ReLU(),
                    nn.MaxPool3d(2),
                    nn.BatchNorm3d(out_channels)
                ])
            else:
                layers.extend([
                    nn.Conv3d(in_channels, out_channels, kernel_size=3, padding='same'),
                    nn.ReLU(),
                    nn.MaxPool3d(2),
                    nn.BatchNorm3d(out_channels),
                    nn.Dropout(dropout_rate)  # We will replace this with MCDropout
                ])
            in_channels = out_channels

        self.features = nn.Sequential(*layers)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(in_channels, dense_units),
            nn.ReLU(),
            nn.Dropout(dropout_rate),  # And this one
            nn.Linear(dense_units, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

def convert_to_mc_dropout(module, dropout_p):
    """Recursively replaces nn.Dropout with MCDropout."""
    for name, child in module.named_children():
        if isinstance(child, nn.Dropout):
            setattr(module, name, MCDropout(p=dropout_p))
        else:
            convert_to_mc_dropout(child, dropout_p)

### Perform Monte Carlo Dropout Inference

In [None]:
all_results = []

# Load the hyperparameter study to get the best model architecture
try:
    study = joblib.load(ad_cn_model_dir / "hyperparameter_study.pkl")
    best_params = study.best_params
    print("Successfully loaded best hyperparameters from study.")
except FileNotFoundError:
    print("ERROR: Could not find 'hyperparameter_study.pkl'. Cannot determine model architecture.")
    # You might want to hardcode the best params here as a fallback if the file is missing
    best_params = None

if best_params:
    for i in range(1, NUM_FOLDS + 1):
        print(f"\n--- Processing Fold {i} ---")

        # 1. Instantiate the model with the best architecture
        model = TunableCNN3D(
            n_layers=best_params['n_layers'],
            base_filters=best_params['base_filters'],
            dropout_rate=best_params['dropout_rate'],
            dense_units=best_params['dense_units']
        )

        # 2. Convert to an MC Dropout model
        convert_to_mc_dropout(model, dropout_p=best_params['dropout_rate'])
        model.to(DEVICE)

        # 3. Load the fine-tuned weights for the specific fold
        model_file = model_path / f"mci_model_fold_{i}_best.pth"
        if not model_file.exists():
            print(f"Model file not found for fold {i}, skipping.")
            continue

        model.load_state_dict(torch.load(model_file, map_location=DEVICE))
        model.eval()  # Set to eval mode, but our MCDropout layers will still be active

        # 4. Load data and perform inference
        val_dataset = ADNIDataset(kfold_path / f"val_fold_{i}.pkl")
        val_loader = DataLoader(val_dataset, batch_size=1)

        for image, label, subject_id in tqdm(val_loader, desc=f"MC Dropout on Fold {i}"):
            # Fix: Remove the extra dimension at index 1
            image = image.to(DEVICE).squeeze(1)
            mc_predictions = []
            for _ in range(MC_ITERATIONS):
                with torch.no_grad():
                    mc_predictions.append(model(image).item())

            all_results.append({
                "subject_id": subject_id[0],  # DataLoader wraps strings in a list
                "label": label.item(),
                "mc_mean": np.mean(mc_predictions),
                "mc_std": np.std(mc_predictions)
            })

if all_results:
    results_df = pd.DataFrame(all_results)
    results_df.to_pickle(output_path / "mci_mc_dropout_results.pkl")
    print(f"\nSaved all MC dropout results to {output_path / 'mci_mc_dropout_results.pkl'}")
else:
    print("No results were generated. Please check model paths and data.")

### Visualize Uncertainty Distributions

Let's visualize the prediction distributions for a few select cases to better understand what the model's uncertainty looks like. We'll pick:
1. A high-confidence correct pMCI prediction.
2. A high-confidence correct sMCI prediction.
3. A high-uncertainty (misclassified or borderline) case.


In [None]:
if 'results_df' in locals():
    # Find interesting cases
    correct_pmci = results_df[(results_df['label'] == 1) & (results_df['mc_mean'] > 0.8)].sort_values('mc_std', ascending=True)
    correct_smci = results_df[(results_df['label'] == 0) & (results_df['mc_mean'] < 0.2)].sort_values('mc_std', ascending=True)
    high_uncertainty = results_df.sort_values('mc_std', ascending=False)

    cases = {
        "High-Confidence pMCI": correct_pmci.iloc[0] if not correct_pmci.empty else None,
        "High-Confidence sMCI": correct_smci.iloc[0] if not correct_smci.empty else None,
        "High-Uncertainty Case": high_uncertainty.iloc[0] if not high_uncertainty.empty else None,
    }

    # This requires re-running inference for these specific subjects, which is slow.
    # A simpler approach is to just plot a representative normal distribution from the saved mean/std.

    fig, axes = plt.subplots(1, 3, figsize=(20, 5), sharex=True)

    for ax, (title, case) in zip(axes, cases.items()):
        if case is None:
            ax.set_title(f"{title}\n(No sample found)")
            ax.axis('off')
            continue

        # Generate data for a normal distribution plot
        predictions = np.random.normal(case['mc_mean'], case['mc_std'], 1000)

        sns.histplot(predictions, bins=30, kde=True, ax=ax)
        ax.axvline(case['mc_mean'], color='r', linestyle='--', label=f"Mean: {case['mc_mean']:.2f}")
        ax.axvline(0.5, color='k', linestyle=':', label='Threshold: 0.5')

        true_label = "pMCI" if case['label'] == 1 else "sMCI"
        ax.set_title(f"{title}\nSubject: {case['subject_id']}\nTrue Label: {true_label} | Std Dev: {case['mc_std']:.3f}")
        ax.set_xlabel("Predicted Probability (pMCI)")
        ax.legend()

    plt.tight_layout()
    plt.savefig(figures_path / "uncertainty_examples.png")
    plt.show()
else:
    print("Results DataFrame not available for visualization.")


In [None]:
if 'results_df' in locals():
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    # Distribution of Mean Predictions
    sns.histplot(data=results_df, x='mc_mean', hue='label', bins=30, kde=True, ax=axes[0])
    axes[0].set_title('Distribution of Mean Predictions by True Label')
    axes[0].set_xlabel('Mean Predicted Probability (pMCI)')
    axes[0].set_ylabel('Count')
    axes[0].axvline(0.5, color='k', linestyle=':', label='Threshold: 0.5')
    axes[0].legend()

    # Distribution of Standard Deviations (Uncertainty)
    sns.histplot(data=results_df, x='mc_std', hue='label', bins=30, kde=True, ax=axes[1])
    axes[1].set_title('Distribution of Prediction Uncertainty (Std Dev) by True Label')
    axes[1].set_xlabel('Standard Deviation of Predictions')
    axes[1].set_ylabel('Count')

    plt.tight_layout()
    plt.savefig(figures_path / "overall_uncertainty_distributions_by_label.png")
    plt.show()
else:
    print("Results DataFrame not available for overall visualization.")

In [None]:
if 'results_df' in locals():
    print("Overall Summary Statistics:")
    display(results_df.groupby('label')[['mc_mean', 'mc_std']].agg(['mean', 'std']))

    # You can also look at the overall distribution of standard deviations
    print("\nSummary Statistics of Standard Deviations:")
    display(results_df['mc_std'].describe())

else:
    print("Results DataFrame not available for summary statistics.")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

if 'results_df' in locals():
    # Calculate absolute prediction error (Distance from the True Label)
    results_df['prediction_error'] = np.abs(results_df['label'] - results_df['mc_mean'])

    # Correlation between Uncertainty (Std Dev) and Error
    correlation = results_df[['mc_std', 'prediction_error']].corr().iloc[0, 1]
    print(f"Correlation between Uncertainty and Prediction Error: {correlation:.4f}")

    # Visualize relationship
    plt.figure(figsize=(10, 6))
    sns.scatterplot(data=results_df, x='mc_std', y='prediction_error', hue='label', alpha=0.6)
    plt.title(f"Prediction Error vs. Uncertainty (Correlation: {correlation:.2f})")
    plt.xlabel("Uncertainty (Standard Deviation)")
    plt.ylabel("Absolute Prediction Error |Label - Mean|")
    plt.legend(title='True Label')
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("results_df is not defined. Please run the inference cells first.")

In [None]:
if 'results_df' in locals():
    # Sort data by uncertainty (lowest uncertainty first)
    sorted_df = results_df.sort_values('mc_std', ascending=True)

    fractions = np.linspace(0.1, 1.0, 20)
    accuracies = []
    retained_ratios = []

    for frac in fractions:
        # Determine how many samples to keep (top % most certain)
        n_samples = int(len(sorted_df) * frac)
        if n_samples == 0: continue

        # Select the subset of most confident predictions
        subset = sorted_df.iloc[:n_samples]

        # Calculate accuracy on this subset
        # Prediction is 1 if mean probability > 0.5, else 0
        preds = (subset['mc_mean'] > 0.5).astype(float)
        acc = (preds == subset['label']).mean()

        accuracies.append(acc)
        retained_ratios.append(frac)

    # Plot Accuracy vs. Retention
    plt.figure(figsize=(10, 6))
    plt.plot(retained_ratios, accuracies, marker='o', linewidth=2)
    plt.xlabel("Fraction of Data Retained (Low Uncertainty -> High Uncertainty)")
    plt.ylabel("Accuracy on Retained Data")
    plt.title("Accuracy vs. Uncertainty Retention Curve")
    plt.gca().invert_xaxis() # Invert x-axis to show effect of discarding uncertain samples (Right to Left)
    plt.grid(True)
    plt.annotate('Keeping only most\ncertain samples', xy=(0.2, accuracies[1]), xytext=(0.4, accuracies[0]-0.05),
                 arrowprops=dict(facecolor='black', shrink=0.05))
    plt.show()

    print("Interpretation:\nIf the curve goes UP as you move to the LEFT (keeping fewer, more certain samples),\nit means the uncertainty metric is working: the model is more accurate on cases where it is confident.")

In [None]:
if 'results_df' in locals():
    # 1. Identify Misclassifications (Threshold at 0.5)
    results_df['predicted_label'] = (results_df['mc_mean'] > 0.5).astype(float)
    results_df['is_correct'] = results_df['predicted_label'] == results_df['label']

    # Filter for errors only
    errors_df = results_df[~results_df['is_correct']].copy()

    print(f"Total Misclassified Subjects: {len(errors_df)} out of {len(results_df)}")

    # 2. Most Uncertain Errors
    # These subjects correspond to the "tail" of the retention curve.
    # Including them is what causes the accuracy to drop as you keep more data.
    uncertain_errors = errors_df.sort_values('mc_std', ascending=False)

    print("\n--- Top 10 Most Uncertain Misclassifications (Likely causing the curve drop) ---")
    display(uncertain_errors.head(10)[['subject_id', 'label', 'mc_mean', 'mc_std', 'prediction_error']])

    # 3. Most Confident Errors
    # These are subjects the model got wrong despite being very consistent (low std).
    # These lower the accuracy of even your "best" subset.
    confident_errors = errors_df.sort_values('mc_std', ascending=True)

    print("\n--- Top 10 Most Confident Misclassifications (High confidence, wrong prediction) ---")
    display(confident_errors.head(10)[['subject_id', 'label', 'mc_mean', 'mc_std', 'prediction_error']])
else:
    print("results_df not found. Please run the inference step first.")

In [None]:
if 'results_df' in locals():
    # Strategy: Reject the top 20% most uncertain predictions
    threshold = results_df['mc_std'].quantile(0.80)

    # Split into Accepted (Low Uncertainty) and Rejected (High Uncertainty)
    accepted_df = results_df[results_df['mc_std'] <= threshold]
    rejected_df = results_df[results_df['mc_std'] > threshold]

    # Calculate accuracies
    acc_original = (results_df['predicted_label'] == results_df['label']).mean()
    acc_accepted = (accepted_df['predicted_label'] == accepted_df['label']).mean()

    print(f"Uncertainty Threshold (80th percentile): {threshold:.4f}")
    print(f"Original Accuracy (All {len(results_df)} subjects): {acc_original:.2%}")
    print(f"Filtered Accuracy (Best {len(accepted_df)} subjects): {acc_accepted:.2%}")
    print(f"\n--- Subjects 'Removed' due to High Uncertainty ({len(rejected_df)} subjects) ---")

    # Show the list of subjects to remove, sorted by how uncertain they are
    display(rejected_df.sort_values('mc_std', ascending=False)[['subject_id', 'label', 'mc_mean', 'mc_std', 'is_correct']].head(15))

    print("\nNote: 'is_correct' shows that many of these highly uncertain cases were indeed misclassified (False).")
else:
    print("results_df is needed for this analysis.")

## Plots

In [None]:
# Load previously saved results if available to skip inference

output_path = Path("/content/drive/MyDrive/Mestrado/TFM/new_pipeline/model_outputs/saved_models/mci_conversion_tau/monte_carlo/")

results_file = output_path / "mci_mc_dropout_results.pkl"

if results_file.exists():
    print(f"Loading saved MC Dropout results from {results_file}...")
    results_df = pd.read_pickle(results_file)
    print(f"Successfully loaded {len(results_df)} subject predictions.")
    display(results_df.head())
else:
    print(f"Results file not found at {results_file}. Please run the inference cell above to generate results.")

In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import balanced_accuracy_score

# --- Calculate Optimal Threshold ---
if 'results_df' in locals():
    y_true = results_df['label'].values
    mc_mean = results_df['mc_mean'].values

    thresholds = np.linspace(0, 1, 1001)
    scores = []

    for t in thresholds:
        preds = (mc_mean >= t).astype(int)
        scores.append(balanced_accuracy_score(y_true, preds))

    best_idx = np.argmax(scores)
    THRESH = thresholds[best_idx]
    print(THRESH)
    print(f"Calculated Optimal Threshold (Balanced Accuracy): {THRESH:.4f} (Score: {scores[best_idx]:.4f})")
else:
    THRESH = 0.31
    print(f"results_df not found. Using default threshold: {THRESH}")

# --- Prepare Data for Visualization ---
df = results_df.copy()

# SDpercentile in [0,100]; higher percentile = larger std = lower confidence
df["sd_percentile"] = df["mc_std"].rank(pct=True) * 100.0

# Confidence Score (CS): higher is better (lower std)
df["CS"] = 100.0 - df["sd_percentile"]

# Convenience: distance to threshold
df["dist_to_thresh"] = (df["mc_mean"] - THRESH).abs()

display(df.head())

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --- assumes df has: mc_mean, mc_std and THRESH ---

# Compute Confidence Score (paper definition) if missing
if "CS" not in df.columns:
    df = df.copy()
    df["sd_percentile"] = df["mc_std"].rank(pct=True) * 100.0
    df["CS"] = 100.0 - df["sd_percentile"]

fig, ax = plt.subplots(figsize=(8, 6))

# MAIN plot: CS vs mean score (show ALL points)
sc = ax.scatter(
    df["mc_mean"],
    df["CS"],
    c=df["mc_std"],        # color by raw uncertainty (optional but informative)
    s=22,
    alpha=0.9)

cb = fig.colorbar(sc, ax=ax)
cb.set_label("MC standard deviation")

# Threshold
ax.axvline(THRESH, linestyle="--", linewidth=2)
ax.text(THRESH - 0.02, 98, "No Progress", ha="right", va="top")
ax.text(THRESH + 0.02, 98, "Progress", ha="left", va="top")

# Left axis (CS)
ax.set_xlabel("Model MCI-to-AD progression score (mean of MCD)")
ax.set_ylabel("Confidence Score (CS)")
ax.set_xlim(0, 1)
ax.set_ylim(0, 100)

# SECONDARY axis: raw SD (same data, inverted like the paper)
ax2 = ax.twinx()
ax2.set_ylim(df["mc_std"].min(), df["mc_std"].max())
ax2.invert_yaxis()
ax2.set_ylabel("Standard Deviation")

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# ---- Inputs expected ----
# df: DataFrame with subject_id, mc_mean, mc_std (CS optional)
# THRESH: scalar decision threshold (e.g., 0.31)

df2 = df.copy()

# CS (paper-style) if missing
if "CS" not in df2.columns:
    df2["sd_percentile"] = df2["mc_std"].rank(pct=True) * 100.0
    df2["CS"] = 100.0 - df2["sd_percentile"]

df2["dist_to_thresh"] = (df2["mc_mean"] - THRESH).abs()

# --- Select the same 4 archetypes (robust, consistent labeling) ---
hi_cut = df2["CS"].quantile(0.80)
lo_cut = df2["CS"].quantile(0.20)
hi = df2[df2["CS"] >= hi_cut]
lo = df2[df2["CS"] <= lo_cut]

A = hi.nsmallest(1, "mc_mean").iloc[0]                 # far left, high confidence
B = hi.nlargest(1, "mc_mean").iloc[0]                  # far right, high confidence
C = hi.nsmallest(1, "dist_to_thresh").iloc[0]          # near threshold, high confidence
D = lo.nsmallest(1, "dist_to_thresh").iloc[0]          # near threshold, low confidence

picked = [("A", A), ("B", B), ("C", C), ("D", D)]

# ---- Fitted normal curves ----
x = np.linspace(0, 1, 900)
def normal_pdf(x, mu, sigma):
    sigma = max(float(sigma), 1e-6)
    return (1.0 / (sigma * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((x - mu) / sigma)**2)

plt.figure(figsize=(9.5, 4.8))

for tag, row in picked:
    mu = float(row["mc_mean"])
    sd = float(row["mc_std"])
    cs = float(row["CS"])
    sid = row["subject_id"]

    y = normal_pdf(x, mu, sd)
    y = y / np.trapezoid(y, x) * 100.0   # normalized “percent-like” density over [0,1]

    plt.fill_between(x, y, alpha=0.25)
    plt.plot(x, y, linewidth=2,
             label=f"{tag} | ID {sid} | Score {mu:.2f} | CS {cs:.0f}")

plt.axvline(THRESH, linestyle="--", linewidth=2, label=f"Threshold {THRESH:.2f}")

plt.xlabel("Model MCI-to-AD score (mean of MCD)")
plt.ylabel("Score density (normalized, %)")
plt.xlim(0, 1)
plt.ylim(0,2000)
plt.tight_layout()
plt.legend(loc="upper right", fontsize=9)
plt.show()

print("Selected A,B,C,D IDs:", [A["subject_id"], B["subject_id"], C["subject_id"], D["subject_id"]])
