# PANDA-PLUS-Bench: Evaluating WSI-Specific Feature Collapse

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dellacortelab/PANDA-PLUS-Bench/blob/main/PANDA_PLUS_Bench_Evaluation.ipynb)

This notebook enables evaluation of pathology foundation models on the PANDA-PLUS-Bench benchmark for measuring WSI-specific feature collapse.

**What this notebook does:**
1. Loads pre-augmented patches from the benchmark dataset
2. Extracts embeddings using your foundation model (or default Phikon)
3. Computes robustness metrics (within-slide vs cross-slide accuracy, etc.)
4. Compares your model against published results from 7 foundation models
5. Generates publication-quality visualizations

**Citation:**
```
Ebbert J, Della Corte D. PANDA-PLUS-Bench: A Benchmark for Evaluating 
WSI-Specific Feature Collapse in Pathology Foundation Models. 2025.
```

---

## Step 0: Setup and Installation

Run this cell first to install required packages (~2-3 minutes).

In [None]:
# Install required packages
!pip install -q datasets huggingface_hub transformers timm torch torchvision
!pip install -q scikit-learn scipy matplotlib seaborn tqdm pandas numpy

import sys
import warnings
warnings.filterwarnings('ignore')

# Verify GPU availability
import torch
if torch.cuda.is_available():
    print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    DEVICE = torch.device('cuda')
else:
    print("⚠ No GPU detected. Embedding extraction will be slow.")
    print("  Go to Runtime → Change runtime type → GPU")
    DEVICE = torch.device('cpu')

print("\n✓ Setup complete!")

## Step 1: Configuration

**Configure your evaluation here:**
- Set your HuggingFace token (if using gated models like UNI, Virchow)
- Choose which foundation model to evaluate
- Select augmentation conditions to test

In [None]:
#@title Configuration {display-mode: "form"}

#@markdown ### HuggingFace Token (optional)
#@markdown Required only for gated models (UNI, Virchow, etc.)
#@markdown Get your token at: https://huggingface.co/settings/tokens
HF_TOKEN = ""  #@param {type:"string"}

#@markdown ### Model Selection
#@markdown Choose a model to evaluate, or enter a custom HuggingFace model ID
MODEL_CHOICE = "owkin/phikon"  #@param ["owkin/phikon", "owkin/phikon-v2", "MahmoodLab/UNI", "paige-ai/Virchow", "paige-ai/Virchow2", "custom"]
CUSTOM_MODEL_ID = ""  #@param {type:"string"}

#@markdown ### Augmentation Conditions
#@markdown Select which augmentation conditions to evaluate
EVAL_BASELINE = True  #@param {type:"boolean"}
EVAL_COLOR_JITTER = False  #@param {type:"boolean"}
EVAL_GRAYSCALE = False  #@param {type:"boolean"}
EVAL_GAUSSIAN_NOISE = False  #@param {type:"boolean"}
EVAL_HEAVY_GEOMETRIC = False  #@param {type:"boolean"}
EVAL_COMBINED_AGGRESSIVE = True  #@param {type:"boolean"}
EVAL_MACENKO = False  #@param {type:"boolean"}
EVAL_HED = False  #@param {type:"boolean"}

#@markdown ### Evaluation Settings
N_NEIGHBORS = 5  #@param {type:"integer"}
BATCH_SIZE = 32  #@param {type:"integer"}

# Process configuration
MODEL_ID = CUSTOM_MODEL_ID if MODEL_CHOICE == "custom" else MODEL_CHOICE

AUGMENTATIONS_TO_EVAL = []
if EVAL_BASELINE: AUGMENTATIONS_TO_EVAL.append("baseline")
if EVAL_COLOR_JITTER: AUGMENTATIONS_TO_EVAL.append("color_jitter")
if EVAL_GRAYSCALE: AUGMENTATIONS_TO_EVAL.append("grayscale")
if EVAL_GAUSSIAN_NOISE: AUGMENTATIONS_TO_EVAL.append("gaussian_noise")
if EVAL_HEAVY_GEOMETRIC: AUGMENTATIONS_TO_EVAL.append("heavy_geometric")
if EVAL_COMBINED_AGGRESSIVE: AUGMENTATIONS_TO_EVAL.append("combined_aggressive")
if EVAL_MACENKO: AUGMENTATIONS_TO_EVAL.append("macenko_normalization")
if EVAL_HED: AUGMENTATIONS_TO_EVAL.append("hed_stain_augmentation")

