### Tutorial: Training a WSI Classification Model with ABMIL on PANDA Dataset (MIL-Lab)

This tutorial trains an attention-based multiple instance learning model on the PANDA (Prostate cANcer graDe Assessment) dataset using pre-extracted UNI v2 features.

**Key Difference**: This notebook uses the **MIL-Lab framework** instead of trident's ABMILSlideEncoder.

#### A- Dataset Information & Preprocessing

- **Dataset**: PANDA (Prostate cANcer graDe Assessment)
- **Features**: Pre-extracted UNI v2 features (1536 dimensions)
- **Patch size**: 256x256 pixels at 20x magnification
- **Task**: Multi-class classification (ISUP grades 0-5)
- **Data**:
  - WSI directory: `/media/nadim/Data/prostate-cancer-grade-assessment/train_images`
  - Features directory: `/media/nadim/Data/prostate-cancer-grade-assessment/panda/`
  - Labels CSV: `/media/nadim/Data/prostate-cancer-grade-assessment/train.csv`

**Preprocessing Steps**:
1. Read slide IDs and labels from CSV
2. Scan features directory for available .h5 files
3. Match CSV labels with available features
4. Perform stratified train/val/test split (70/20/10)

In [1]:
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from glob import glob

# Configuration
csv_path = '/media/nadim/Data/prostate-cancer-grade-assessment/train.csv'
feats_path = '/media/nadim/Data/prostate-cancer-grade-assessment/panda/'
SEED = 42

# Set random seed for reproducibility
np.random.seed(SEED)

print(f"{'='*70}")
print("IMPROVED DATA PREPROCESSING")
print(f"{'='*70}\n")

# Step 1: Read labels and slide IDs from CSV
print("Step 1: Reading labels from CSV...")
df_labels = pd.read_csv(csv_path)

# Select only necessary columns
if 'isup_grade' in df_labels.columns:
    df_labels = df_labels[['image_id', 'isup_grade']].rename(columns={'image_id': 'slide_id', 'isup_grade': 'label'})
elif 'label' in df_labels.columns:
    df_labels = df_labels[['slide_id', 'label']]
else:
    print("ERROR: Could not find label column in CSV")
    
print(f"  Found {len(df_labels)} slides in CSV with labels")
print(f"  Label distribution in CSV:")
for grade in sorted(df_labels['label'].unique()):
    count = len(df_labels[df_labels['label'] == grade])
    print(f"    ISUP {grade}: {count}")

# Step 2: Find all available feature files
print(f"\nStep 2: Scanning features directory...")
feature_files = glob(os.path.join(feats_path, '*.h5'))
available_slide_ids = [os.path.basename(f).replace('.h5', '') for f in feature_files]
print(f"  Found {len(available_slide_ids)} feature files")

# Step 3: Match CSV with available features
print(f"\nStep 3: Matching CSV labels with available features...")
df_labels['has_features'] = df_labels['slide_id'].isin(available_slide_ids)
df_matched = df_labels[df_labels['has_features']].drop(columns=['has_features']).reset_index(drop=True)

missing_count = len(df_labels) - len(df_matched)
print(f"  Matched: {len(df_matched)} slides")
print(f"  Missing features: {missing_count} slides")

print(f"\n  Label distribution after matching:")
for grade in sorted(df_matched['label'].unique()):
    count = len(df_matched[df_matched['label'] == grade])
    print(f"    ISUP {grade}: {count}")

# Step 4: Perform stratified train/val/test split
print(f"\nStep 4: Performing stratified split (70% train, 20% val, 10% test)...")

# First split: separate test set (10%)
train_val_df, test_df = train_test_split(
    df_matched, 
    test_size=0.10, 
    stratify=df_matched['label'],
    random_state=SEED
)

# Second split: separate train and val from remaining 90% (77.78% train, 22.22% val of remaining)
train_df, val_df = train_test_split(
    train_val_df,
    test_size=0.222,  # 0.222 * 0.9 ≈ 0.20 of total
    stratify=train_val_df['label'],
    random_state=SEED
)

