# Uncertainty Quantification

This notebook demonstrates uncertainty quantification techniques:
- Monte Carlo Dropout
- Test-Time Augmentation
- Confidence calibration
- Identifying uncertain predictions

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from src.data.dataset import StrokeDataset
from src.data.augmentation import get_train_augmentation, get_val_augmentation
from src.models.cnn import ResNetClassifier
from src.evaluation.uncertainty import (
    monte_carlo_dropout, test_time_augmentation,
    calculate_confidence_metrics, calibration_curve,
    expected_calibration_error, identify_uncertain_samples
)
from src.utils.helpers import load_config, get_device

%matplotlib inline

## 1. Load Model and Data

In [None]:
checkpoint_path = '../experiments/checkpoints/best_model.pth'
config_path = '../config/default_config.yaml'
data_dir = '../data/processed'

config = load_config(config_path)
device = get_device()

# Load model
model = ResNetClassifier(
    arch=config['model']['architecture'],
    num_classes=config['model']['num_classes'],
    pretrained=False
)

checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

print("Model loaded")

In [None]:
# Load dataset
dataset = StrokeDataset(
    data_dir=data_dir,
    split='val',
    split_file='../data/splits/val.json',
    transform=get_val_augmentation(config['data']['image_size'])
)

dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

print(f"Dataset size: {len(dataset)}")

## 2. Monte Carlo Dropout

In [None]:
# Perform MC dropout on first batch
images, labels = next(iter(dataloader))

n_iterations = 30
mean_probs, std_probs, all_predictions = monte_carlo_dropout(
    model, images, n_iterations=n_iterations, device=device
)

print(f"Mean probabilities shape: {mean_probs.shape}")
print(f"Std probabilities shape: {std_probs.shape}")
print(f"All predictions shape: {all_predictions.shape}")

In [None]:
# Visualize uncertainty for a sample
sample_idx = 0
sample_predictions = all_predictions[:, sample_idx, :]  # (n_iterations, n_classes)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Plot prediction distribution
axes[0].hist(sample_predictions[:, 0], bins=20, alpha=0.7, label='CE')
axes[0].hist(sample_predictions[:, 1], bins=20, alpha=0.7, label='LAA')
axes[0].set_xlabel('Probability')
axes[0].set_ylabel('Frequency')
axes[0].set_title('MC Dropout Prediction Distribution')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot confidence over iterations
max_probs = sample_predictions.max(axis=1)
axes[1].plot(max_probs)
axes[1].axhline(mean_probs[sample_idx].max(), color='r', linestyle='--', label='Mean')
axes[1].set_xlabel('Iteration')
axes[1].set_ylabel('Max Probability')
axes[1].set_title('Confidence Across MC Dropout Iterations')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Mean confidence: {mean_probs[sample_idx].max():.4f}")
print(f"Std: {std_probs[sample_idx].max():.4f}")

## 3. Confidence Metrics

In [None]:
# Calculate confidence metrics for entire validation set
all_mean_probs = []
all_std_probs = []
all_predictions_mc = []
all_labels = []

for images, labels in dataloader:
    mean_probs, std_probs, predictions_mc = monte_carlo_dropout(
        model, images, n_iterations=30, device=device
    )
    
    all_mean_probs.append(mean_probs)
    all_std_probs.append(std_probs)
    all_predictions_mc.append(predictions_mc)
    all_labels.append(labels.numpy())

# Concatenate
mean_probs = np.concatenate(all_mean_probs, axis=0)
std_probs = np.concatenate(all_std_probs, axis=0)
predictions_mc = np.concatenate(all_predictions_mc, axis=1)
labels = np.concatenate(all_labels, axis=0)

# Calculate metrics
metrics = calculate_confidence_metrics(mean_probs, predictions_mc)

print("Confidence Metrics:")
for key, value in metrics.items():
    print(f"  {key}: {value:.4f}")

## 4. Calibration Analysis

