In [1]:
!pip install 'zarr<3'
!pip install timm


Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
# ALWAYS RUN THIS FIRST!
import os
import sys
from pathlib import Path

NOTEBOOK_DIR = Path("/rsrch9/home/plm/idso_fa1_pathology/codes/yshokrollahi/vitamin-p-latest")
os.chdir(NOTEBOOK_DIR)
sys.path.insert(0, str(NOTEBOOK_DIR))

print(f"‚úÖ Working directory: {os.getcwd()}")

‚úÖ Working directory: /rsrch9/home/plm/idso_fa1_pathology/codes/yshokrollahi/vitamin-p-latest


## Data Loading

In [3]:
# Cell 3: Import and create dataloaders
from dataset import Config, create_dataloaders

# Just use the correct relative path from your working directory
config = Config("configs/config_fold13.yaml")  # Note: "configs" not "config"
config.print_config()

train_loader, val_loader, test_loader = create_dataloaders(config)
print("\n‚úÖ Ready to use!")

‚úÖ CRC Dataset Package v1.0.0 loaded
CRC DATASET CONFIGURATION
Config File: configs/config_fold13.yaml
Zarr Base: /rsrch9/home/plm/idso_fa1_pathology/TIER2/yasin-vitaminp/ORION-CRC-Syn/zarr_data
Cache: ./cache/multimodal_syn_dataset_cache_fold3.pkl
Strategy: memory

üìä Data Splits:
  Train: 33 samples
  Val: 9 samples
  Test: 8 samples

üîÑ DataLoader:
  Batch Size: 4
  Num Workers: 0
  Pin Memory: True

üé® Augmentation:
  Training: True
  Probability: 0.0

üéØ HV Maps:
  Generate: True
  Method: pannuke
  HE Nuclei: True
  HE Cells: True
  MIF Nuclei: True
  MIF Cells: True

üîç Filtering:
  Min Instances: 0
  Filter Empty: True

CREATING DATALOADERS
Strategy: memory
Use Cache: True
Batch Size: 4
Num Workers: 0

Train split: 27 CRC + 6 Xenium samples
Val split: 7 CRC + 2 Xenium samples
Test split: 7 CRC + 1 Xenium samples

üì¶ Loading from cache: ./cache/combined_cache_train_0be9581c.pkl
üì¶ Loaded 3294 patches from cache
üì¶ Loading from cache: ./cache/combined_cache_val_0

## Syn model Segmentions 

In [8]:
import torch
import numpy as np
from collections import defaultdict
from metrics import (
    dice_coefficient, 
    iou_score, 
    precision_score, 
    recall_score,
    f1_score
)
from vitaminp import VitaminPSyn, SimplePreprocessing
from tqdm import tqdm

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load DUAL model
print("\nüì¶ Loading VitaminPSyn model...")
model = VitaminPSyn(model_size='base', dropout_rate=0.3, freeze_backbone=False)
checkpoint_path = "checkpoints/vitamin_p_syn_base_fold12_best.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
model = model.to(device)
print(f"‚úÖ VitaminPDual model loaded")

# Preprocessor
preprocessor = SimplePreprocessing()

# Initialize metric storage PER SAMPLE for BOTH H&E and MIF
sample_metrics = defaultdict(lambda: {
    'he_nuclei_dice': [], 'he_nuclei_iou': [], 
    'he_cell_dice': [], 'he_cell_iou': [],
    'mif_nuclei_dice': [], 'mif_nuclei_iou': [],
    'mif_cell_dice': [], 'mif_cell_iou': [],
    'patch_count': 0
})

print(f"\nüîÑ Evaluating on all test samples (H&E + MIF - DUAL MODEL - BINARY METRICS ONLY)...")