print("Configuration Summary:")
print(f"  Model: {MODEL_ID}")
print(f"  Augmentations: {AUGMENTATIONS_TO_EVAL}")
print(f"  HF Token: {'Set' if HF_TOKEN else 'Not set (using public models only)'}")
print(f"  Device: {DEVICE}")

## Step 2: Load Benchmark Dataset

Downloads patches from HuggingFace Hub (~2-5 GB depending on augmentations selected).

In [None]:
from datasets import load_dataset
import json
import numpy as np

# Dataset repository
DATASET_REPO = "dellacorte/PANDA-PLUS-Bench"

print("Loading PANDA-PLUS-Bench dataset...")
print("(First run will download ~2-5 GB, subsequent runs use cache)\n")

# Load the dataset
try:
    dataset = load_dataset(DATASET_REPO, token=HF_TOKEN if HF_TOKEN else None)
    print(f"✓ Dataset loaded successfully!")
    print(f"  Available splits: {list(dataset.keys())}")
    
    # Show dataset info
    sample_split = list(dataset.keys())[0]
    print(f"  Total patches in '{sample_split}': {len(dataset[sample_split])}")
    print(f"  Features: {dataset[sample_split].features}")
    
except Exception as e:
    print(f"✗ Error loading dataset: {e}")
    print("\nTroubleshooting:")
    print("  1. Check that the dataset exists at the specified repo")
    print("  2. If private, ensure HF_TOKEN is set correctly")
    raise

# Load paper results for comparison
print("\nLoading published benchmark results...")
try:
    # Load from GitHub repo
    import urllib.request
    paper_results_url = "https://raw.githubusercontent.com/dellacortelab/PANDA-PLUS-Bench/main/paper_results.json"
    with urllib.request.urlopen(paper_results_url) as response:
        PAPER_RESULTS = json.loads(response.read().decode())
    print(f"✓ Loaded results for {len(PAPER_RESULTS)} models from paper")
    print(f"  Models: {list(PAPER_RESULTS.keys())}")
except Exception as e:
    print(f"⚠ Paper results not found: {e}")
    print("  Comparison will be skipped.")
    PAPER_RESULTS = {}

## Step 3: Load Foundation Model

Loads the selected foundation model for embedding extraction.

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoImageProcessor
import timm
from torchvision import transforms

def load_foundation_model(model_id, hf_token=None):
    """
    Load a pathology foundation model.
    
    Supports:
    - Phikon / Phikon-v2 (owkin)
    - UNI / UNI2 (MahmoodLab)
    - Virchow / Virchow2 (paige-ai)
    - Any timm-compatible model
    - Any HuggingFace transformers model
    """
    print(f"Loading model: {model_id}")
    
    model = None
    processor = None
    embed_dim = None
    
    # Try HuggingFace transformers first
    try:
        model = AutoModel.from_pretrained(
            model_id, 
            token=hf_token,
            trust_remote_code=True
        )
        
        try:
            processor = AutoImageProcessor.from_pretrained(
                model_id, 
                token=hf_token
            )
        except:
            processor = None
        
        # Get embedding dimension
        if hasattr(model.config, 'hidden_size'):
            embed_dim = model.config.hidden_size
        elif hasattr(model, 'embed_dim'):
            embed_dim = model.embed_dim
        else:
            embed_dim = None
            
        print(f"  ✓ Loaded via HuggingFace transformers")
        
    except Exception as e:
        print(f"  HuggingFace loading failed: {e}")
        print(f"  Trying timm...")
        
        # Try timm
        try:
            model = timm.create_model(model_id, pretrained=True, num_classes=0)
            embed_dim = model.num_features
            processor = None
            print(f"  ✓ Loaded via timm")
        except Exception as e2:
            raise ValueError(f"Could not load model '{model_id}' via HuggingFace or timm: {e2}")
    
    model = model.to(DEVICE)
    model.eval()
    
    # Default transform if no processor
    if processor is None:
        default_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    else:
        default_transform = None
    
    print(f"  Embedding dimension: {embed_dim}")
    print(f"  Device: {DEVICE}")
    
    return model, processor, default_transform, embed_dim