In [None]:
# Calculate calibration
predictions = mean_probs.argmax(axis=1)
confidences = mean_probs.max(axis=1)
y_true_binary = (predictions == labels).astype(int)

bin_confs, bin_accs = calibration_curve(y_true_binary, confidences, n_bins=10)
ece = expected_calibration_error(y_true_binary, confidences)

# Plot calibration curve
fig, ax = plt.subplots(figsize=(8, 8))
ax.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration')
ax.plot(bin_confs, bin_accs, 'o-', label='Model', markersize=10)
ax.set_xlabel('Confidence', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Calibration Curve (Reliability Diagram)', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# Add ECE text
ax.text(0.05, 0.95, f'ECE: {ece:.4f}', transform=ax.transAxes,
       fontsize=12, verticalalignment='top',
       bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.show()

print(f"Expected Calibration Error: {ece:.4f}")

## 5. Identify Uncertain Samples

In [None]:
# Identify uncertain samples
uncertain = identify_uncertain_samples(
    mean_probs, predictions_mc,
    threshold_confidence=0.7,
    threshold_entropy=0.5
)

print(f"Uncertain samples: {uncertain.sum()} / {len(uncertain)} ({100*uncertain.sum()/len(uncertain):.1f}%)")

# Compare accuracy on certain vs uncertain
certain_mask = ~uncertain

certain_acc = (predictions[certain_mask] == labels[certain_mask]).mean()
uncertain_acc = (predictions[uncertain] == labels[uncertain]).mean()

print(f"\nAccuracy on certain samples: {certain_acc:.4f}")
print(f"Accuracy on uncertain samples: {uncertain_acc:.4f}")

## 6. Confidence vs Correctness

In [None]:
# Plot confidence distribution by correctness
correct_mask = predictions == labels

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Confidence distribution
axes[0].hist(confidences[correct_mask], bins=30, alpha=0.7, label='Correct', density=True)
axes[0].hist(confidences[~correct_mask], bins=30, alpha=0.7, label='Incorrect', density=True)
axes[0].set_xlabel('Confidence', fontsize=12)
axes[0].set_ylabel('Density', fontsize=12)
axes[0].set_title('Confidence Distribution', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Scatter plot
axes[1].scatter(confidences[correct_mask], np.ones(correct_mask.sum()),
               alpha=0.5, label='Correct', s=20)
axes[1].scatter(confidences[~correct_mask], np.zeros((~correct_mask).sum()),
               alpha=0.5, label='Incorrect', s=20)
axes[1].set_xlabel('Confidence', fontsize=12)
axes[1].set_ylabel('Correctness', fontsize=12)
axes[1].set_title('Confidence vs Correctness', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Test-Time Augmentation (Sample)

In [None]:
# TTA on a few samples (computationally expensive)
augmentation_fn = get_train_augmentation(image_size=224, p=1.0)

sample_image, sample_label = dataset[0]
sample_image_input = sample_image.unsqueeze(0)

tta_mean, tta_std = test_time_augmentation(
    model, sample_image_input, augmentation_fn,
    n_augmentations=20, device=device
)

print("Test-Time Augmentation Results:")
print(f"Mean prediction: {tta_mean[0]}")
print(f"Std: {tta_std[0]}")
print(f"Predicted class: {tta_mean[0].argmax()}")
print(f"True class: {sample_label}")

## 8. Clinical Recommendations

**Based on uncertainty analysis:**

1. **High Confidence Predictions** (>0.9): Can be used with minimal human oversight
2. **Medium Confidence** (0.7-0.9): Recommend expert review
3. **Low Confidence** (<0.7): Requires expert diagnosis
4. **Uncertain Samples**: Flag for manual review or additional testing

**Key Insights:**
- Model calibration (ECE) indicates how well confidence matches accuracy
- MC dropout reveals epistemic uncertainty (model uncertainty)
- TTA reveals aleatoric uncertainty (data uncertainty)
- Uncertain predictions often indicate ambiguous cases