skipped_batches = 0

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader, desc="Processing batches")):
        
        try:
            # ========== PREPARE INPUTS ==========
            he_img = batch['he_image'].to(device)
            mif_img = batch['mif_image'].to(device)
            
            # Normalize
            he_img = preprocessor.percentile_normalize(he_img)
            mif_img = preprocessor.percentile_normalize(mif_img)
            
            # Ground truth masks
            he_nuclei_mask_gt = batch['he_nuclei_mask'].float().unsqueeze(1).to(device)
            he_cell_mask_gt = batch['he_cell_mask'].float().unsqueeze(1).to(device)
            mif_nuclei_mask_gt = batch['mif_nuclei_mask'].float().unsqueeze(1).to(device)
            mif_cell_mask_gt = batch['mif_cell_mask'].float().unsqueeze(1).to(device)
            
            # Get sample names
            sample_names = batch['sample_name']
            
            # ========== INFERENCE WITH DUAL MODEL (processes both H&E and MIF together) ==========
            outputs = model(he_img, mif_img)
            
            # Get predictions for H&E
            pred_he_nuclei = (outputs['he_nuclei_seg'] > 0.5).float()
            pred_he_cell = (outputs['he_cell_seg'] > 0.5).float()
            
            # Get predictions for MIF
            pred_mif_nuclei = (outputs['mif_nuclei_seg'] > 0.5).float()
            pred_mif_cell = (outputs['mif_cell_seg'] > 0.5).float()
            
            # Process each sample in the batch
            batch_size = pred_he_nuclei.shape[0]
            
            for i in range(batch_size):
                sample_name = sample_names[i]
                
                # ========== H&E METRICS ==========
                sample_metrics[sample_name]['he_nuclei_dice'].append(
                    dice_coefficient(pred_he_nuclei[i:i+1], he_nuclei_mask_gt[i:i+1]))
                sample_metrics[sample_name]['he_nuclei_iou'].append(
                    iou_score(pred_he_nuclei[i:i+1], he_nuclei_mask_gt[i:i+1]))
                sample_metrics[sample_name]['he_cell_dice'].append(
                    dice_coefficient(pred_he_cell[i:i+1], he_cell_mask_gt[i:i+1]))
                sample_metrics[sample_name]['he_cell_iou'].append(
                    iou_score(pred_he_cell[i:i+1], he_cell_mask_gt[i:i+1]))
                
                # ========== MIF METRICS ==========
                sample_metrics[sample_name]['mif_nuclei_dice'].append(
                    dice_coefficient(pred_mif_nuclei[i:i+1], mif_nuclei_mask_gt[i:i+1]))
                sample_metrics[sample_name]['mif_nuclei_iou'].append(
                    iou_score(pred_mif_nuclei[i:i+1], mif_nuclei_mask_gt[i:i+1]))
                sample_metrics[sample_name]['mif_cell_dice'].append(
                    dice_coefficient(pred_mif_cell[i:i+1], mif_cell_mask_gt[i:i+1]))
                sample_metrics[sample_name]['mif_cell_iou'].append(
                    iou_score(pred_mif_cell[i:i+1], mif_cell_mask_gt[i:i+1]))
                
                sample_metrics[sample_name]['patch_count'] += 1
        
        except RuntimeError as e:
            print(f"\n‚ö†Ô∏è Skipping batch {batch_idx} due to error: {str(e)[:100]}")
            skipped_batches += 1
            continue

# ========== COMPUTE PER-SAMPLE AVERAGES ==========
print("\n" + "="*100)
print("üìä PER-SAMPLE RESULTS - H&E (VitaminPDual)")
print("="*100)

# Separate CRC and Xenium samples
crc_samples = {}
xenium_samples = {}

for sample_name, metrics in sample_metrics.items():
    # Calculate averages for this sample
    avg_metrics = {
        'he_nuclei_dice': np.mean(metrics['he_nuclei_dice']),
        'he_nuclei_iou': np.mean(metrics['he_nuclei_iou']),
        'he_cell_dice': np.mean(metrics['he_cell_dice']),
        'he_cell_iou': np.mean(metrics['he_cell_iou']),
        'mif_nuclei_dice': np.mean(metrics['mif_nuclei_dice']),
        'mif_nuclei_iou': np.mean(metrics['mif_nuclei_iou']),
        'mif_cell_dice': np.mean(metrics['mif_cell_dice']),
        'mif_cell_iou': np.mean(metrics['mif_cell_iou']),
        'patch_count': metrics['patch_count']
    }
    
    if sample_name.startswith('CRC'):
        crc_samples[sample_name] = avg_metrics
    else:
        xenium_samples[sample_name] = avg_metrics

# Print CRC samples - H&E
if crc_samples:
    print("\nüî¨ CRC SAMPLES - H&E:")
    print("-" * 100)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei Dice':>13} {'Nuclei IoU':>12} {'Cell Dice':>12} {'Cell IoU':>11}")
    print("-" * 100)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_dice']:>13.4f} {m['he_nuclei_iou']:>12.4f} "
              f"{m['he_cell_dice']:>12.4f} {m['he_cell_iou']:>11.4f}")

# Print Xenium samples - H&E
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - H&E:")
    print("-" * 100)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei Dice':>13} {'Nuclei IoU':>12} {'Cell Dice':>12} {'Cell IoU':>11}")
    print("-" * 100)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_dice']:>13.4f} {m['he_nuclei_iou']:>12.4f} "
              f"{m['he_cell_dice']:>12.4f} {m['he_cell_iou']:>11.4f}")

# ========== MIF RESULTS ==========
print("\n" + "="*100)
print("üìä PER-SAMPLE RESULTS - MIF (VitaminPDual)")
print("="*100)