# Add split column
train_df['split'] = 'train'
val_df['split'] = 'val'
test_df['split'] = 'test'

# Combine back into single dataframe
df = pd.concat([train_df, val_df, test_df], ignore_index=True)

print(f"\n{'='*70}")
print("SPLIT SUMMARY")
print(f"{'='*70}")
print(f"Total slides: {len(df)}\n")

print("Split distribution:")
print(f"  Train: {len(train_df)} ({len(train_df)/len(df)*100:.1f}%)")
print(f"  Val:   {len(val_df)} ({len(val_df)/len(df)*100:.1f}%)")
print(f"  Test:  {len(test_df)} ({len(test_df)/len(df)*100:.1f}%)")

print(f"\nLabel distribution per split:")
for split_name in ['train', 'val', 'test']:
    split_df = df[df['split'] == split_name]
    print(f"\n{split_name.upper()}:")
    for grade in sorted(df['label'].unique()):
        count = len(split_df[split_df['label'] == grade])
        pct = count / len(split_df) * 100
        print(f"  ISUP {grade}: {count:4d} ({pct:5.1f}%)")

print(f"\n{'='*70}\n")

# Display first few rows
df.head(10)

IMPROVED DATA PREPROCESSING

Step 1: Reading labels from CSV...
  Found 10616 slides in CSV with labels
  Label distribution in CSV:
    ISUP 0: 2892
    ISUP 1: 2666
    ISUP 2: 1343
    ISUP 3: 1242
    ISUP 4: 1249
    ISUP 5: 1224

Step 2: Scanning features directory...
  Found 10615 feature files

Step 3: Matching CSV labels with available features...
  Matched: 10615 slides
  Missing features: 1 slides

  Label distribution after matching:
    ISUP 0: 2891
    ISUP 1: 2666
    ISUP 2: 1343
    ISUP 3: 1242
    ISUP 4: 1249
    ISUP 5: 1224

Step 4: Performing stratified split (70% train, 20% val, 10% test)...

SPLIT SUMMARY
Total slides: 10615

Split distribution:
  Train: 7432 (70.0%)
  Val:   2121 (20.0%)
  Test:  1062 (10.0%)

Label distribution per split:

TRAIN:
  ISUP 0: 2024 ( 27.2%)
  ISUP 1: 1866 ( 25.1%)
  ISUP 2:  941 ( 12.7%)
  ISUP 3:  870 ( 11.7%)
  ISUP 4:  874 ( 11.8%)
  ISUP 5:  857 ( 11.5%)