# Load the model
model, processor, transform, embed_dim = load_foundation_model(MODEL_ID, HF_TOKEN)
print("\n✓ Model ready for embedding extraction!")

## Step 4: Extract Embeddings

Extracts embeddings for all patches in selected augmentation conditions.

**Expected time:** ~5-15 minutes per augmentation on Colab GPU (T4)

In [None]:
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from PIL import Image
import numpy as np

class PatchDataset(Dataset):
    """Dataset wrapper for benchmark patches."""
    def __init__(self, hf_dataset, processor=None, transform=None):
        self.dataset = hf_dataset
        self.processor = processor
        self.transform = transform
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Get image (handle different formats)
        if 'image' in item:
            image = item['image']
        elif 'patch' in item:
            image = item['patch']
        else:
            raise KeyError(f"No image field found. Available: {item.keys()}")
        
        # Convert to PIL if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        
        # Apply preprocessing
        if self.processor is not None:
            inputs = self.processor(images=image, return_tensors="pt")
            pixel_values = inputs['pixel_values'].squeeze(0)
        elif self.transform is not None:
            pixel_values = self.transform(image)
        else:
            raise ValueError("No processor or transform provided")
        
        return {
            'pixel_values': pixel_values,
            'label': item.get('label', -1),
            'slide_id': item.get('slide_id', 'unknown')
        }


@torch.no_grad()
def extract_embeddings(model, dataloader, device):
    """
    Extract embeddings from all patches.
    
    Returns:
        embeddings: numpy array of shape (n_patches, embed_dim)
        labels: numpy array of shape (n_patches,)
        slide_ids: numpy array of shape (n_patches,)
    """
    model.eval()
    
    all_embeddings = []
    all_labels = []
    all_slide_ids = []
    
    for batch in tqdm(dataloader, desc="Extracting embeddings"):
        pixel_values = batch['pixel_values'].to(device)
        
        # Forward pass
        outputs = model(pixel_values)
        
        # Extract embeddings (handle different output formats)
        if hasattr(outputs, 'last_hidden_state'):
            embeddings = outputs.last_hidden_state[:, 0, :]
        elif hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
            embeddings = outputs.pooler_output
        elif isinstance(outputs, torch.Tensor):
            embeddings = outputs
        else:
            raise ValueError(f"Unknown output format: {type(outputs)}")
        
        all_embeddings.append(embeddings.cpu().numpy())
        all_labels.extend(batch['label'].numpy())
        all_slide_ids.extend(batch['slide_id'])
    
    embeddings = np.vstack(all_embeddings)
    labels = np.array(all_labels)
    slide_ids = np.array(all_slide_ids)
    
    return embeddings, labels, slide_ids


# Extract embeddings for each augmentation
print(f"\nExtracting embeddings for {len(AUGMENTATIONS_TO_EVAL)} augmentation(s)...\n")

all_results = {}