# Print CRC samples - MIF
if crc_samples:
    print("\nüî¨ CRC SAMPLES - MIF:")
    print("-" * 100)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei Dice':>13} {'Nuclei IoU':>12} {'Cell Dice':>12} {'Cell IoU':>11}")
    print("-" * 100)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_dice']:>13.4f} {m['mif_nuclei_iou']:>12.4f} "
              f"{m['mif_cell_dice']:>12.4f} {m['mif_cell_iou']:>11.4f}")

# Print Xenium samples - MIF
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - MIF:")
    print("-" * 100)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei Dice':>13} {'Nuclei IoU':>12} {'Cell Dice':>12} {'Cell IoU':>11}")
    print("-" * 100)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_dice']:>13.4f} {m['mif_nuclei_iou']:>12.4f} "
              f"{m['mif_cell_dice']:>12.4f} {m['mif_cell_iou']:>11.4f}")

# ========== OVERALL STATISTICS ==========
print("\n" + "="*100)
print("üìä OVERALL TEST SET RESULTS (VitaminPDual)")
print("="*100)

# Collect all metrics across all samples
all_he_nuclei_dice = []
all_he_nuclei_iou = []
all_he_cell_dice = []
all_he_cell_iou = []
all_mif_nuclei_dice = []
all_mif_nuclei_iou = []
all_mif_cell_dice = []
all_mif_cell_iou = []

for sample_name, metrics in sample_metrics.items():
    all_he_nuclei_dice.extend(metrics['he_nuclei_dice'])
    all_he_nuclei_iou.extend(metrics['he_nuclei_iou'])
    all_he_cell_dice.extend(metrics['he_cell_dice'])
    all_he_cell_iou.extend(metrics['he_cell_iou'])
    all_mif_nuclei_dice.extend(metrics['mif_nuclei_dice'])
    all_mif_nuclei_iou.extend(metrics['mif_nuclei_iou'])
    all_mif_cell_dice.extend(metrics['mif_cell_dice'])
    all_mif_cell_iou.extend(metrics['mif_cell_iou'])

total_patches = sum(m['patch_count'] for m in sample_metrics.values())

print(f"\nTotal samples: {len(sample_metrics)}")
print(f"Total patches: {total_patches}")
if skipped_batches > 0:
    print(f"‚ö†Ô∏è Skipped batches: {skipped_batches}")

print("\n" + "="*50)
print("H&E RESULTS")
print("="*50)
print("\nüî¨ H&E NUCLEI METRICS (all patches):")
print(f"  Dice:      {np.mean(all_he_nuclei_dice):.4f} ¬± {np.std(all_he_nuclei_dice):.4f}")
print(f"  IoU:       {np.mean(all_he_nuclei_iou):.4f} ¬± {np.std(all_he_nuclei_iou):.4f}")

print("\nüß¨ H&E CELL METRICS (all patches):")
print(f"  Dice:      {np.mean(all_he_cell_dice):.4f} ¬± {np.std(all_he_cell_dice):.4f}")
print(f"  IoU:       {np.mean(all_he_cell_iou):.4f} ¬± {np.std(all_he_cell_iou):.4f}")

print("\n" + "="*50)
print("MIF RESULTS")
print("="*50)
print("\nüî¨ MIF NUCLEI METRICS (all patches):")
print(f"  Dice:      {np.mean(all_mif_nuclei_dice):.4f} ¬± {np.std(all_mif_nuclei_dice):.4f}")
print(f"  IoU:       {np.mean(all_mif_nuclei_iou):.4f} ¬± {np.std(all_mif_nuclei_iou):.4f}")

print("\nüß¨ MIF CELL METRICS (all patches):")
print(f"  Dice:      {np.mean(all_mif_cell_dice):.4f} ¬± {np.std(all_mif_cell_dice):.4f}")
print(f"  IoU:       {np.mean(all_mif_cell_iou):.4f} ¬± {np.std(all_mif_cell_iou):.4f}")