VAL:
  ISUP 0:  578 ( 27.3%)
  ISUP 1:  533 ( 25.1%)
  ISUP 2:  268 ( 1

Unnamed: 0,slide_id,label,split
0,b600f26e7bc2daf1917e4078496d37ac,2,train
1,36531354b9f3353938e4227f7d8a3ece,2,train
2,4e94a11759c86bca9836e27cfd2fb670,3,train
3,e2b069d3db5dd82610abc684befac2d3,5,train
4,5289d3ebfc24bbe70206daf1f546b687,0,train
5,caaf2645fd3ea43541dad8081e351689,5,train
6,42a4d1519445a251149be56ff8d9ed12,3,train
7,01c977c97e2f5543e65e559d98dec93c,3,train
8,369a3711e3d38086b74319999fcf77f5,2,train
9,5d1abefd41a6663e33995134cbd44838,1,train


#### B- Training an ABMIL Model (using MIL-Lab)

In [59]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, balanced_accuracy_score, cohen_kappa_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Add MIL-Lab to path if needed
# sys.path.insert(0, '/home/nadim/Source/MIL-Lab/MIL-Lab')

# Import MIL-Lab model builder
from src.builder import create_model

# Set deterministic behavior
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Custom dataset for PANDA with UNI v2 features
class PANDAH5Dataset(Dataset):
    def __init__(self, feats_path, df, split, num_features=512):
        self.df = df[df["split"] == split].reset_index(drop=True)
        self.feats_path = feats_path
        self.num_features = num_features
        self.split = split
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        feat_path = os.path.join(self.feats_path, row['slide_id'] + '.h5')
        
        with h5py.File(feat_path, "r") as f:
            # UNI v2 features have shape (1, num_patches, 1536)
            features = torch.from_numpy(f["features"][:]).squeeze(0)  # Remove batch dimension -> (num_patches, 1536)

        # Sample patches for training to control memory
        if self.split == 'train':
            num_available = features.shape[0]
            if num_available >= self.num_features:
                indices = torch.randperm(num_available, generator=torch.Generator().manual_seed(SEED))[:self.num_features]
            else:
                indices = torch.randint(num_available, (self.num_features,), generator=torch.Generator().manual_seed(SEED))
            features = features[indices]

        label = torch.tensor(row["label"], dtype=torch.long)
        return features, label

# Initialize MIL-Lab ABMIL model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create ABMIL model using MIL-Lab
# Model naming: 'abmil.base.uni_v2.none' means:
#   - abmil: model architecture
#   - base: configuration (uses default hyperparameters from src/model_configs/abmil/base.yaml)
#   - uni_v2: encoder type (automatically sets in_dim=1536)
#   - none: no pretrained weights (random initialization)
model = create_model(
    'abmil.base.uni_v2.none',  # Model specification
    num_classes=6,              # 6 ISUP grades (0-5)
    dropout=0.2,                # Override default dropout
    gate=True                   # Use gated attention
).to(device)

print(f"Device: {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

# Create dataloaders
feats_path = '/media/nadim/Data/prostate-cancer-grade-assessment/panda'

batch_size = 32

train_dataset = PANDAH5Dataset(feats_path, df, "train")
val_dataset = PANDAH5Dataset(feats_path, df, "val")
test_dataset = PANDAH5Dataset(feats_path, df, "test")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val samples: {len(val_loader)}")
print(f"Test samples: {len(test_loader)}")

Device: cuda
Model parameters: 1184391

Train batches: 87
Val samples: 798
Test samples: 389




In [None]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

# Training loop with validation
num_epochs = 5
best_val_loss = float('inf')
train_losses = []
val_losses = []
val_accuracies = []

print(f"\n{'='*70}")
print("Starting training...")
print(f"{'='*70}\n")

for epoch in range(num_epochs):
    # Training
    model.train()
    total_loss = 0.

    for batch_idx, (features, labels) in enumerate(train_loader):
        # if FileNotFoundError occurs, skip this batch
        
        # MIL-Lab models take features directly as [B, M, D] tensor
            features, labels = features.to(device), labels.to(device)
            optimizer.zero_grad()
        
        # MIL-Lab forward returns (results_dict, log_dict)
            results_dict, log_dict = model(features, loss_fn=criterion, label=labels)
            loss = results_dict['loss']
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
   
        # Print progress every 50 batches
    if (batch_idx + 1) % 50 == 0:
        print(f"  Epoch [{epoch+1}/{num_epochs}] Batch [{batch_idx+1}/{len(train_loader)}] Loss: {loss.item():.4f}")
    
    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation
    model.eval()
    val_loss = 0.
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
            for features, labels in val_loader:
                features, labels = features.to(device), labels.to(device)
            
            # Forward pass with MIL-Lab
                results_dict, log_dict = model(features, loss_fn=criterion, label=labels)
                logits = results_dict['logits']
                loss = results_dict['loss']
                val_loss += loss.item()
                preds = torch.argmax(logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    val_acc = accuracy_score(all_labels, all_preds)
    val_accuracies.append(val_acc)
    
    # Get current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    
    # Learning rate scheduling
    old_lr = current_lr
    scheduler.step(avg_val_loss)
    new_lr = optimizer.param_groups[0]['lr']
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{num_epochs} Summary:")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss:   {avg_val_loss:.4f}")
    print(f"  Val Acc:    {val_acc:.4f}")
    print(f"  LR:         {new_lr:.6f}")
    
    if new_lr < old_lr:
        print(f"  >>> Learning rate reduced: {old_lr:.6f} -> {new_lr:.6f}")
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_model_panda_millab.pth')
        print(f"  >>> Saved best model (Val Loss: {best_val_loss:.4f})")
    
    print(f"{'-'*70}")

print(f"\n{'='*70}")
print("Training complete!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"{'='*70}\n")


Starting training...



In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax1.plot(train_losses, label='Train Loss', marker='o')
ax1.plot(val_losses, label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss (MIL-Lab)')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy curve
ax2.plot(val_accuracies, label='Val Accuracy', marker='o', color='green')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Validation Accuracy (MIL-Lab)')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves_panda_millab.png', dpi=300, bbox_inches='tight')
plt.show()

#### C- Evaluating the ABMIL Model on Test Set

In [None]:
# Load best model
model.load_state_dict(torch.load('best_model_panda_millab.pth'))
model.eval()

# Test evaluation
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for features, labels in test_loader:
        features, labels = features.to(device), labels.to(device)
        
        # MIL-Lab forward pass
        results_dict, log_dict = model(features)
        logits = results_dict['logits']
        
        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(logits, dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

# Calculate metrics
test_acc = accuracy_score(all_labels, all_preds)
test_balanced_acc = balanced_accuracy_score(all_labels, all_preds)
test_kappa = cohen_kappa_score(all_labels, all_preds, weights='quadratic')

print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Balanced Accuracy: {test_balanced_acc:.4f}")
print(f"Test Quadratic Weighted Kappa: {test_kappa:.4f}")

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=[f'ISUP {i}' for i in range(6)],
            yticklabels=[f'ISUP {i}' for i in range(6)])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix - PANDA Test Set (MIL-Lab)')
plt.savefig('confusion_matrix_panda_millab.png', dpi=300, bbox_inches='tight')
plt.show()

# Per-class accuracy
print("\nPer-class accuracy:")
for i in range(6):
    class_acc = cm[i, i] / cm[i].sum() if cm[i].sum() > 0 else 0
    print(f"  ISUP {i}: {class_acc:.4f} ({cm[i, i]}/{int(cm[i].sum())})")

#### D- Generate Attention Heatmap

**Note**: This section demonstrates how to extract attention scores from MIL-Lab ABMIL models.

**Key Difference from Trident**: 
- MIL-Lab returns attention in `log_dict['attention']` when `return_attention=True`
- Attention has shape `[B, K, M]` where K is number of attention heads (usually 1)
- These are raw attention scores before softmax

In [None]:
# Example: Get attention scores for a single slide
from PIL import Image

# Select a slide from test set
test_df = df[df['split'] == 'test']
slide_id = test_df.iloc[0]['slide_id']
true_label = test_df.iloc[0]['label']

print(f"Processing slide: {slide_id}")
print(f"True label: ISUP {true_label}")

# Load features
feat_path = os.path.join(feats_path, slide_id + '.h5')
with h5py.File(feat_path, 'r') as f:
    patch_features = torch.from_numpy(f['features'][:]).squeeze(0)  # (num_patches, 1536)
    coords = f['coords_patching'][:]
    
    if hasattr(f['coords_patching'], 'attrs') and 'patch_size' in f['coords_patching'].attrs:
        patch_size_level0 = int(f['coords_patching'].attrs['patch_size'])
    else:
        patch_size_level0 = 256

# Ensure coords and features have the same length
min_len = min(len(coords), len(patch_features))
coords = coords[:min_len]
patch_features = patch_features[:min_len]

print(f"Loaded {len(coords)} patches")

# Get attention scores using MIL-Lab
model.eval()
with torch.no_grad():
    features = patch_features.float().to(device).unsqueeze(0)  # Add batch dimension: [1, M, D]
    
    # Request attention scores in log_dict
    results_dict, log_dict = model(features, return_attention=True)
    
    logits = results_dict['logits']
    attention = log_dict['attention']  # Shape: [B, K, M]

predicted_class = torch.argmax(logits, dim=1).item()
attention_scores = attention.cpu().numpy().squeeze()  # Remove batch and head dimensions: [M]

# Ensure attention scores match coords length
attention_scores = attention_scores[:len(coords)]

print(f"Predicted: ISUP {predicted_class}")
print(f"Attention range: [{attention_scores.min():.4f}, {attention_scores.max():.4f}]")
print(f"Attention shape: {attention_scores.shape}")

# Optional: If you have trident installed, you can visualize the heatmap
try:
    from trident import load_wsi, visualize_heatmap
    
    slide_path = f'/media/nadim/Data/prostate-cancer-grade-assessment/train_images/{slide_id}.tiff'
    
    if os.path.exists(slide_path):
        job_dir = './heatmap_output_panda_millab'
        os.makedirs(job_dir, exist_ok=True)
        
        slide = load_wsi(slide_path=slide_path, lazy_init=False)
        
        # Generate heatmap with jet colormap
        heatmap_save_path = visualize_heatmap(
            wsi=slide,
            scores=attention_scores,
            coords=coords,
            vis_level=1,
            patch_size_level0=patch_size_level0,
            normalize=True,
            num_top_patches_to_save=0,
            output_dir=job_dir,
            cmap='jet',
            filename=f'{slide_id}_ISUP{true_label}_heatmap_millab.png'
        )
        
        print(f"\nHeatmap saved to: {heatmap_save_path}")
        
        # Display heatmap
        heatmap_img = Image.open(heatmap_save_path)
        plt.figure(figsize=(12, 10))
        plt.imshow(heatmap_img)
        plt.axis('off')
        plt.title(f'Attention Heatmap (MIL-Lab)\nSlide: {slide_id} | True: ISUP {true_label} | Predicted: ISUP {predicted_class}',
                 fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()
    else:
        print(f"\nSlide not found at {slide_path}")
        print("Attention scores extracted successfully. Install trident to visualize heatmaps.")
        
except ImportError:
    print("\nTrident not installed. Attention scores extracted successfully.")
    print("To visualize heatmaps, install trident: pip install trident")

#### E- Heatmap Visualization (Trident-Compatible)

**Important**: Since the patches were extracted using **Trident**, we use the Trident-compatible visualizer.

**Key Features:**
- **Works with Trident coordinates** - No coordinate mismatch issues
- **Rank-based Normalization** - Same as Trident's approach
- **Alpha Blending** - Smooth overlay on original H&E image
- **Top-K Patch Sampling** - Automatically saves patches with highest attention
- **Multiple Colormaps** - Support for jet, coolwarm, viridis, etc.

**Why Not CLAM-style?** CLAM does tissue segmentation at visualization time, which doesn't align with Trident's patch coordinates. Trident does tissue masking during extraction, so we just overlay on the full image.

In [None]:
# Import Trident-compatible visualizer
from src.visualization import TridentVisualizer
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable

# Visualize examples from different ISUP grades
test_df = df[df['split'] == 'test']

# Get one example from each ISUP grade (0-5)
example_slides = []
for grade in range(6):
    grade_slides = test_df[test_df['label'] == grade]
    if len(grade_slides) > 0:
        example_slides.append(grade_slides.iloc[0])
    else:
        print(f"Warning: No test slides found for ISUP grade {grade}")

print(f"{'='*70}")
print(f"Visualizing {len(example_slides)} slides from different ISUP grades")
print(f"{'='*70}\n")

# Create output directory
output_dir = './heatmap_output_trident_style'
os.makedirs(output_dir, exist_ok=True)

# Visualize each example
num_examples = len(example_slides)
fig, axes = plt.subplots(2, 3, figsize=(24, 16))
axes = axes.flatten()

for idx, slide_row in enumerate(example_slides):
    slide_id = slide_row['slide_id']
    true_label = slide_row['label']
    
    print(f"Processing slide {idx+1}/{num_examples}: {slide_id} (ISUP {true_label})")
    
    # Load features and coordinates
    feat_path = os.path.join(feats_path, slide_id + '.h5')
    slide_path = f'/media/nadim/Data/prostate-cancer-grade-assessment/train_images/{slide_id}.tiff'
    
    with h5py.File(feat_path, 'r') as f:
        patch_features = torch.from_numpy(f['features'][:]).squeeze(0)
        coords = f['coords_patching'][:]
        
        if hasattr(f['coords_patching'], 'attrs') and 'patch_size' in f['coords_patching'].attrs:
            patch_size_level0 = int(f['coords_patching'].attrs['patch_size'])
        else:
            patch_size_level0 = 256
    
    # Ensure matching lengths
    min_len = min(len(coords), len(patch_features))
    coords = coords[:min_len]
    patch_features = patch_features[:min_len]
    
    # Get attention scores
    model.eval()
    with torch.no_grad():
        features_input = patch_features.float().to(device).unsqueeze(0)
        results_dict, log_dict = model(features_input, return_attention=True)
        attention_scores = log_dict['attention'].cpu().numpy().squeeze()[:len(coords)]
        predicted_class = torch.argmax(results_dict['logits'], dim=1).item()
    
    # Initialize visualizer
    viz = TridentVisualizer(model, wsi_path=slide_path)
    
    # Create heatmap
    heatmap = viz.create_heatmap(
        features=patch_features,
        coords=coords,
        attention_scores=attention_scores,
        patch_size_level0=patch_size_level0,
        vis_level=-1,
        cmap='jet',
        alpha=0.4,
        normalize=True,
        output_path=os.path.join(output_dir, f'{slide_id}_ISUP{true_label}_pred{predicted_class}.png')
    )
    
    # Display heatmap with colorbar
    im = axes[idx].imshow(heatmap)
    axes[idx].axis('off')
    
    # Add colorbar
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(axes[idx])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = plt.colorbar(ScalarMappable(norm=Normalize(vmin=0, vmax=1), cmap='jet'), cax=cax)
    cbar.set_label('Attention Score', rotation=270, labelpad=15)
    
    # Title with true and predicted labels
    title_color = 'green' if predicted_class == true_label else 'red'
    axes[idx].set_title(
        f'ISUP {true_label} → Pred: {predicted_class}\nSlide: {slide_id}',
        fontsize=12, fontweight='bold', color=title_color, pad=10
    )
    
    print(f"  Attention range: [{attention_scores.min():.4f}, {attention_scores.max():.4f}]")
    print(f"  Predicted: ISUP {predicted_class}\n")

# Hide unused subplots if less than 6 grades available
for idx in range(len(example_slides), 6):
    axes[idx].axis('off')

plt.suptitle(
    'ABMIL Attention Heatmaps - One Example per ISUP Grade',
    fontsize=18, fontweight='bold', y=0.995
)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'all_grades_comparison.png'), dpi=200, bbox_inches='tight')
plt.show()

print(f"\n{'='*70}")
print("Visualization Complete!")
print(f"{'='*70}")
print(f"Output directory: {output_dir}")
print(f"- Individual heatmaps saved for each grade")
print(f"- Comparison grid: all_grades_comparison.png")

In [None]:
# Compare different colormaps for the same slide
print("Comparing different colormaps...\n")

# Use the first example slide
slide_row = example_slides[0]
slide_id = slide_row['slide_id']
true_label = slide_row['label']

# Load features
feat_path = os.path.join(feats_path, slide_id + '.h5')
slide_path = f'/media/nadim/Data/prostate-cancer-grade-assessment/train_images/{slide_id}.tiff'

with h5py.File(feat_path, 'r') as f:
    patch_features = torch.from_numpy(f['features'][:]).squeeze(0)
    coords = f['coords_patching'][:]
    patch_size_level0 = 256

min_len = min(len(coords), len(patch_features))
coords = coords[:min_len]
patch_features = patch_features[:min_len]

# Get attention scores
model.eval()
with torch.no_grad():
    features_input = patch_features.float().to(device).unsqueeze(0)
    results_dict, log_dict = model(features_input, return_attention=True)
    attention_scores = log_dict['attention'].cpu().numpy().squeeze()[:len(coords)]
    predicted_class = torch.argmax(results_dict['logits'], dim=1).item()

# Initialize visualizer
viz = TridentVisualizer(model, wsi_path=slide_path)

# Different colormap options
cmaps = [
    ('jet', 'Jet (Rainbow)'),
    ('coolwarm', 'Coolwarm (Diverging)'),
    ('hot', 'Hot (Sequential)'),
    ('viridis', 'Viridis (Perceptual)')
]

fig, axes = plt.subplots(2, 2, figsize=(20, 16))
axes = axes.flatten()

for idx, (cmap_name, cmap_label) in enumerate(cmaps):
    print(f"Creating heatmap with {cmap_label}...")
    
    heatmap = viz.create_heatmap(
        features=patch_features,
        coords=coords,
        attention_scores=attention_scores,
        patch_size_level0=patch_size_level0,
        vis_level=-1,
        cmap=cmap_name,
        alpha=0.4,
        normalize=True
    )
    
    # Display heatmap
    im = axes[idx].imshow(heatmap)
    axes[idx].axis('off')
    
    # Add colorbar
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(axes[idx])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = plt.colorbar(ScalarMappable(norm=Normalize(vmin=0, vmax=1), cmap=cmap_name), cax=cax)
    cbar.set_label('Attention', rotation=270, labelpad=15)
    
    axes[idx].set_title(cmap_label, fontsize=14, fontweight='bold', pad=10)

plt.suptitle(
    f'Colormap Comparison - {slide_id} (ISUP {true_label} → Pred: {predicted_class})',
    fontsize=18, fontweight='bold', y=0.995
)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f'{slide_id}_colormap_comparison.png'), dpi=200, bbox_inches='tight')
plt.show()

