# Model Interpretability Analysis

This notebook demonstrates various interpretability techniques for understanding model predictions:
- Grad-CAM and Grad-CAM++
- Feature space visualization (t-SNE, UMAP, PCA)
- Guided backpropagation
- Feature separability analysis

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from src.data.dataset import StrokeDataset
from src.data.augmentation import get_val_augmentation
from src.models.cnn import ResNetClassifier
from src.visualization.gradcam import GradCAM, GradCAMPlusPlus, get_target_layer, overlay_heatmap_on_image
from src.visualization.features import extract_features, compute_tsne, compute_pca, plot_embedding
from src.utils.helpers import load_config, get_device

%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

## 1. Load Model and Data

In [None]:
# Configuration
checkpoint_path = '../experiments/checkpoints/best_model.pth'
config_path = '../config/default_config.yaml'
data_dir = '../data/processed'

# Load config
config = load_config(config_path)
device = get_device()

print(f"Device: {device}")

In [None]:
# Load model
model = ResNetClassifier(
    arch=config['model']['architecture'],
    num_classes=config['model']['num_classes'],
    pretrained=False
)

checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

print("Model loaded successfully")

In [None]:
# Load dataset
dataset = StrokeDataset(
    data_dir=data_dir,
    split='val',
    split_file='../data/splits/val.json',
    transform=get_val_augmentation(config['data']['image_size'])
)

print(f"Dataset size: {len(dataset)}")

## 2. Grad-CAM Visualization

In [None]:
# Initialize Grad-CAM
target_layer = get_target_layer(model, 'resnet')
gradcam = GradCAM(model, target_layer)

# Select a random sample
idx = np.random.randint(0, len(dataset))
image, true_label = dataset[idx]

# Generate Grad-CAM
image_input = image.unsqueeze(0).to(device)
cam = gradcam.generate_cam(image_input)

# Get prediction
with torch.no_grad():
    output = model(image_input)
    pred_label = output.argmax(dim=1).item()
    confidence = torch.softmax(output, dim=1)[0, pred_label].item()

# Prepare image for visualization
img_display = image.permute(1, 2, 0).cpu().numpy()
if img_display.max() <= 1.0:
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_display = std * img_display + mean
    img_display = np.clip(img_display, 0, 1)

img_display = (img_display * 255).astype(np.uint8)

# Overlay
overlayed = overlay_heatmap_on_image(img_display, cam, alpha=0.4)

# Plot
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
class_names = ['CE', 'LAA']

axes[0].imshow(img_display)
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(cam, cmap='jet')
axes[1].set_title('Grad-CAM Heatmap')
axes[1].axis('off')

axes[2].imshow(overlayed)
axes[2].set_title('Overlayed')
axes[2].axis('off')

color = 'green' if pred_label == true_label else 'red'
title = f"True: {class_names[true_label]} | Pred: {class_names[pred_label]} ({confidence:.2%})"
fig.suptitle(title, fontsize=14, fontweight='bold', color=color)
plt.tight_layout()
plt.show()

## 3. Grad-CAM++ Comparison

In [None]:
# Compare Grad-CAM and Grad-CAM++
gradcam_pp = GradCAMPlusPlus(model, target_layer)

# Generate both
cam_regular = gradcam.generate_cam(image_input)
cam_plus = gradcam_pp.generate_cam(image_input)

# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(img_display)
axes[0].set_title('Original')
axes[0].axis('off')

axes[1].imshow(overlay_heatmap_on_image(img_display, cam_regular, alpha=0.4))
axes[1].set_title('Grad-CAM')
axes[1].axis('off')

axes[2].imshow(overlay_heatmap_on_image(img_display, cam_plus, alpha=0.4))
axes[2].set_title('Grad-CAM++')
axes[2].axis('off')

plt.suptitle('Grad-CAM vs Grad-CAM++', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Feature Space Visualization

In [None]:
# Extract features
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
features, labels = extract_features(model, dataloader, device)

print(f"Feature shape: {features.shape}")

In [None]:
# Compute t-SNE
tsne_embedded = compute_tsne(features)
plot_embedding(tsne_embedded, labels, title="t-SNE Feature Embedding")

In [None]:
# Compute PCA
pca_embedded = compute_pca(features)
plot_embedding(pca_embedded, labels, title="PCA Feature Embedding")

## 5. Feature Separability Analysis

In [None]:
from src.visualization.features import analyze_feature_separability

separability = analyze_feature_separability(features, labels)

print("Feature Separability Metrics:")
for metric, value in separability.items():
    print(f"  {metric}: {value:.4f}")

## 6. Multiple Sample Visualization

In [None]:
# Visualize Grad-CAM for multiple samples
n_samples = 6
indices = np.random.choice(len(dataset), n_samples, replace=False)

fig, axes = plt.subplots(n_samples, 3, figsize=(12, 4*n_samples))

for row, idx in enumerate(indices):
    image, true_label = dataset[idx]
    image_input = image.unsqueeze(0).to(device)
    
    # Generate Grad-CAM
    cam = gradcam.generate_cam(image_input)
    
    # Get prediction
    with torch.no_grad():
        output = model(image_input)
        pred_label = output.argmax(dim=1).item()
        confidence = torch.softmax(output, dim=1)[0, pred_label].item()
    
    # Prepare image
    img_display = image.permute(1, 2, 0).cpu().numpy()
    if img_display.max() <= 1.0:
        img_display = std * img_display + mean
        img_display = np.clip(img_display, 0, 1)
    img_display = (img_display * 255).astype(np.uint8)
    
    # Overlay
    overlayed = overlay_heatmap_on_image(img_display, cam, alpha=0.4)
    
    # Plot
    axes[row, 0].imshow(img_display)
    axes[row, 0].axis('off')
    axes[row, 1].imshow(cam, cmap='jet')
    axes[row, 1].axis('off')
    axes[row, 2].imshow(overlayed)
    axes[row, 2].axis('off')
    
    # Title for row
    color = 'green' if pred_label == true_label else 'red'
    axes[row, 0].set_ylabel(
        f"True: {class_names[true_label]}\nPred: {class_names[pred_label]} ({confidence:.2%})",
        fontsize=10,
        color=color,
        fontweight='bold'
    )

axes[0, 0].set_title('Original', fontsize=12)
axes[0, 1].set_title('Grad-CAM', fontsize=12)
axes[0, 2].set_title('Overlay', fontsize=12)

plt.tight_layout()
plt.show()

## 7. Clinical Interpretation

**Questions to consider:**
- Does the model focus on clinically relevant regions?
- Are the attention patterns consistent with medical knowledge?
- Do misclassifications show attention on confounding features?
- Is there a pattern in which samples have low confidence?