# Statistics by cancer type
if crc_samples:
    print("\n" + "="*50)
    print(f"CRC SAMPLES ({len(crc_samples)} samples)")
    print("="*50)
    
    crc_he_nuclei_dice = [m['he_nuclei_dice'] for m in crc_samples.values()]
    crc_he_cell_dice = [m['he_cell_dice'] for m in crc_samples.values()]
    crc_mif_nuclei_dice = [m['mif_nuclei_dice'] for m in crc_samples.values()]
    crc_mif_cell_dice = [m['mif_cell_dice'] for m in crc_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei Dice: {np.mean(crc_he_nuclei_dice):.4f} ¬± {np.std(crc_he_nuclei_dice):.4f}")
    print(f"  Cell Dice:   {np.mean(crc_he_cell_dice):.4f} ¬± {np.std(crc_he_cell_dice):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei Dice: {np.mean(crc_mif_nuclei_dice):.4f} ¬± {np.std(crc_mif_nuclei_dice):.4f}")
    print(f"  Cell Dice:   {np.mean(crc_mif_cell_dice):.4f} ¬± {np.std(crc_mif_cell_dice):.4f}")

if xenium_samples:
    print("\n" + "="*50)
    print(f"XENIUM SAMPLES ({len(xenium_samples)} samples)")
    print("="*50)
    
    xenium_he_nuclei_dice = [m['he_nuclei_dice'] for m in xenium_samples.values()]
    xenium_he_cell_dice = [m['he_cell_dice'] for m in xenium_samples.values()]
    xenium_mif_nuclei_dice = [m['mif_nuclei_dice'] for m in xenium_samples.values()]
    xenium_mif_cell_dice = [m['mif_cell_dice'] for m in xenium_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei Dice: {np.mean(xenium_he_nuclei_dice):.4f} ¬± {np.std(xenium_he_nuclei_dice):.4f}")
    print(f"  Cell Dice:   {np.mean(xenium_he_cell_dice):.4f} ¬± {np.std(xenium_he_cell_dice):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei Dice: {np.mean(xenium_mif_nuclei_dice):.4f} ¬± {np.std(xenium_mif_nuclei_dice):.4f}")
    print(f"  Cell Dice:   {np.mean(xenium_mif_cell_dice):.4f} ¬± {np.std(xenium_mif_cell_dice):.4f}")

print("\n" + "="*100)
print("‚úÖ Evaluation complete for VitaminPDual (both H&E and MIF)!")

Using device: cuda

üì¶ Loading VitaminPSyn model...
Building H&E encoder with DINOv2-base
Building Synthetic MIF encoder with DINOv2-base
Building shared encoder with DINOv2-base
‚úì VitaminPSyn initialized with base backbone
  Embed dim: 768 | Decoder dims: [768, 384, 192, 96]
‚úÖ VitaminPDual model loaded

üîÑ Evaluating on all test samples (H&E + MIF - DUAL MODEL - BINARY METRICS ONLY)...


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 127/127 [00:25<00:00,  4.91it/s]


üìä PER-SAMPLE RESULTS - H&E (VitaminPDual)

üî¨ CRC SAMPLES - H&E:
----------------------------------------------------------------------------------------------------
Sample        Patches   Nuclei Dice   Nuclei IoU    Cell Dice    Cell IoU
----------------------------------------------------------------------------------------------------
CRC15              18        0.8584       0.7526       0.9256      0.8626
CRC16             142        0.9138       0.8417       0.9626      0.9282
CRC17              76        0.8316       0.7132       0.9374      0.8827
CRC18              80        0.8745       0.7775       0.9314      0.8723
CRC19              29        0.8332       0.7165       0.9354      0.8793
CRC20              35        0.8291       0.7092       0.9197      0.8530
CRC39              62        0.8553       0.7492       0.9168      0.8478

üß¨ XENIUM SAMPLES - H&E:
----------------------------------------------------------------------------------------------------
Sample




## Syn model instance metrics

In [4]:
import torch
import numpy as np
from collections import defaultdict
from metrics import (
    get_fast_pq,
    aggregated_jaccard_index
)
from vitaminp import VitaminPSyn, SimplePreprocessing
from postprocessing import process_model_outputs
from tqdm import tqdm

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load SYN model
print("\nüì¶ Loading VitaminPSyn model...")
model = VitaminPSyn(model_size='base', dropout_rate=0.3, freeze_backbone=False)
checkpoint_path = "checkpoints/vitamin_p_syn_base_fold13_best.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
model = model.to(device)
print(f"‚úÖ VitaminPSyn model loaded")

# Preprocessor
preprocessor = SimplePreprocessing()

# Initialize metric storage PER SAMPLE for BOTH H&E and MIF - INSTANCE METRICS ONLY
sample_metrics = defaultdict(lambda: {
    'he_nuclei_pq': [], 'he_nuclei_dq': [], 'he_nuclei_sq': [], 'he_nuclei_aji': [], 
    'he_cell_pq': [], 'he_cell_dq': [], 'he_cell_sq': [], 'he_cell_aji': [],
    'mif_nuclei_pq': [], 'mif_nuclei_dq': [], 'mif_nuclei_sq': [], 'mif_nuclei_aji': [],
    'mif_cell_pq': [], 'mif_cell_dq': [], 'mif_cell_sq': [], 'mif_cell_aji': [],
    'patch_count': 0
})

print(f"\nüîÑ Evaluating on all test samples (H&E + MIF - SYN MODEL - INSTANCE METRICS ONLY)...")

skipped_batches = 0

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader, desc="Processing batches")):
        
        try:
            # ========== PREPARE INPUTS ==========
            he_img = batch['he_image'].to(device)
            mif_img = batch['mif_image'].to(device)
            
            # Normalize
            he_img = preprocessor.percentile_normalize(he_img)
            mif_img = preprocessor.percentile_normalize(mif_img)
            
            # Ground truth instance maps
            he_nuclei_inst_gt = batch['he_nuclei_instance'].cpu().numpy()
            he_cell_inst_gt = batch['he_cell_instance'].cpu().numpy()
            mif_nuclei_inst_gt = batch['mif_nuclei_instance'].cpu().numpy()
            mif_cell_inst_gt = batch['mif_cell_instance'].cpu().numpy()
            
            # Get sample names
            sample_names = batch['sample_name']
            
            # ========== INFERENCE WITH SYN MODEL (processes both H&E and MIF together) ==========
            outputs = model(he_img, mif_img)
            
            # Process each sample in the batch
            batch_size = he_img.shape[0]
            
            for i in range(batch_size):
                sample_name = sample_names[i]
                
                # ========== H&E INSTANCE SEGMENTATION ==========
                # Extract predictions for post-processing
                he_nuclei_seg_np = outputs['he_nuclei_seg'][i, 0].cpu().numpy()
                he_nuclei_h_map = outputs['he_nuclei_hv'][i, 0].cpu().numpy()
                he_nuclei_v_map = outputs['he_nuclei_hv'][i, 1].cpu().numpy()
                
                he_cell_seg_np = outputs['he_cell_seg'][i, 0].cpu().numpy()
                he_cell_h_map = outputs['he_cell_hv'][i, 0].cpu().numpy()
                he_cell_v_map = outputs['he_cell_hv'][i, 1].cpu().numpy()
                
                # Apply post-processing to get instance maps
                he_nuclei_inst_pred, _, _ = process_model_outputs(
                    he_nuclei_seg_np, he_nuclei_h_map, he_nuclei_v_map,
                    magnification=40
                )
                
                he_cell_inst_pred, _, _ = process_model_outputs(
                    he_cell_seg_np, he_cell_h_map, he_cell_v_map,
                    magnification=40
                )
                
                # Compute H&E instance metrics (returns pq, dq, sq)
                he_nuclei_pq, he_nuclei_dq, he_nuclei_sq = get_fast_pq(he_nuclei_inst_gt[i], he_nuclei_inst_pred)
                he_cell_pq, he_cell_dq, he_cell_sq = get_fast_pq(he_cell_inst_gt[i], he_cell_inst_pred)
                
                sample_metrics[sample_name]['he_nuclei_pq'].append(he_nuclei_pq)
                sample_metrics[sample_name]['he_nuclei_dq'].append(he_nuclei_dq)
                sample_metrics[sample_name]['he_nuclei_sq'].append(he_nuclei_sq)
                sample_metrics[sample_name]['he_nuclei_aji'].append(
                    aggregated_jaccard_index(he_nuclei_inst_gt[i], he_nuclei_inst_pred))
                
                sample_metrics[sample_name]['he_cell_pq'].append(he_cell_pq)
                sample_metrics[sample_name]['he_cell_dq'].append(he_cell_dq)
                sample_metrics[sample_name]['he_cell_sq'].append(he_cell_sq)
                sample_metrics[sample_name]['he_cell_aji'].append(
                    aggregated_jaccard_index(he_cell_inst_gt[i], he_cell_inst_pred))
                
                # ========== MIF INSTANCE SEGMENTATION ==========
                # Extract predictions for post-processing
                mif_nuclei_seg_np = outputs['mif_nuclei_seg'][i, 0].cpu().numpy()
                mif_nuclei_h_map = outputs['mif_nuclei_hv'][i, 0].cpu().numpy()
                mif_nuclei_v_map = outputs['mif_nuclei_hv'][i, 1].cpu().numpy()
                
                mif_cell_seg_np = outputs['mif_cell_seg'][i, 0].cpu().numpy()
                mif_cell_h_map = outputs['mif_cell_hv'][i, 0].cpu().numpy()
                mif_cell_v_map = outputs['mif_cell_hv'][i, 1].cpu().numpy()
                
                # Apply post-processing to get instance maps
                mif_nuclei_inst_pred, _, _ = process_model_outputs(
                    mif_nuclei_seg_np, mif_nuclei_h_map, mif_nuclei_v_map,
                    magnification=40
                )
                
                mif_cell_inst_pred, _, _ = process_model_outputs(
                    mif_cell_seg_np, mif_cell_h_map, mif_cell_v_map,
                    magnification=40
                )
                
                # Compute MIF instance metrics (returns pq, dq, sq)
                mif_nuclei_pq, mif_nuclei_dq, mif_nuclei_sq = get_fast_pq(mif_nuclei_inst_gt[i], mif_nuclei_inst_pred)
                mif_cell_pq, mif_cell_dq, mif_cell_sq = get_fast_pq(mif_cell_inst_gt[i], mif_cell_inst_pred)
                
                sample_metrics[sample_name]['mif_nuclei_pq'].append(mif_nuclei_pq)
                sample_metrics[sample_name]['mif_nuclei_dq'].append(mif_nuclei_dq)
                sample_metrics[sample_name]['mif_nuclei_sq'].append(mif_nuclei_sq)
                sample_metrics[sample_name]['mif_nuclei_aji'].append(
                    aggregated_jaccard_index(mif_nuclei_inst_gt[i], mif_nuclei_inst_pred))
                
                sample_metrics[sample_name]['mif_cell_pq'].append(mif_cell_pq)
                sample_metrics[sample_name]['mif_cell_dq'].append(mif_cell_dq)
                sample_metrics[sample_name]['mif_cell_sq'].append(mif_cell_sq)
                sample_metrics[sample_name]['mif_cell_aji'].append(
                    aggregated_jaccard_index(mif_cell_inst_gt[i], mif_cell_inst_pred))
                
                sample_metrics[sample_name]['patch_count'] += 1
        
        except RuntimeError as e:
            print(f"\n‚ö†Ô∏è Skipping batch {batch_idx} due to error: {str(e)[:100]}")
            skipped_batches += 1
            continue