print("\nColormap comparison complete!")

---

## Key Differences: MIL-Lab vs Trident

### Model Creation
**Trident**:
```python
from trident.slide_encoder_models import ABMILSlideEncoder
model = MultiClassABMILModel(input_feature_dim=1536, n_classes=6)
```

**MIL-Lab**:
```python
from src.builder import create_model
model = create_model('abmil.base.uni_v2.none', num_classes=6)
```

### Forward Pass
**Trident**:
```python
features_dict = {'features': features.to(device)}
outputs = model(features_dict)
```

**MIL-Lab**:
```python
features = features.to(device)  # Direct tensor, no dict wrapper
results_dict, log_dict = model(features, loss_fn=criterion, label=labels)
logits = results_dict['logits']
loss = results_dict['loss']
```

### Attention Retrieval
**Trident**:
```python
logits, attention = model(features_dict, return_raw_attention=True)
```

**MIL-Lab**:
```python
results_dict, log_dict = model(features, return_attention=True)
attention = log_dict['attention']  # Shape: [B, K, M]
```

### Visualization (IMPORTANT!)

**Trident**:
```python
from trident import visualize_heatmap
heatmap_path = visualize_heatmap(wsi=wsi, scores=scores, coords=coords, ...)
```

**MIL-Lab (Trident-Compatible)**:
```python
from src.visualization import TridentVisualizer
viz = TridentVisualizer(model, wsi_path=slide_path)
heatmap = viz.create_heatmap(features=features, coords=coords, ...)
```

**Why Trident-Compatible?** Since patches were extracted using Trident, the coordinates are Trident's. MIL-Lab provides a Trident-compatible visualizer that works with these coordinates without tissue segmentation mismatches.

### Advantages of MIL-Lab
1. **Standardized interface** across all MIL models (ABMIL, TransMIL, CLAM, etc.)
2. **Easy model switching**: Change `'abmil'` to `'transmil'`, `'clam'`, `'dsmil'`, etc.
3. **Pretrained weights**: Load `'abmil.base.uni_v2.pc108-24k'` for transfer learning
4. **Encoder flexibility**: Automatically handles different feature dimensions (UNI, UNIv2, CONCH, etc.)
5. **HuggingFace integration**: Compatible with transformers library
6. **Dual visualization**: Both Trident-compatible and CLAM-style visualizers available