for aug in AUGMENTATIONS_TO_EVAL:
    print(f"\n{'='*60}")
    print(f"Processing: {aug}")
    print(f"{'='*60}")
    
    if aug not in dataset:
        print(f"  ⚠ Augmentation '{aug}' not found in dataset. Skipping.")
        continue
    
    patch_dataset = PatchDataset(
        dataset[aug], 
        processor=processor, 
        transform=transform
    )
    
    dataloader = DataLoader(
        patch_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    print(f"  Patches: {len(patch_dataset)}")
    print(f"  Batches: {len(dataloader)}")
    
    embeddings, labels, slide_ids = extract_embeddings(model, dataloader, DEVICE)
    
    all_results[aug] = {
        'embeddings': embeddings,
        'labels': labels,
        'slide_ids': slide_ids
    }
    
    print(f"  ✓ Embeddings shape: {embeddings.shape}")
    print(f"  ✓ Unique slides: {len(np.unique(slide_ids))}")
    print(f"  ✓ Unique labels: {np.unique(labels)}")

print(f"\n\n✓ Embedding extraction complete!")
print(f"  Processed {len(all_results)} augmentation(s)")

## Step 5: Compute Benchmark Metrics

Computes the standard PANDA-PLUS-Bench metrics:
- **Within-slide accuracy**: Classification using patches from same slide
- **Cross-slide accuracy**: Leave-one-slide-out classification
- **Accuracy gap**: Within - Cross (higher = more collapse)
- **Silhouette scores**: Cluster quality by class and by slide
- **Confusion attribution**: Entropy of kNN label distribution

In [None]:
import numpy as np
from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import silhouette_score
from scipy.stats import entropy

def compute_within_cross_accuracy(embeddings, labels, slide_ids, n_neighbors=5):
    """Compute within-slide and cross-slide classification accuracy."""
    unique_slides = np.unique(slide_ids)
    within_scores = []
    cross_scores = []
    
    # Within-slide accuracy
    for slide in unique_slides:
        slide_mask = slide_ids == slide
        slide_embeddings = embeddings[slide_mask]
        slide_labels = labels[slide_mask]
        
        if len(slide_embeddings) < n_neighbors + 1:
            continue
        
        n_test = max(1, len(slide_embeddings) // 5)
        indices = np.random.permutation(len(slide_embeddings))
        train_idx, test_idx = indices[:-n_test], indices[-n_test:]
        
        knn = KNeighborsClassifier(n_neighbors=min(n_neighbors, len(train_idx)-1))
        knn.fit(slide_embeddings[train_idx], slide_labels[train_idx])
        within_scores.append(knn.score(slide_embeddings[test_idx], slide_labels[test_idx]))
    
    # Cross-slide accuracy
    logo = LeaveOneGroupOut()
    for train_idx, test_idx in logo.split(embeddings, labels, slide_ids):
        X_train, X_test = embeddings[train_idx], embeddings[test_idx]
        y_train, y_test = labels[train_idx], labels[test_idx]
        
        knn = KNeighborsClassifier(n_neighbors=n_neighbors)
        knn.fit(X_train, y_train)
        cross_scores.append(knn.score(X_test, y_test))
    
    return np.mean(within_scores), np.mean(cross_scores)

# Compute metrics for each augmentation
print("\nComputing metrics...\n")
metrics = {}

for aug, results in all_results.items():
    print(f"Computing metrics for {aug}...")
    
    within_acc, cross_acc = compute_within_cross_accuracy(
        results['embeddings'],
        results['labels'],
        results['slide_ids'],
        n_neighbors=N_NEIGHBORS
    )
    
    metrics[aug] = {
        'within_accuracy': within_acc,
        'cross_accuracy': cross_acc,
        'accuracy_gap': within_acc - cross_acc
    }
    
    print(f"  Within-slide accuracy: {within_acc:.3f}")
    print(f"  Cross-slide accuracy: {cross_acc:.3f}")
    print(f"  Accuracy gap: {within_acc - cross_acc:.3f}\n")

print("✓ Metrics computed!")

## Step 6: Display Results

Display your model's results in a summary table.

In [None]:
import pandas as pd

# Create results dataframe
results_data = []
for aug, m in metrics.items():
    results_data.append({
        'Augmentation': aug,
        'Within-Slide Acc': f"{m['within_accuracy']:.3f}",
        'Cross-Slide Acc': f"{m['cross_accuracy']:.3f}",
        'Accuracy Gap': f"{m['accuracy_gap']:.3f}"
    })

results_df = pd.DataFrame(results_data)
print(f"\n{'='*60}")
print(f"Results for {MODEL_ID}")
print(f"{'='*60}\n")
print(results_df.to_string(index=False))
print(f"\n{'='*60}")