# ========== COMPUTE PER-SAMPLE AVERAGES ==========
print("\n" + "="*120)
print("üìä PER-SAMPLE RESULTS - H&E INSTANCE METRICS (VitaminPSyn)")
print("="*120)

# Separate CRC and Xenium samples
crc_samples = {}
xenium_samples = {}

for sample_name, metrics in sample_metrics.items():
    # Calculate averages for this sample
    avg_metrics = {
        'he_nuclei_pq': np.mean(metrics['he_nuclei_pq']),
        'he_nuclei_dq': np.mean(metrics['he_nuclei_dq']),
        'he_nuclei_sq': np.mean(metrics['he_nuclei_sq']),
        'he_nuclei_aji': np.mean(metrics['he_nuclei_aji']),
        'he_cell_pq': np.mean(metrics['he_cell_pq']),
        'he_cell_dq': np.mean(metrics['he_cell_dq']),
        'he_cell_sq': np.mean(metrics['he_cell_sq']),
        'he_cell_aji': np.mean(metrics['he_cell_aji']),
        'mif_nuclei_pq': np.mean(metrics['mif_nuclei_pq']),
        'mif_nuclei_dq': np.mean(metrics['mif_nuclei_dq']),
        'mif_nuclei_sq': np.mean(metrics['mif_nuclei_sq']),
        'mif_nuclei_aji': np.mean(metrics['mif_nuclei_aji']),
        'mif_cell_pq': np.mean(metrics['mif_cell_pq']),
        'mif_cell_dq': np.mean(metrics['mif_cell_dq']),
        'mif_cell_sq': np.mean(metrics['mif_cell_sq']),
        'mif_cell_aji': np.mean(metrics['mif_cell_aji']),
        'patch_count': metrics['patch_count']
    }
    
    if sample_name.startswith('CRC'):
        crc_samples[sample_name] = avg_metrics
    else:
        xenium_samples[sample_name] = avg_metrics

# Print CRC samples - H&E
if crc_samples:
    print("\nüî¨ CRC SAMPLES - H&E INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_pq']:>11.4f} {m['he_nuclei_dq']:>11.4f} {m['he_nuclei_sq']:>11.4f} {m['he_nuclei_aji']:>12.4f} "
              f"{m['he_cell_pq']:>9.4f} {m['he_cell_dq']:>9.4f} {m['he_cell_sq']:>9.4f} {m['he_cell_aji']:>10.4f}")

# Print Xenium samples - H&E
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - H&E INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_pq']:>11.4f} {m['he_nuclei_dq']:>11.4f} {m['he_nuclei_sq']:>11.4f} {m['he_nuclei_aji']:>12.4f} "
              f"{m['he_cell_pq']:>9.4f} {m['he_cell_dq']:>9.4f} {m['he_cell_sq']:>9.4f} {m['he_cell_aji']:>10.4f}")

# ========== MIF INSTANCE RESULTS ==========
print("\n" + "="*120)
print("üìä PER-SAMPLE RESULTS - MIF INSTANCE METRICS (VitaminPSyn)")
print("="*120)

# Print CRC samples - MIF
if crc_samples:
    print("\nüî¨ CRC SAMPLES - MIF INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_pq']:>11.4f} {m['mif_nuclei_dq']:>11.4f} {m['mif_nuclei_sq']:>11.4f} {m['mif_nuclei_aji']:>12.4f} "
              f"{m['mif_cell_pq']:>9.4f} {m['mif_cell_dq']:>9.4f} {m['mif_cell_sq']:>9.4f} {m['mif_cell_aji']:>10.4f}")

# Print Xenium samples - MIF
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - MIF INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_pq']:>11.4f} {m['mif_nuclei_dq']:>11.4f} {m['mif_nuclei_sq']:>11.4f} {m['mif_nuclei_aji']:>12.4f} "
              f"{m['mif_cell_pq']:>9.4f} {m['mif_cell_dq']:>9.4f} {m['mif_cell_sq']:>9.4f} {m['mif_cell_aji']:>10.4f}")

# ========== OVERALL STATISTICS ==========
print("\n" + "="*100)
print("üìä OVERALL TEST SET RESULTS - INSTANCE METRICS (VitaminPSyn)")
print("="*100)

# Collect all metrics across all samples
all_he_nuclei_pq = []
all_he_nuclei_dq = []
all_he_nuclei_sq = []
all_he_nuclei_aji = []
all_he_cell_pq = []
all_he_cell_dq = []
all_he_cell_sq = []
all_he_cell_aji = []
all_mif_nuclei_pq = []
all_mif_nuclei_dq = []
all_mif_nuclei_sq = []
all_mif_nuclei_aji = []
all_mif_cell_pq = []
all_mif_cell_dq = []
all_mif_cell_sq = []
all_mif_cell_aji = []

for sample_name, metrics in sample_metrics.items():
    all_he_nuclei_pq.extend(metrics['he_nuclei_pq'])
    all_he_nuclei_dq.extend(metrics['he_nuclei_dq'])
    all_he_nuclei_sq.extend(metrics['he_nuclei_sq'])
    all_he_nuclei_aji.extend(metrics['he_nuclei_aji'])
    all_he_cell_pq.extend(metrics['he_cell_pq'])
    all_he_cell_dq.extend(metrics['he_cell_dq'])
    all_he_cell_sq.extend(metrics['he_cell_sq'])
    all_he_cell_aji.extend(metrics['he_cell_aji'])
    all_mif_nuclei_pq.extend(metrics['mif_nuclei_pq'])
    all_mif_nuclei_dq.extend(metrics['mif_nuclei_dq'])
    all_mif_nuclei_sq.extend(metrics['mif_nuclei_sq'])
    all_mif_nuclei_aji.extend(metrics['mif_nuclei_aji'])
    all_mif_cell_pq.extend(metrics['mif_cell_pq'])
    all_mif_cell_dq.extend(metrics['mif_cell_dq'])
    all_mif_cell_sq.extend(metrics['mif_cell_sq'])
    all_mif_cell_aji.extend(metrics['mif_cell_aji'])

total_patches = sum(m['patch_count'] for m in sample_metrics.values())

print(f"\nTotal samples: {len(sample_metrics)}")
print(f"Total patches: {total_patches}")
if skipped_batches > 0:
    print(f"‚ö†Ô∏è Skipped batches: {skipped_batches}")

print("\n" + "="*50)
print("H&E INSTANCE RESULTS")
print("="*50)
print("\nüî¨ H&E NUCLEI INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_he_nuclei_pq):.4f} ¬± {np.std(all_he_nuclei_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_he_nuclei_dq):.4f} ¬± {np.std(all_he_nuclei_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_he_nuclei_sq):.4f} ¬± {np.std(all_he_nuclei_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_he_nuclei_aji):.4f} ¬± {np.std(all_he_nuclei_aji):.4f}")

print("\nüß¨ H&E CELL INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_he_cell_pq):.4f} ¬± {np.std(all_he_cell_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_he_cell_dq):.4f} ¬± {np.std(all_he_cell_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_he_cell_sq):.4f} ¬± {np.std(all_he_cell_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_he_cell_aji):.4f} ¬± {np.std(all_he_cell_aji):.4f}")

print("\n" + "="*50)
print("MIF INSTANCE RESULTS")
print("="*50)
print("\nüî¨ MIF NUCLEI INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_mif_nuclei_pq):.4f} ¬± {np.std(all_mif_nuclei_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_mif_nuclei_dq):.4f} ¬± {np.std(all_mif_nuclei_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_mif_nuclei_sq):.4f} ¬± {np.std(all_mif_nuclei_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_mif_nuclei_aji):.4f} ¬± {np.std(all_mif_nuclei_aji):.4f}")

print("\nüß¨ MIF CELL INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_mif_cell_pq):.4f} ¬± {np.std(all_mif_cell_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_mif_cell_dq):.4f} ¬± {np.std(all_mif_cell_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_mif_cell_sq):.4f} ¬± {np.std(all_mif_cell_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_mif_cell_aji):.4f} ¬± {np.std(all_mif_cell_aji):.4f}")

# Statistics by cancer type
if crc_samples:
    print("\n" + "="*50)
    print(f"CRC SAMPLES ({len(crc_samples)} samples)")
    print("="*50)
    
    crc_he_nuclei_pq = [m['he_nuclei_pq'] for m in crc_samples.values()]
    crc_he_nuclei_dq = [m['he_nuclei_dq'] for m in crc_samples.values()]
    crc_he_nuclei_sq = [m['he_nuclei_sq'] for m in crc_samples.values()]
    crc_he_cell_pq = [m['he_cell_pq'] for m in crc_samples.values()]
    crc_mif_nuclei_pq = [m['mif_nuclei_pq'] for m in crc_samples.values()]
    crc_mif_cell_pq = [m['mif_cell_pq'] for m in crc_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei PQ: {np.mean(crc_he_nuclei_pq):.4f} ¬± {np.std(crc_he_nuclei_pq):.4f}")
    print(f"  Nuclei DQ: {np.mean(crc_he_nuclei_dq):.4f} ¬± {np.std(crc_he_nuclei_dq):.4f}")
    print(f"  Nuclei SQ: {np.mean(crc_he_nuclei_sq):.4f} ¬± {np.std(crc_he_nuclei_sq):.4f}")
    print(f"  Cell PQ:   {np.mean(crc_he_cell_pq):.4f} ¬± {np.std(crc_he_cell_pq):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei PQ: {np.mean(crc_mif_nuclei_pq):.4f} ¬± {np.std(crc_mif_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(crc_mif_cell_pq):.4f} ¬± {np.std(crc_mif_cell_pq):.4f}")

if xenium_samples:
    print("\n" + "="*50)
    print(f"XENIUM SAMPLES ({len(xenium_samples)} samples)")
    print("="*50)
    
    xenium_he_nuclei_pq = [m['he_nuclei_pq'] for m in xenium_samples.values()]
    xenium_he_cell_pq = [m['he_cell_pq'] for m in xenium_samples.values()]
    xenium_mif_nuclei_pq = [m['mif_nuclei_pq'] for m in xenium_samples.values()]
    xenium_mif_cell_pq = [m['mif_cell_pq'] for m in xenium_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei PQ: {np.mean(xenium_he_nuclei_pq):.4f} ¬± {np.std(xenium_he_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(xenium_he_cell_pq):.4f} ¬± {np.std(xenium_he_cell_pq):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei PQ: {np.mean(xenium_mif_nuclei_pq):.4f} ¬± {np.std(xenium_mif_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(xenium_mif_cell_pq):.4f} ¬± {np.std(xenium_mif_cell_pq):.4f}")

print("\n" + "="*100)
print("‚úÖ Instance segmentation evaluation complete for VitaminPSyn (both H&E and MIF)!")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda

üì¶ Loading VitaminPSyn model...
Building H&E encoder with DINOv2-base
Building Synthetic MIF encoder with DINOv2-base
Building shared encoder with DINOv2-base
‚úì VitaminPSyn initialized with base backbone
  Embed dim: 768 | Decoder dims: [768, 384, 192, 96]
‚úÖ VitaminPSyn model loaded

üîÑ Evaluating on all test samples (H&E + MIF - SYN MODEL - INSTANCE METRICS ONLY)...


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 127/127 [16:46<00:00,  7.93s/it]


üìä PER-SAMPLE RESULTS - H&E INSTANCE METRICS (VitaminPSyn)

üî¨ CRC SAMPLES - H&E INSTANCES:
------------------------------------------------------------------------------------------------------------------------
Sample        Patches   Nuclei PQ   Nuclei DQ   Nuclei SQ   Nuclei AJI   Cell PQ   Cell DQ   Cell SQ   Cell AJI
------------------------------------------------------------------------------------------------------------------------
CRC15              18      0.6779      0.8493      0.7979       0.6750    0.7501    0.9001    0.8330     0.6562
CRC16             142      0.7805      0.9065      0.8605       0.8083    0.8184    0.9318    0.8781     0.7313
CRC17              76      0.5814      0.7506      0.7737       0.5775    0.7584    0.8981    0.8444     0.6929
CRC18              80      0.6463      0.8146      0.7921       0.6508    0.7372    0.8774    0.8396     0.6263
CRC19              29      0.6033      0.7706      0.7814       0.6107    0.7781    0.9111    0.8538 


