In [1]:
# !pip install 'zarr<3'
# !pip install timm


Defaulting to user installation because normal site-packages is not writeable
Collecting zarr<3
  Using cached zarr-2.18.7-py3-none-any.whl.metadata (5.8 kB)
Collecting asciitree (from zarr<3)
  Using cached asciitree-0.3.3-py3-none-any.whl
Collecting fasteners (from zarr<3)
  Using cached fasteners-0.20-py3-none-any.whl.metadata (4.8 kB)
Collecting numcodecs!=0.14.0,!=0.14.1,<0.16,>=0.10.0 (from zarr<3)
  Using cached numcodecs-0.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.9 kB)
Collecting deprecated (from numcodecs!=0.14.0,!=0.14.1,<0.16,>=0.10.0->zarr<3)
  Using cached deprecated-1.3.1-py2.py3-none-any.whl.metadata (5.9 kB)
Using cached zarr-2.18.7-py3-none-any.whl (211 kB)
Using cached numcodecs-0.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.9 MB)
Using cached fasteners-0.20-py3-none-any.whl (18 kB)
Using cached deprecated-1.3.1-py2.py3-none-any.whl (11 kB)
Installing collected packages: asciitree, fasteners, deprecated, numcode

In [1]:
# ALWAYS RUN THIS FIRST!
import os
import sys
from pathlib import Path

NOTEBOOK_DIR = Path("/rsrch9/home/plm/idso_fa1_pathology/codes/yshokrollahi/vitamin-p-latest")
os.chdir(NOTEBOOK_DIR)
sys.path.insert(0, str(NOTEBOOK_DIR))

print(f"‚úÖ Working directory: {os.getcwd()}")

‚úÖ Working directory: /rsrch9/home/plm/idso_fa1_pathology/codes/yshokrollahi/vitamin-p-latest


## Data Loading

In [2]:
# Cell 3: Import and create dataloaders
from dataset import Config, create_dataloaders

# Just use the correct relative path from your working directory
config = Config("configs/training/config_fold3.yaml")  # Note: "configs" not "config"
config.print_config()

train_loader, val_loader, test_loader = create_dataloaders(config)
print("\n‚úÖ Ready to use!")

‚úÖ CRC Dataset Package v1.0.0 loaded
CRC DATASET CONFIGURATION
Config File: configs/training/config_fold3.yaml
Zarr Base: /rsrch9/home/plm/idso_fa1_pathology/TIER2/yasin-vitaminp/ORION-CRC/zarr_data
Cache: ./cache/multimodal_dataset_cache_fold3.pkl
Strategy: memory

üìä Data Splits:
  Train: 33 samples
  Val: 9 samples
  Test: 8 samples

üîÑ DataLoader:
  Batch Size: 4
  Num Workers: 0
  Pin Memory: True

üé® Augmentation:
  Training: True
  Probability: 0.0

üéØ HV Maps:
  Generate: True
  Method: pannuke
  HE Nuclei: True
  HE Cells: True
  MIF Nuclei: True
  MIF Cells: True

üîç Filtering:
  Min Instances: 0
  Filter Empty: True

CREATING DATALOADERS
Strategy: memory
Use Cache: True
Batch Size: 4
Num Workers: 0

Train split: 27 CRC + 6 Xenium samples
Val split: 7 CRC + 2 Xenium samples
Test split: 7 CRC + 1 Xenium samples

üì¶ Loading from cache: ./cache/combined_cache_train_0be9581c.pkl
üì¶ Loaded 3294 patches from cache
üì¶ Loading from cache: ./cache/combined_cache_val_0

## Flex Metrics Instances

In [4]:
import torch
import numpy as np
from collections import defaultdict
from metrics import (
    get_fast_pq,
    aggregated_jaccard_index
)
from vitaminp import VitaminPFlex, SimplePreprocessing, prepare_he_input, prepare_mif_input
from postprocessing import process_model_outputs
from tqdm import tqdm

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model
print("\nüì¶ Loading model...")
model = VitaminPFlex(model_size='large').to(device)
checkpoint_path = "checkpoints/vitamin_p_flex_large_fold3_best.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
print(f"‚úÖ Model loaded")

# Preprocessor
preprocessor = SimplePreprocessing()

# Initialize metric storage PER SAMPLE for BOTH H&E and MIF - INSTANCE METRICS ONLY
sample_metrics = defaultdict(lambda: {
    'he_nuclei_pq': [], 'he_nuclei_dq': [], 'he_nuclei_sq': [], 'he_nuclei_aji': [], 
    'he_cell_pq': [], 'he_cell_dq': [], 'he_cell_sq': [], 'he_cell_aji': [],
    'mif_nuclei_pq': [], 'mif_nuclei_dq': [], 'mif_nuclei_sq': [], 'mif_nuclei_aji': [],
    'mif_cell_pq': [], 'mif_cell_dq': [], 'mif_cell_sq': [], 'mif_cell_aji': [],
    'patch_count': 0
})

print(f"\nüîÑ Evaluating on all test samples (H&E + MIF - INSTANCE METRICS ONLY)...")

skipped_batches = 0

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader, desc="Processing batches")):
        
        try:
            # ========== H&E PROCESSING ==========
            he_img = prepare_he_input(batch['he_image'].to(device))
            he_img = preprocessor.percentile_normalize(he_img)
            
            he_nuclei_inst_gt = batch['he_nuclei_instance'].cpu().numpy()
            he_cell_inst_gt = batch['he_cell_instance'].cpu().numpy()
            
            # ========== MIF PROCESSING ==========
            mif_img = prepare_mif_input(batch['mif_image'].to(device))
            mif_img = preprocessor.percentile_normalize(mif_img)
            
            mif_nuclei_inst_gt = batch['mif_nuclei_instance'].cpu().numpy()
            mif_cell_inst_gt = batch['mif_cell_instance'].cpu().numpy()
            
            # Get sample names
            sample_names = batch['sample_name']
            
            # ========== INFERENCE - H&E ==========
            outputs_he = model(he_img)
            
            # ========== INFERENCE - MIF ==========
            outputs_mif = model(mif_img)
            
            # Process each sample in the batch
            batch_size = he_img.shape[0]
            
            for i in range(batch_size):
                sample_name = sample_names[i]
                
                # ========== H&E INSTANCE SEGMENTATION ==========
                # Extract predictions for post-processing
                he_nuclei_seg_np = outputs_he['he_nuclei_seg'][i, 0].cpu().numpy()
                he_nuclei_h_map = outputs_he['he_nuclei_hv'][i, 0].cpu().numpy()
                he_nuclei_v_map = outputs_he['he_nuclei_hv'][i, 1].cpu().numpy()
                
                he_cell_seg_np = outputs_he['he_cell_seg'][i, 0].cpu().numpy()
                he_cell_h_map = outputs_he['he_cell_hv'][i, 0].cpu().numpy()
                he_cell_v_map = outputs_he['he_cell_hv'][i, 1].cpu().numpy()
                
                # Apply post-processing to get instance maps
                he_nuclei_inst_pred, _, _ = process_model_outputs(
                    he_nuclei_seg_np, he_nuclei_h_map, he_nuclei_v_map,
                    magnification=40
                )
                
                he_cell_inst_pred, _, _ = process_model_outputs(
                    he_cell_seg_np, he_cell_h_map, he_cell_v_map,
                    magnification=40
                )
                
                # Compute H&E instance metrics (now returns pq, dq, sq)
                he_nuclei_pq, he_nuclei_dq, he_nuclei_sq = get_fast_pq(he_nuclei_inst_gt[i], he_nuclei_inst_pred)
                he_cell_pq, he_cell_dq, he_cell_sq = get_fast_pq(he_cell_inst_gt[i], he_cell_inst_pred)
                
                sample_metrics[sample_name]['he_nuclei_pq'].append(he_nuclei_pq)
                sample_metrics[sample_name]['he_nuclei_dq'].append(he_nuclei_dq)
                sample_metrics[sample_name]['he_nuclei_sq'].append(he_nuclei_sq)
                sample_metrics[sample_name]['he_nuclei_aji'].append(
                    aggregated_jaccard_index(he_nuclei_inst_gt[i], he_nuclei_inst_pred))
                
                sample_metrics[sample_name]['he_cell_pq'].append(he_cell_pq)
                sample_metrics[sample_name]['he_cell_dq'].append(he_cell_dq)
                sample_metrics[sample_name]['he_cell_sq'].append(he_cell_sq)
                sample_metrics[sample_name]['he_cell_aji'].append(
                    aggregated_jaccard_index(he_cell_inst_gt[i], he_cell_inst_pred))
                
                # ========== MIF INSTANCE SEGMENTATION ==========
                # Extract predictions for post-processing
                mif_nuclei_seg_np = outputs_mif['mif_nuclei_seg'][i, 0].cpu().numpy()
                mif_nuclei_h_map = outputs_mif['mif_nuclei_hv'][i, 0].cpu().numpy()
                mif_nuclei_v_map = outputs_mif['mif_nuclei_hv'][i, 1].cpu().numpy()
                
                mif_cell_seg_np = outputs_mif['mif_cell_seg'][i, 0].cpu().numpy()
                mif_cell_h_map = outputs_mif['mif_cell_hv'][i, 0].cpu().numpy()
                mif_cell_v_map = outputs_mif['mif_cell_hv'][i, 1].cpu().numpy()
                
                # Apply post-processing to get instance maps
                mif_nuclei_inst_pred, _, _ = process_model_outputs(
                    mif_nuclei_seg_np, mif_nuclei_h_map, mif_nuclei_v_map,
                    magnification=40
                )
                
                mif_cell_inst_pred, _, _ = process_model_outputs(
                    mif_cell_seg_np, mif_cell_h_map, mif_cell_v_map,
                    magnification=40
                )
                
                # Compute MIF instance metrics (now returns pq, dq, sq)
                mif_nuclei_pq, mif_nuclei_dq, mif_nuclei_sq = get_fast_pq(mif_nuclei_inst_gt[i], mif_nuclei_inst_pred)
                mif_cell_pq, mif_cell_dq, mif_cell_sq = get_fast_pq(mif_cell_inst_gt[i], mif_cell_inst_pred)
                
                sample_metrics[sample_name]['mif_nuclei_pq'].append(mif_nuclei_pq)
                sample_metrics[sample_name]['mif_nuclei_dq'].append(mif_nuclei_dq)
                sample_metrics[sample_name]['mif_nuclei_sq'].append(mif_nuclei_sq)
                sample_metrics[sample_name]['mif_nuclei_aji'].append(
                    aggregated_jaccard_index(mif_nuclei_inst_gt[i], mif_nuclei_inst_pred))
                
                sample_metrics[sample_name]['mif_cell_pq'].append(mif_cell_pq)
                sample_metrics[sample_name]['mif_cell_dq'].append(mif_cell_dq)
                sample_metrics[sample_name]['mif_cell_sq'].append(mif_cell_sq)
                sample_metrics[sample_name]['mif_cell_aji'].append(
                    aggregated_jaccard_index(mif_cell_inst_gt[i], mif_cell_inst_pred))
                
                sample_metrics[sample_name]['patch_count'] += 1
        
        except RuntimeError as e:
            print(f"\n‚ö†Ô∏è Skipping batch {batch_idx} due to error: {str(e)[:100]}")
            skipped_batches += 1
            continue

# ========== COMPUTE PER-SAMPLE AVERAGES ==========
print("\n" + "="*120)
print("üìä PER-SAMPLE RESULTS - H&E INSTANCE METRICS")
print("="*120)

# Separate CRC and Xenium samples
crc_samples = {}
xenium_samples = {}

for sample_name, metrics in sample_metrics.items():
    # Calculate averages for this sample
    avg_metrics = {
        'he_nuclei_pq': np.mean(metrics['he_nuclei_pq']),
        'he_nuclei_dq': np.mean(metrics['he_nuclei_dq']),
        'he_nuclei_sq': np.mean(metrics['he_nuclei_sq']),
        'he_nuclei_aji': np.mean(metrics['he_nuclei_aji']),
        'he_cell_pq': np.mean(metrics['he_cell_pq']),
        'he_cell_dq': np.mean(metrics['he_cell_dq']),
        'he_cell_sq': np.mean(metrics['he_cell_sq']),
        'he_cell_aji': np.mean(metrics['he_cell_aji']),
        'mif_nuclei_pq': np.mean(metrics['mif_nuclei_pq']),
        'mif_nuclei_dq': np.mean(metrics['mif_nuclei_dq']),
        'mif_nuclei_sq': np.mean(metrics['mif_nuclei_sq']),
        'mif_nuclei_aji': np.mean(metrics['mif_nuclei_aji']),
        'mif_cell_pq': np.mean(metrics['mif_cell_pq']),
        'mif_cell_dq': np.mean(metrics['mif_cell_dq']),
        'mif_cell_sq': np.mean(metrics['mif_cell_sq']),
        'mif_cell_aji': np.mean(metrics['mif_cell_aji']),
        'patch_count': metrics['patch_count']
    }
    
    if sample_name.startswith('CRC'):
        crc_samples[sample_name] = avg_metrics
    else:
        xenium_samples[sample_name] = avg_metrics

# Print CRC samples - H&E
if crc_samples:
    print("\nüî¨ CRC SAMPLES - H&E INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_pq']:>11.4f} {m['he_nuclei_dq']:>11.4f} {m['he_nuclei_sq']:>11.4f} {m['he_nuclei_aji']:>12.4f} "
              f"{m['he_cell_pq']:>9.4f} {m['he_cell_dq']:>9.4f} {m['he_cell_sq']:>9.4f} {m['he_cell_aji']:>10.4f}")

# Print Xenium samples - H&E
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - H&E INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_pq']:>11.4f} {m['he_nuclei_dq']:>11.4f} {m['he_nuclei_sq']:>11.4f} {m['he_nuclei_aji']:>12.4f} "
              f"{m['he_cell_pq']:>9.4f} {m['he_cell_dq']:>9.4f} {m['he_cell_sq']:>9.4f} {m['he_cell_aji']:>10.4f}")

# ========== MIF INSTANCE RESULTS ==========
print("\n" + "="*120)
print("üìä PER-SAMPLE RESULTS - MIF INSTANCE METRICS")
print("="*120)

# Print CRC samples - MIF
if crc_samples:
    print("\nüî¨ CRC SAMPLES - MIF INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_pq']:>11.4f} {m['mif_nuclei_dq']:>11.4f} {m['mif_nuclei_sq']:>11.4f} {m['mif_nuclei_aji']:>12.4f} "
              f"{m['mif_cell_pq']:>9.4f} {m['mif_cell_dq']:>9.4f} {m['mif_cell_sq']:>9.4f} {m['mif_cell_aji']:>10.4f}")

# Print Xenium samples - MIF
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - MIF INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_pq']:>11.4f} {m['mif_nuclei_dq']:>11.4f} {m['mif_nuclei_sq']:>11.4f} {m['mif_nuclei_aji']:>12.4f} "
              f"{m['mif_cell_pq']:>9.4f} {m['mif_cell_dq']:>9.4f} {m['mif_cell_sq']:>9.4f} {m['mif_cell_aji']:>10.4f}")

# ========== OVERALL STATISTICS ==========
print("\n" + "="*100)
print("üìä OVERALL TEST SET RESULTS - INSTANCE METRICS")
print("="*100)

# Collect all metrics across all samples
all_he_nuclei_pq = []
all_he_nuclei_dq = []
all_he_nuclei_sq = []
all_he_nuclei_aji = []
all_he_cell_pq = []
all_he_cell_dq = []
all_he_cell_sq = []
all_he_cell_aji = []
all_mif_nuclei_pq = []
all_mif_nuclei_dq = []
all_mif_nuclei_sq = []
all_mif_nuclei_aji = []
all_mif_cell_pq = []
all_mif_cell_dq = []
all_mif_cell_sq = []
all_mif_cell_aji = []

for sample_name, metrics in sample_metrics.items():
    all_he_nuclei_pq.extend(metrics['he_nuclei_pq'])
    all_he_nuclei_dq.extend(metrics['he_nuclei_dq'])
    all_he_nuclei_sq.extend(metrics['he_nuclei_sq'])
    all_he_nuclei_aji.extend(metrics['he_nuclei_aji'])
    all_he_cell_pq.extend(metrics['he_cell_pq'])
    all_he_cell_dq.extend(metrics['he_cell_dq'])
    all_he_cell_sq.extend(metrics['he_cell_sq'])
    all_he_cell_aji.extend(metrics['he_cell_aji'])
    all_mif_nuclei_pq.extend(metrics['mif_nuclei_pq'])
    all_mif_nuclei_dq.extend(metrics['mif_nuclei_dq'])
    all_mif_nuclei_sq.extend(metrics['mif_nuclei_sq'])
    all_mif_nuclei_aji.extend(metrics['mif_nuclei_aji'])
    all_mif_cell_pq.extend(metrics['mif_cell_pq'])
    all_mif_cell_dq.extend(metrics['mif_cell_dq'])
    all_mif_cell_sq.extend(metrics['mif_cell_sq'])
    all_mif_cell_aji.extend(metrics['mif_cell_aji'])

total_patches = sum(m['patch_count'] for m in sample_metrics.values())

print(f"\nTotal samples: {len(sample_metrics)}")
print(f"Total patches: {total_patches}")
if skipped_batches > 0:
    print(f"‚ö†Ô∏è Skipped batches: {skipped_batches}")

print("\n" + "="*50)
print("H&E INSTANCE RESULTS")
print("="*50)
print("\nüî¨ H&E NUCLEI INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_he_nuclei_pq):.4f} ¬± {np.std(all_he_nuclei_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_he_nuclei_dq):.4f} ¬± {np.std(all_he_nuclei_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_he_nuclei_sq):.4f} ¬± {np.std(all_he_nuclei_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_he_nuclei_aji):.4f} ¬± {np.std(all_he_nuclei_aji):.4f}")

print("\nüß¨ H&E CELL INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_he_cell_pq):.4f} ¬± {np.std(all_he_cell_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_he_cell_dq):.4f} ¬± {np.std(all_he_cell_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_he_cell_sq):.4f} ¬± {np.std(all_he_cell_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_he_cell_aji):.4f} ¬± {np.std(all_he_cell_aji):.4f}")

print("\n" + "="*50)
print("MIF INSTANCE RESULTS")
print("="*50)
print("\nüî¨ MIF NUCLEI INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_mif_nuclei_pq):.4f} ¬± {np.std(all_mif_nuclei_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_mif_nuclei_dq):.4f} ¬± {np.std(all_mif_nuclei_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_mif_nuclei_sq):.4f} ¬± {np.std(all_mif_nuclei_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_mif_nuclei_aji):.4f} ¬± {np.std(all_mif_nuclei_aji):.4f}")

print("\nüß¨ MIF CELL INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_mif_cell_pq):.4f} ¬± {np.std(all_mif_cell_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_mif_cell_dq):.4f} ¬± {np.std(all_mif_cell_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_mif_cell_sq):.4f} ¬± {np.std(all_mif_cell_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_mif_cell_aji):.4f} ¬± {np.std(all_mif_cell_aji):.4f}")

# Statistics by cancer type
if crc_samples:
    print("\n" + "="*50)
    print(f"CRC SAMPLES ({len(crc_samples)} samples)")
    print("="*50)
    
    crc_he_nuclei_pq = [m['he_nuclei_pq'] for m in crc_samples.values()]
    crc_he_nuclei_dq = [m['he_nuclei_dq'] for m in crc_samples.values()]
    crc_he_nuclei_sq = [m['he_nuclei_sq'] for m in crc_samples.values()]
    crc_he_cell_pq = [m['he_cell_pq'] for m in crc_samples.values()]
    crc_mif_nuclei_pq = [m['mif_nuclei_pq'] for m in crc_samples.values()]
    crc_mif_cell_pq = [m['mif_cell_pq'] for m in crc_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei PQ: {np.mean(crc_he_nuclei_pq):.4f} ¬± {np.std(crc_he_nuclei_pq):.4f}")
    print(f"  Nuclei DQ: {np.mean(crc_he_nuclei_dq):.4f} ¬± {np.std(crc_he_nuclei_dq):.4f}")
    print(f"  Nuclei SQ: {np.mean(crc_he_nuclei_sq):.4f} ¬± {np.std(crc_he_nuclei_sq):.4f}")
    print(f"  Cell PQ:   {np.mean(crc_he_cell_pq):.4f} ¬± {np.std(crc_he_cell_pq):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei PQ: {np.mean(crc_mif_nuclei_pq):.4f} ¬± {np.std(crc_mif_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(crc_mif_cell_pq):.4f} ¬± {np.std(crc_mif_cell_pq):.4f}")

if xenium_samples:
    print("\n" + "="*50)
    print(f"XENIUM SAMPLES ({len(xenium_samples)} samples)")
    print("="*50)
    
    xenium_he_nuclei_pq = [m['he_nuclei_pq'] for m in xenium_samples.values()]
    xenium_he_cell_pq = [m['he_cell_pq'] for m in xenium_samples.values()]
    xenium_mif_nuclei_pq = [m['mif_nuclei_pq'] for m in xenium_samples.values()]
    xenium_mif_cell_pq = [m['mif_cell_pq'] for m in xenium_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei PQ: {np.mean(xenium_he_nuclei_pq):.4f} ¬± {np.std(xenium_he_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(xenium_he_cell_pq):.4f} ¬± {np.std(xenium_he_cell_pq):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei PQ: {np.mean(xenium_mif_nuclei_pq):.4f} ¬± {np.std(xenium_mif_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(xenium_mif_cell_pq):.4f} ¬± {np.std(xenium_mif_cell_pq):.4f}")

print("\n" + "="*100)
print("‚úÖ Instance segmentation evaluation complete for both H&E and MIF!")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda

üì¶ Loading model...
‚úì VitaminPFlex initialized with large backbone
  Architecture: Shared Encoder ‚Üí 4 Separate Decoders
  Embed dim: 1024 | Decoder dims: [1024, 512, 256, 128]
‚úÖ Model loaded

üîÑ Evaluating on all test samples (H&E + MIF - INSTANCE METRICS ONLY)...


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 127/127 [15:43<00:00,  7.43s/it]


üìä PER-SAMPLE RESULTS - H&E INSTANCE METRICS

üî¨ CRC SAMPLES - H&E INSTANCES:
------------------------------------------------------------------------------------------------------------------------
Sample        Patches   Nuclei PQ   Nuclei DQ   Nuclei SQ   Nuclei AJI   Cell PQ   Cell DQ   Cell SQ   Cell AJI
------------------------------------------------------------------------------------------------------------------------
CRC15              18      0.6679      0.8345      0.8001       0.6630    0.5431    0.7479    0.7199     0.4845
CRC16             142      0.7804      0.9075      0.8596       0.8037    0.6584    0.8597    0.7644     0.5498
CRC17              76      0.5865      0.7588      0.7718       0.5804    0.5397    0.7467    0.7204     0.5026
CRC18              80      0.6376      0.8061      0.7898       0.6339    0.5407    0.7391    0.7295     0.4609
CRC19              29      0.6099      0.7806      0.7802       0.6038    0.4980    0.6951    0.7145     0.5053
CRC




## Dual instance metrica

In [5]:
import torch
import numpy as np
from collections import defaultdict
from metrics import (
    get_fast_pq,
    aggregated_jaccard_index
)
from vitaminp import VitaminPDual, SimplePreprocessing
from postprocessing import process_model_outputs
from tqdm import tqdm

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load DUAL model
print("\nüì¶ Loading VitaminPDual model...")
model = VitaminPDual(model_size='base', dropout_rate=0.3, freeze_backbone=False)
checkpoint_path = "checkpoints/vitamin_p_dual_base_fold21_best.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
model = model.to(device)
print(f"‚úÖ VitaminPDual model loaded")

# Preprocessor
preprocessor = SimplePreprocessing()

# Initialize metric storage PER SAMPLE for BOTH H&E and MIF - INSTANCE METRICS ONLY
sample_metrics = defaultdict(lambda: {
    'he_nuclei_pq': [], 'he_nuclei_dq': [], 'he_nuclei_sq': [], 'he_nuclei_aji': [], 
    'he_cell_pq': [], 'he_cell_dq': [], 'he_cell_sq': [], 'he_cell_aji': [],
    'mif_nuclei_pq': [], 'mif_nuclei_dq': [], 'mif_nuclei_sq': [], 'mif_nuclei_aji': [],
    'mif_cell_pq': [], 'mif_cell_dq': [], 'mif_cell_sq': [], 'mif_cell_aji': [],
    'patch_count': 0
})

print(f"\nüîÑ Evaluating on all test samples (H&E + MIF - DUAL MODEL - INSTANCE METRICS ONLY)...")

skipped_batches = 0

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader, desc="Processing batches")):
        
        try:
            # ========== PREPARE INPUTS ==========
            he_img = batch['he_image'].to(device)
            mif_img = batch['mif_image'].to(device)
            
            # Normalize
            he_img = preprocessor.percentile_normalize(he_img)
            mif_img = preprocessor.percentile_normalize(mif_img)
            
            # Ground truth instance maps
            he_nuclei_inst_gt = batch['he_nuclei_instance'].cpu().numpy()
            he_cell_inst_gt = batch['he_cell_instance'].cpu().numpy()
            mif_nuclei_inst_gt = batch['mif_nuclei_instance'].cpu().numpy()
            mif_cell_inst_gt = batch['mif_cell_instance'].cpu().numpy()
            
            # Get sample names
            sample_names = batch['sample_name']
            
            # ========== INFERENCE WITH DUAL MODEL (processes both H&E and MIF together) ==========
            outputs = model(he_img, mif_img)
            
            # Process each sample in the batch
            batch_size = he_img.shape[0]
            
            for i in range(batch_size):
                sample_name = sample_names[i]
                
                # ========== H&E INSTANCE SEGMENTATION ==========
                # Extract predictions for post-processing
                he_nuclei_seg_np = outputs['he_nuclei_seg'][i, 0].cpu().numpy()
                he_nuclei_h_map = outputs['he_nuclei_hv'][i, 0].cpu().numpy()
                he_nuclei_v_map = outputs['he_nuclei_hv'][i, 1].cpu().numpy()
                
                he_cell_seg_np = outputs['he_cell_seg'][i, 0].cpu().numpy()
                he_cell_h_map = outputs['he_cell_hv'][i, 0].cpu().numpy()
                he_cell_v_map = outputs['he_cell_hv'][i, 1].cpu().numpy()
                
                # Apply post-processing to get instance maps
                he_nuclei_inst_pred, _, _ = process_model_outputs(
                    he_nuclei_seg_np, he_nuclei_h_map, he_nuclei_v_map,
                    magnification=40
                )
                
                he_cell_inst_pred, _, _ = process_model_outputs(
                    he_cell_seg_np, he_cell_h_map, he_cell_v_map,
                    magnification=40
                )
                
                # Compute H&E instance metrics (returns pq, dq, sq)
                he_nuclei_pq, he_nuclei_dq, he_nuclei_sq = get_fast_pq(he_nuclei_inst_gt[i], he_nuclei_inst_pred)
                he_cell_pq, he_cell_dq, he_cell_sq = get_fast_pq(he_cell_inst_gt[i], he_cell_inst_pred)
                
                sample_metrics[sample_name]['he_nuclei_pq'].append(he_nuclei_pq)
                sample_metrics[sample_name]['he_nuclei_dq'].append(he_nuclei_dq)
                sample_metrics[sample_name]['he_nuclei_sq'].append(he_nuclei_sq)
                sample_metrics[sample_name]['he_nuclei_aji'].append(
                    aggregated_jaccard_index(he_nuclei_inst_gt[i], he_nuclei_inst_pred))
                
                sample_metrics[sample_name]['he_cell_pq'].append(he_cell_pq)
                sample_metrics[sample_name]['he_cell_dq'].append(he_cell_dq)
                sample_metrics[sample_name]['he_cell_sq'].append(he_cell_sq)
                sample_metrics[sample_name]['he_cell_aji'].append(
                    aggregated_jaccard_index(he_cell_inst_gt[i], he_cell_inst_pred))
                
                # ========== MIF INSTANCE SEGMENTATION ==========
                # Extract predictions for post-processing
                mif_nuclei_seg_np = outputs['mif_nuclei_seg'][i, 0].cpu().numpy()
                mif_nuclei_h_map = outputs['mif_nuclei_hv'][i, 0].cpu().numpy()
                mif_nuclei_v_map = outputs['mif_nuclei_hv'][i, 1].cpu().numpy()
                
                mif_cell_seg_np = outputs['mif_cell_seg'][i, 0].cpu().numpy()
                mif_cell_h_map = outputs['mif_cell_hv'][i, 0].cpu().numpy()
                mif_cell_v_map = outputs['mif_cell_hv'][i, 1].cpu().numpy()
                
                # Apply post-processing to get instance maps
                mif_nuclei_inst_pred, _, _ = process_model_outputs(
                    mif_nuclei_seg_np, mif_nuclei_h_map, mif_nuclei_v_map,
                    magnification=40
                )
                
                mif_cell_inst_pred, _, _ = process_model_outputs(
                    mif_cell_seg_np, mif_cell_h_map, mif_cell_v_map,
                    magnification=40
                )
                
                # Compute MIF instance metrics (returns pq, dq, sq)
                mif_nuclei_pq, mif_nuclei_dq, mif_nuclei_sq = get_fast_pq(mif_nuclei_inst_gt[i], mif_nuclei_inst_pred)
                mif_cell_pq, mif_cell_dq, mif_cell_sq = get_fast_pq(mif_cell_inst_gt[i], mif_cell_inst_pred)
                
                sample_metrics[sample_name]['mif_nuclei_pq'].append(mif_nuclei_pq)
                sample_metrics[sample_name]['mif_nuclei_dq'].append(mif_nuclei_dq)
                sample_metrics[sample_name]['mif_nuclei_sq'].append(mif_nuclei_sq)
                sample_metrics[sample_name]['mif_nuclei_aji'].append(
                    aggregated_jaccard_index(mif_nuclei_inst_gt[i], mif_nuclei_inst_pred))
                
                sample_metrics[sample_name]['mif_cell_pq'].append(mif_cell_pq)
                sample_metrics[sample_name]['mif_cell_dq'].append(mif_cell_dq)
                sample_metrics[sample_name]['mif_cell_sq'].append(mif_cell_sq)
                sample_metrics[sample_name]['mif_cell_aji'].append(
                    aggregated_jaccard_index(mif_cell_inst_gt[i], mif_cell_inst_pred))
                
                sample_metrics[sample_name]['patch_count'] += 1
        
        except RuntimeError as e:
            print(f"\n‚ö†Ô∏è Skipping batch {batch_idx} due to error: {str(e)[:100]}")
            skipped_batches += 1
            continue

# ========== COMPUTE PER-SAMPLE AVERAGES ==========
print("\n" + "="*120)
print("üìä PER-SAMPLE RESULTS - H&E INSTANCE METRICS (VitaminPDual)")
print("="*120)

# Separate CRC and Xenium samples
crc_samples = {}
xenium_samples = {}

for sample_name, metrics in sample_metrics.items():
    # Calculate averages for this sample
    avg_metrics = {
        'he_nuclei_pq': np.mean(metrics['he_nuclei_pq']),
        'he_nuclei_dq': np.mean(metrics['he_nuclei_dq']),
        'he_nuclei_sq': np.mean(metrics['he_nuclei_sq']),
        'he_nuclei_aji': np.mean(metrics['he_nuclei_aji']),
        'he_cell_pq': np.mean(metrics['he_cell_pq']),
        'he_cell_dq': np.mean(metrics['he_cell_dq']),
        'he_cell_sq': np.mean(metrics['he_cell_sq']),
        'he_cell_aji': np.mean(metrics['he_cell_aji']),
        'mif_nuclei_pq': np.mean(metrics['mif_nuclei_pq']),
        'mif_nuclei_dq': np.mean(metrics['mif_nuclei_dq']),
        'mif_nuclei_sq': np.mean(metrics['mif_nuclei_sq']),
        'mif_nuclei_aji': np.mean(metrics['mif_nuclei_aji']),
        'mif_cell_pq': np.mean(metrics['mif_cell_pq']),
        'mif_cell_dq': np.mean(metrics['mif_cell_dq']),
        'mif_cell_sq': np.mean(metrics['mif_cell_sq']),
        'mif_cell_aji': np.mean(metrics['mif_cell_aji']),
        'patch_count': metrics['patch_count']
    }
    
    if sample_name.startswith('CRC'):
        crc_samples[sample_name] = avg_metrics
    else:
        xenium_samples[sample_name] = avg_metrics

# Print CRC samples - H&E
if crc_samples:
    print("\nüî¨ CRC SAMPLES - H&E INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_pq']:>11.4f} {m['he_nuclei_dq']:>11.4f} {m['he_nuclei_sq']:>11.4f} {m['he_nuclei_aji']:>12.4f} "
              f"{m['he_cell_pq']:>9.4f} {m['he_cell_dq']:>9.4f} {m['he_cell_sq']:>9.4f} {m['he_cell_aji']:>10.4f}")

# Print Xenium samples - H&E
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - H&E INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_pq']:>11.4f} {m['he_nuclei_dq']:>11.4f} {m['he_nuclei_sq']:>11.4f} {m['he_nuclei_aji']:>12.4f} "
              f"{m['he_cell_pq']:>9.4f} {m['he_cell_dq']:>9.4f} {m['he_cell_sq']:>9.4f} {m['he_cell_aji']:>10.4f}")

# ========== MIF INSTANCE RESULTS ==========
print("\n" + "="*120)
print("üìä PER-SAMPLE RESULTS - MIF INSTANCE METRICS (VitaminPDual)")
print("="*120)

# Print CRC samples - MIF
if crc_samples:
    print("\nüî¨ CRC SAMPLES - MIF INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_pq']:>11.4f} {m['mif_nuclei_dq']:>11.4f} {m['mif_nuclei_sq']:>11.4f} {m['mif_nuclei_aji']:>12.4f} "
              f"{m['mif_cell_pq']:>9.4f} {m['mif_cell_dq']:>9.4f} {m['mif_cell_sq']:>9.4f} {m['mif_cell_aji']:>10.4f}")

# Print Xenium samples - MIF
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - MIF INSTANCES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_pq']:>11.4f} {m['mif_nuclei_dq']:>11.4f} {m['mif_nuclei_sq']:>11.4f} {m['mif_nuclei_aji']:>12.4f} "
              f"{m['mif_cell_pq']:>9.4f} {m['mif_cell_dq']:>9.4f} {m['mif_cell_sq']:>9.4f} {m['mif_cell_aji']:>10.4f}")

# ========== OVERALL STATISTICS ==========
print("\n" + "="*100)
print("üìä OVERALL TEST SET RESULTS - INSTANCE METRICS (VitaminPDual)")
print("="*100)

# Collect all metrics across all samples
all_he_nuclei_pq = []
all_he_nuclei_dq = []
all_he_nuclei_sq = []
all_he_nuclei_aji = []
all_he_cell_pq = []
all_he_cell_dq = []
all_he_cell_sq = []
all_he_cell_aji = []
all_mif_nuclei_pq = []
all_mif_nuclei_dq = []
all_mif_nuclei_sq = []
all_mif_nuclei_aji = []
all_mif_cell_pq = []
all_mif_cell_dq = []
all_mif_cell_sq = []
all_mif_cell_aji = []

for sample_name, metrics in sample_metrics.items():
    all_he_nuclei_pq.extend(metrics['he_nuclei_pq'])
    all_he_nuclei_dq.extend(metrics['he_nuclei_dq'])
    all_he_nuclei_sq.extend(metrics['he_nuclei_sq'])
    all_he_nuclei_aji.extend(metrics['he_nuclei_aji'])
    all_he_cell_pq.extend(metrics['he_cell_pq'])
    all_he_cell_dq.extend(metrics['he_cell_dq'])
    all_he_cell_sq.extend(metrics['he_cell_sq'])
    all_he_cell_aji.extend(metrics['he_cell_aji'])
    all_mif_nuclei_pq.extend(metrics['mif_nuclei_pq'])
    all_mif_nuclei_dq.extend(metrics['mif_nuclei_dq'])
    all_mif_nuclei_sq.extend(metrics['mif_nuclei_sq'])
    all_mif_nuclei_aji.extend(metrics['mif_nuclei_aji'])
    all_mif_cell_pq.extend(metrics['mif_cell_pq'])
    all_mif_cell_dq.extend(metrics['mif_cell_dq'])
    all_mif_cell_sq.extend(metrics['mif_cell_sq'])
    all_mif_cell_aji.extend(metrics['mif_cell_aji'])

total_patches = sum(m['patch_count'] for m in sample_metrics.values())

print(f"\nTotal samples: {len(sample_metrics)}")
print(f"Total patches: {total_patches}")
if skipped_batches > 0:
    print(f"‚ö†Ô∏è Skipped batches: {skipped_batches}")

print("\n" + "="*50)
print("H&E INSTANCE RESULTS")
print("="*50)
print("\nüî¨ H&E NUCLEI INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_he_nuclei_pq):.4f} ¬± {np.std(all_he_nuclei_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_he_nuclei_dq):.4f} ¬± {np.std(all_he_nuclei_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_he_nuclei_sq):.4f} ¬± {np.std(all_he_nuclei_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_he_nuclei_aji):.4f} ¬± {np.std(all_he_nuclei_aji):.4f}")

print("\nüß¨ H&E CELL INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_he_cell_pq):.4f} ¬± {np.std(all_he_cell_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_he_cell_dq):.4f} ¬± {np.std(all_he_cell_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_he_cell_sq):.4f} ¬± {np.std(all_he_cell_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_he_cell_aji):.4f} ¬± {np.std(all_he_cell_aji):.4f}")

print("\n" + "="*50)
print("MIF INSTANCE RESULTS")
print("="*50)
print("\nüî¨ MIF NUCLEI INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_mif_nuclei_pq):.4f} ¬± {np.std(all_mif_nuclei_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_mif_nuclei_dq):.4f} ¬± {np.std(all_mif_nuclei_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_mif_nuclei_sq):.4f} ¬± {np.std(all_mif_nuclei_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_mif_nuclei_aji):.4f} ¬± {np.std(all_mif_nuclei_aji):.4f}")

print("\nüß¨ MIF CELL INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_mif_cell_pq):.4f} ¬± {np.std(all_mif_cell_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_mif_cell_dq):.4f} ¬± {np.std(all_mif_cell_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_mif_cell_sq):.4f} ¬± {np.std(all_mif_cell_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_mif_cell_aji):.4f} ¬± {np.std(all_mif_cell_aji):.4f}")

# Statistics by cancer type
if crc_samples:
    print("\n" + "="*50)
    print(f"CRC SAMPLES ({len(crc_samples)} samples)")
    print("="*50)
    
    crc_he_nuclei_pq = [m['he_nuclei_pq'] for m in crc_samples.values()]
    crc_he_nuclei_dq = [m['he_nuclei_dq'] for m in crc_samples.values()]
    crc_he_nuclei_sq = [m['he_nuclei_sq'] for m in crc_samples.values()]
    crc_he_cell_pq = [m['he_cell_pq'] for m in crc_samples.values()]
    crc_mif_nuclei_pq = [m['mif_nuclei_pq'] for m in crc_samples.values()]
    crc_mif_cell_pq = [m['mif_cell_pq'] for m in crc_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei PQ: {np.mean(crc_he_nuclei_pq):.4f} ¬± {np.std(crc_he_nuclei_pq):.4f}")
    print(f"  Nuclei DQ: {np.mean(crc_he_nuclei_dq):.4f} ¬± {np.std(crc_he_nuclei_dq):.4f}")
    print(f"  Nuclei SQ: {np.mean(crc_he_nuclei_sq):.4f} ¬± {np.std(crc_he_nuclei_sq):.4f}")
    print(f"  Cell PQ:   {np.mean(crc_he_cell_pq):.4f} ¬± {np.std(crc_he_cell_pq):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei PQ: {np.mean(crc_mif_nuclei_pq):.4f} ¬± {np.std(crc_mif_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(crc_mif_cell_pq):.4f} ¬± {np.std(crc_mif_cell_pq):.4f}")

if xenium_samples:
    print("\n" + "="*50)
    print(f"XENIUM SAMPLES ({len(xenium_samples)} samples)")
    print("="*50)
    
    xenium_he_nuclei_pq = [m['he_nuclei_pq'] for m in xenium_samples.values()]
    xenium_he_cell_pq = [m['he_cell_pq'] for m in xenium_samples.values()]
    xenium_mif_nuclei_pq = [m['mif_nuclei_pq'] for m in xenium_samples.values()]
    xenium_mif_cell_pq = [m['mif_cell_pq'] for m in xenium_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei PQ: {np.mean(xenium_he_nuclei_pq):.4f} ¬± {np.std(xenium_he_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(xenium_he_cell_pq):.4f} ¬± {np.std(xenium_he_cell_pq):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei PQ: {np.mean(xenium_mif_nuclei_pq):.4f} ¬± {np.std(xenium_mif_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(xenium_mif_cell_pq):.4f} ¬± {np.std(xenium_mif_cell_pq):.4f}")

print("\n" + "="*100)
print("‚úÖ Instance segmentation evaluation complete for VitaminPDual (both H&E and MIF)!")

Using device: cuda

üì¶ Loading VitaminPDual model...
Building H&E encoder with DINOv2-base
Building MIF encoder with DINOv2-base
Building shared encoder with DINOv2-base
‚úì VitaminPDual initialized with base backbone
  Embed dim: 768 | Decoder dims: [768, 384, 192, 96]
‚úÖ VitaminPDual model loaded

üîÑ Evaluating on all test samples (H&E + MIF - DUAL MODEL - INSTANCE METRICS ONLY)...


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 127/127 [15:16<00:00,  7.22s/it]


üìä PER-SAMPLE RESULTS - H&E INSTANCE METRICS (VitaminPDual)

üî¨ CRC SAMPLES - H&E INSTANCES:
------------------------------------------------------------------------------------------------------------------------
Sample        Patches   Nuclei PQ   Nuclei DQ   Nuclei SQ   Nuclei AJI   Cell PQ   Cell DQ   Cell SQ   Cell AJI
------------------------------------------------------------------------------------------------------------------------
CRC15              18      0.6712      0.8411      0.7978       0.6788    0.7526    0.9016    0.8346     0.6396
CRC16             142      0.7838      0.9090      0.8618       0.8110    0.8153    0.9282    0.8781     0.7121
CRC17              76      0.5757      0.7438      0.7733       0.5665    0.7590    0.8964    0.8466     0.6783
CRC18              80      0.6511      0.8191      0.7938       0.6521    0.7348    0.8722    0.8417     0.5853
CRC19              29      0.6009      0.7693      0.7796       0.6109    0.7814    0.9099    0.8586




## Dual instance metrica HE compare with mIF GT

In [3]:
import torch
import numpy as np
from collections import defaultdict
from metrics import (
    get_fast_pq,
    aggregated_jaccard_index
)
from vitaminp import VitaminPDual, SimplePreprocessing
from postprocessing import process_model_outputs
from tqdm import tqdm

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load DUAL model
print("\nüì¶ Loading VitaminPDual model...")
model = VitaminPDual(model_size='base', dropout_rate=0.3, freeze_backbone=False)
checkpoint_path = "checkpoints/vitamin_p_dual_base_fold21_best.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
model = model.to(device)
print(f"‚úÖ VitaminPDual model loaded")

# Preprocessor
preprocessor = SimplePreprocessing()

# Initialize metric storage PER SAMPLE for BOTH H&E and MIF - INSTANCE METRICS ONLY
sample_metrics = defaultdict(lambda: {
    'he_nuclei_pq': [], 'he_nuclei_dq': [], 'he_nuclei_sq': [], 'he_nuclei_aji': [], 
    'he_cell_pq': [], 'he_cell_dq': [], 'he_cell_sq': [], 'he_cell_aji': [],
    'he_nuclei_vs_mif_pq': [], 'he_nuclei_vs_mif_dq': [], 'he_nuclei_vs_mif_sq': [], 'he_nuclei_vs_mif_aji': [],
    'mif_nuclei_pq': [], 'mif_nuclei_dq': [], 'mif_nuclei_sq': [], 'mif_nuclei_aji': [],
    'mif_cell_pq': [], 'mif_cell_dq': [], 'mif_cell_sq': [], 'mif_cell_aji': [],
    'patch_count': 0
})

print(f"\nüîÑ Evaluating on all test samples (H&E + MIF - DUAL MODEL - INSTANCE METRICS ONLY)...")

skipped_batches = 0

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader, desc="Processing batches")):
        
        try:
            # ========== PREPARE INPUTS ==========
            he_img = batch['he_image'].to(device)
            mif_img = batch['mif_image'].to(device)
            
            # Normalize
            he_img = preprocessor.percentile_normalize(he_img)
            mif_img = preprocessor.percentile_normalize(mif_img)
            
            # Ground truth instance maps
            he_nuclei_inst_gt = batch['he_nuclei_instance'].cpu().numpy()
            he_cell_inst_gt = batch['he_cell_instance'].cpu().numpy()
            mif_nuclei_inst_gt = batch['mif_nuclei_instance'].cpu().numpy()
            mif_cell_inst_gt = batch['mif_cell_instance'].cpu().numpy()
            
            # Get sample names
            sample_names = batch['sample_name']
            
            # ========== INFERENCE WITH DUAL MODEL (processes both H&E and MIF together) ==========
            outputs = model(he_img, mif_img)
            
            # Process each sample in the batch
            batch_size = he_img.shape[0]
            
            for i in range(batch_size):
                sample_name = sample_names[i]
                
                # ========== H&E INSTANCE SEGMENTATION ==========
                # Extract predictions for post-processing
                he_nuclei_seg_np = outputs['he_nuclei_seg'][i, 0].cpu().numpy()
                he_nuclei_h_map = outputs['he_nuclei_hv'][i, 0].cpu().numpy()
                he_nuclei_v_map = outputs['he_nuclei_hv'][i, 1].cpu().numpy()
                
                he_cell_seg_np = outputs['he_cell_seg'][i, 0].cpu().numpy()
                he_cell_h_map = outputs['he_cell_hv'][i, 0].cpu().numpy()
                he_cell_v_map = outputs['he_cell_hv'][i, 1].cpu().numpy()
                
                # Apply post-processing to get instance maps
                he_nuclei_inst_pred, _, _ = process_model_outputs(
                    he_nuclei_seg_np, he_nuclei_h_map, he_nuclei_v_map,
                    magnification=40
                )
                
                he_cell_inst_pred, _, _ = process_model_outputs(
                    he_cell_seg_np, he_cell_h_map, he_cell_v_map,
                    magnification=40
                )
                
                # Compute H&E instance metrics (returns pq, dq, sq)
                he_nuclei_pq, he_nuclei_dq, he_nuclei_sq = get_fast_pq(he_nuclei_inst_gt[i], he_nuclei_inst_pred)
                he_cell_pq, he_cell_dq, he_cell_sq = get_fast_pq(he_cell_inst_gt[i], he_cell_inst_pred)
                
                sample_metrics[sample_name]['he_nuclei_pq'].append(he_nuclei_pq)
                sample_metrics[sample_name]['he_nuclei_dq'].append(he_nuclei_dq)
                sample_metrics[sample_name]['he_nuclei_sq'].append(he_nuclei_sq)
                sample_metrics[sample_name]['he_nuclei_aji'].append(
                    aggregated_jaccard_index(he_nuclei_inst_gt[i], he_nuclei_inst_pred))
                
                sample_metrics[sample_name]['he_cell_pq'].append(he_cell_pq)
                sample_metrics[sample_name]['he_cell_dq'].append(he_cell_dq)
                sample_metrics[sample_name]['he_cell_sq'].append(he_cell_sq)
                sample_metrics[sample_name]['he_cell_aji'].append(
                    aggregated_jaccard_index(he_cell_inst_gt[i], he_cell_inst_pred))
                
                # ========== H&E NUCLEI vs MIF NUCLEI GT (NEW COMPARISON) ==========
                he_nuclei_vs_mif_pq, he_nuclei_vs_mif_dq, he_nuclei_vs_mif_sq = get_fast_pq(mif_nuclei_inst_gt[i], he_nuclei_inst_pred)
                
                sample_metrics[sample_name]['he_nuclei_vs_mif_pq'].append(he_nuclei_vs_mif_pq)
                sample_metrics[sample_name]['he_nuclei_vs_mif_dq'].append(he_nuclei_vs_mif_dq)
                sample_metrics[sample_name]['he_nuclei_vs_mif_sq'].append(he_nuclei_vs_mif_sq)
                sample_metrics[sample_name]['he_nuclei_vs_mif_aji'].append(
                    aggregated_jaccard_index(mif_nuclei_inst_gt[i], he_nuclei_inst_pred))
                
                # ========== MIF INSTANCE SEGMENTATION ==========
                # Extract predictions for post-processing
                mif_nuclei_seg_np = outputs['mif_nuclei_seg'][i, 0].cpu().numpy()
                mif_nuclei_h_map = outputs['mif_nuclei_hv'][i, 0].cpu().numpy()
                mif_nuclei_v_map = outputs['mif_nuclei_hv'][i, 1].cpu().numpy()
                
                mif_cell_seg_np = outputs['mif_cell_seg'][i, 0].cpu().numpy()
                mif_cell_h_map = outputs['mif_cell_hv'][i, 0].cpu().numpy()
                mif_cell_v_map = outputs['mif_cell_hv'][i, 1].cpu().numpy()
                
                # Apply post-processing to get instance maps
                mif_nuclei_inst_pred, _, _ = process_model_outputs(
                    mif_nuclei_seg_np, mif_nuclei_h_map, mif_nuclei_v_map,
                    magnification=40
                )
                
                mif_cell_inst_pred, _, _ = process_model_outputs(
                    mif_cell_seg_np, mif_cell_h_map, mif_cell_v_map,
                    magnification=40
                )
                
                # Compute MIF instance metrics (returns pq, dq, sq)
                mif_nuclei_pq, mif_nuclei_dq, mif_nuclei_sq = get_fast_pq(mif_nuclei_inst_gt[i], mif_nuclei_inst_pred)
                mif_cell_pq, mif_cell_dq, mif_cell_sq = get_fast_pq(mif_cell_inst_gt[i], mif_cell_inst_pred)
                
                sample_metrics[sample_name]['mif_nuclei_pq'].append(mif_nuclei_pq)
                sample_metrics[sample_name]['mif_nuclei_dq'].append(mif_nuclei_dq)
                sample_metrics[sample_name]['mif_nuclei_sq'].append(mif_nuclei_sq)
                sample_metrics[sample_name]['mif_nuclei_aji'].append(
                    aggregated_jaccard_index(mif_nuclei_inst_gt[i], mif_nuclei_inst_pred))
                
                sample_metrics[sample_name]['mif_cell_pq'].append(mif_cell_pq)
                sample_metrics[sample_name]['mif_cell_dq'].append(mif_cell_dq)
                sample_metrics[sample_name]['mif_cell_sq'].append(mif_cell_sq)
                sample_metrics[sample_name]['mif_cell_aji'].append(
                    aggregated_jaccard_index(mif_cell_inst_gt[i], mif_cell_inst_pred))
                
                sample_metrics[sample_name]['patch_count'] += 1
        
        except RuntimeError as e:
            print(f"\n‚ö†Ô∏è Skipping batch {batch_idx} due to error: {str(e)[:100]}")
            skipped_batches += 1
            continue

# ========== COMPUTE PER-SAMPLE AVERAGES ==========
print("\n" + "="*140)
print("üìä PER-SAMPLE RESULTS - H&E INSTANCE METRICS (VitaminPDual)")
print("="*140)

# Separate CRC and Xenium samples
crc_samples = {}
xenium_samples = {}

for sample_name, metrics in sample_metrics.items():
    # Calculate averages for this sample
    avg_metrics = {
        'he_nuclei_pq': np.mean(metrics['he_nuclei_pq']),
        'he_nuclei_dq': np.mean(metrics['he_nuclei_dq']),
        'he_nuclei_sq': np.mean(metrics['he_nuclei_sq']),
        'he_nuclei_aji': np.mean(metrics['he_nuclei_aji']),
        'he_nuclei_vs_mif_pq': np.mean(metrics['he_nuclei_vs_mif_pq']),
        'he_nuclei_vs_mif_dq': np.mean(metrics['he_nuclei_vs_mif_dq']),
        'he_nuclei_vs_mif_sq': np.mean(metrics['he_nuclei_vs_mif_sq']),
        'he_nuclei_vs_mif_aji': np.mean(metrics['he_nuclei_vs_mif_aji']),
        'he_cell_pq': np.mean(metrics['he_cell_pq']),
        'he_cell_dq': np.mean(metrics['he_cell_dq']),
        'he_cell_sq': np.mean(metrics['he_cell_sq']),
        'he_cell_aji': np.mean(metrics['he_cell_aji']),
        'mif_nuclei_pq': np.mean(metrics['mif_nuclei_pq']),
        'mif_nuclei_dq': np.mean(metrics['mif_nuclei_dq']),
        'mif_nuclei_sq': np.mean(metrics['mif_nuclei_sq']),
        'mif_nuclei_aji': np.mean(metrics['mif_nuclei_aji']),
        'mif_cell_pq': np.mean(metrics['mif_cell_pq']),
        'mif_cell_dq': np.mean(metrics['mif_cell_dq']),
        'mif_cell_sq': np.mean(metrics['mif_cell_sq']),
        'mif_cell_aji': np.mean(metrics['mif_cell_aji']),
        'patch_count': metrics['patch_count']
    }
    
    if sample_name.startswith('CRC'):
        crc_samples[sample_name] = avg_metrics
    else:
        xenium_samples[sample_name] = avg_metrics

# Print CRC samples - H&E
if crc_samples:
    print("\nüî¨ CRC SAMPLES - H&E INSTANCES (vs H&E GT):")
    print("-" * 140)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 140)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_pq']:>11.4f} {m['he_nuclei_dq']:>11.4f} {m['he_nuclei_sq']:>11.4f} {m['he_nuclei_aji']:>12.4f} "
              f"{m['he_cell_pq']:>9.4f} {m['he_cell_dq']:>9.4f} {m['he_cell_sq']:>9.4f} {m['he_cell_aji']:>10.4f}")
    
    print("\nüî¨ CRC SAMPLES - H&E NUCLEI PREDICTIONS (vs MIF NUCLEI GT):")
    print("-" * 140)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12}")
    print("-" * 140)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_vs_mif_pq']:>11.4f} {m['he_nuclei_vs_mif_dq']:>11.4f} {m['he_nuclei_vs_mif_sq']:>11.4f} {m['he_nuclei_vs_mif_aji']:>12.4f}")

# Print Xenium samples - H&E
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - H&E INSTANCES (vs H&E GT):")
    print("-" * 140)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 140)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_pq']:>11.4f} {m['he_nuclei_dq']:>11.4f} {m['he_nuclei_sq']:>11.4f} {m['he_nuclei_aji']:>12.4f} "
              f"{m['he_cell_pq']:>9.4f} {m['he_cell_dq']:>9.4f} {m['he_cell_sq']:>9.4f} {m['he_cell_aji']:>10.4f}")
    
    print("\nüß¨ XENIUM SAMPLES - H&E NUCLEI PREDICTIONS (vs MIF NUCLEI GT):")
    print("-" * 140)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12}")
    print("-" * 140)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_vs_mif_pq']:>11.4f} {m['he_nuclei_vs_mif_dq']:>11.4f} {m['he_nuclei_vs_mif_sq']:>11.4f} {m['he_nuclei_vs_mif_aji']:>12.4f}")

# ========== MIF INSTANCE RESULTS ==========
print("\n" + "="*140)
print("üìä PER-SAMPLE RESULTS - MIF INSTANCE METRICS (VitaminPDual)")
print("="*140)

# Print CRC samples - MIF
if crc_samples:
    print("\nüî¨ CRC SAMPLES - MIF INSTANCES:")
    print("-" * 140)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 140)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_pq']:>11.4f} {m['mif_nuclei_dq']:>11.4f} {m['mif_nuclei_sq']:>11.4f} {m['mif_nuclei_aji']:>12.4f} "
              f"{m['mif_cell_pq']:>9.4f} {m['mif_cell_dq']:>9.4f} {m['mif_cell_sq']:>9.4f} {m['mif_cell_aji']:>10.4f}")

# Print Xenium samples - MIF
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - MIF INSTANCES:")
    print("-" * 140)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 140)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_pq']:>11.4f} {m['mif_nuclei_dq']:>11.4f} {m['mif_nuclei_sq']:>11.4f} {m['mif_nuclei_aji']:>12.4f} "
              f"{m['mif_cell_pq']:>9.4f} {m['mif_cell_dq']:>9.4f} {m['mif_cell_sq']:>9.4f} {m['mif_cell_aji']:>10.4f}")

# ========== OVERALL STATISTICS ==========
print("\n" + "="*100)
print("üìä OVERALL TEST SET RESULTS - INSTANCE METRICS (VitaminPDual)")
print("="*100)

# Collect all metrics across all samples
all_he_nuclei_pq = []
all_he_nuclei_dq = []
all_he_nuclei_sq = []
all_he_nuclei_aji = []
all_he_nuclei_vs_mif_pq = []
all_he_nuclei_vs_mif_dq = []
all_he_nuclei_vs_mif_sq = []
all_he_nuclei_vs_mif_aji = []
all_he_cell_pq = []
all_he_cell_dq = []
all_he_cell_sq = []
all_he_cell_aji = []
all_mif_nuclei_pq = []
all_mif_nuclei_dq = []
all_mif_nuclei_sq = []
all_mif_nuclei_aji = []
all_mif_cell_pq = []
all_mif_cell_dq = []
all_mif_cell_sq = []
all_mif_cell_aji = []

for sample_name, metrics in sample_metrics.items():
    all_he_nuclei_pq.extend(metrics['he_nuclei_pq'])
    all_he_nuclei_dq.extend(metrics['he_nuclei_dq'])
    all_he_nuclei_sq.extend(metrics['he_nuclei_sq'])
    all_he_nuclei_aji.extend(metrics['he_nuclei_aji'])
    all_he_nuclei_vs_mif_pq.extend(metrics['he_nuclei_vs_mif_pq'])
    all_he_nuclei_vs_mif_dq.extend(metrics['he_nuclei_vs_mif_dq'])
    all_he_nuclei_vs_mif_sq.extend(metrics['he_nuclei_vs_mif_sq'])
    all_he_nuclei_vs_mif_aji.extend(metrics['he_nuclei_vs_mif_aji'])
    all_he_cell_pq.extend(metrics['he_cell_pq'])
    all_he_cell_dq.extend(metrics['he_cell_dq'])
    all_he_cell_sq.extend(metrics['he_cell_sq'])
    all_he_cell_aji.extend(metrics['he_cell_aji'])
    all_mif_nuclei_pq.extend(metrics['mif_nuclei_pq'])
    all_mif_nuclei_dq.extend(metrics['mif_nuclei_dq'])
    all_mif_nuclei_sq.extend(metrics['mif_nuclei_sq'])
    all_mif_nuclei_aji.extend(metrics['mif_nuclei_aji'])
    all_mif_cell_pq.extend(metrics['mif_cell_pq'])
    all_mif_cell_dq.extend(metrics['mif_cell_dq'])
    all_mif_cell_sq.extend(metrics['mif_cell_sq'])
    all_mif_cell_aji.extend(metrics['mif_cell_aji'])

total_patches = sum(m['patch_count'] for m in sample_metrics.values())

print(f"\nTotal samples: {len(sample_metrics)}")
print(f"Total patches: {total_patches}")
if skipped_batches > 0:
    print(f"‚ö†Ô∏è Skipped batches: {skipped_batches}")

print("\n" + "="*50)
print("H&E INSTANCE RESULTS")
print("="*50)
print("\nüî¨ H&E NUCLEI INSTANCE METRICS (vs H&E GT - all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_he_nuclei_pq):.4f} ¬± {np.std(all_he_nuclei_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_he_nuclei_dq):.4f} ¬± {np.std(all_he_nuclei_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_he_nuclei_sq):.4f} ¬± {np.std(all_he_nuclei_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_he_nuclei_aji):.4f} ¬± {np.std(all_he_nuclei_aji):.4f}")

print("\nüî¨ H&E NUCLEI PREDICTIONS (vs MIF NUCLEI GT - all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_he_nuclei_vs_mif_pq):.4f} ¬± {np.std(all_he_nuclei_vs_mif_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_he_nuclei_vs_mif_dq):.4f} ¬± {np.std(all_he_nuclei_vs_mif_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_he_nuclei_vs_mif_sq):.4f} ¬± {np.std(all_he_nuclei_vs_mif_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_he_nuclei_vs_mif_aji):.4f} ¬± {np.std(all_he_nuclei_vs_mif_aji):.4f}")

print("\nüß¨ H&E CELL INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_he_cell_pq):.4f} ¬± {np.std(all_he_cell_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_he_cell_dq):.4f} ¬± {np.std(all_he_cell_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_he_cell_sq):.4f} ¬± {np.std(all_he_cell_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_he_cell_aji):.4f} ¬± {np.std(all_he_cell_aji):.4f}")

print("\n" + "="*50)
print("MIF INSTANCE RESULTS")
print("="*50)
print("\nüî¨ MIF NUCLEI INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_mif_nuclei_pq):.4f} ¬± {np.std(all_mif_nuclei_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_mif_nuclei_dq):.4f} ¬± {np.std(all_mif_nuclei_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_mif_nuclei_sq):.4f} ¬± {np.std(all_mif_nuclei_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_mif_nuclei_aji):.4f} ¬± {np.std(all_mif_nuclei_aji):.4f}")

print("\nüß¨ MIF CELL INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_mif_cell_pq):.4f} ¬± {np.std(all_mif_cell_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_mif_cell_dq):.4f} ¬± {np.std(all_mif_cell_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_mif_cell_sq):.4f} ¬± {np.std(all_mif_cell_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_mif_cell_aji):.4f} ¬± {np.std(all_mif_cell_aji):.4f}")

# Statistics by cancer type
if crc_samples:
    print("\n" + "="*50)
    print(f"CRC SAMPLES ({len(crc_samples)} samples)")
    print("="*50)
    
    crc_he_nuclei_pq = [m['he_nuclei_pq'] for m in crc_samples.values()]
    crc_he_nuclei_dq = [m['he_nuclei_dq'] for m in crc_samples.values()]
    crc_he_nuclei_sq = [m['he_nuclei_sq'] for m in crc_samples.values()]
    crc_he_nuclei_vs_mif_pq = [m['he_nuclei_vs_mif_pq'] for m in crc_samples.values()]
    crc_he_cell_pq = [m['he_cell_pq'] for m in crc_samples.values()]
    crc_mif_nuclei_pq = [m['mif_nuclei_pq'] for m in crc_samples.values()]
    crc_mif_cell_pq = [m['mif_cell_pq'] for m in crc_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei PQ (vs H&E GT):     {np.mean(crc_he_nuclei_pq):.4f} ¬± {np.std(crc_he_nuclei_pq):.4f}")
    print(f"  Nuclei DQ (vs H&E GT):     {np.mean(crc_he_nuclei_dq):.4f} ¬± {np.std(crc_he_nuclei_dq):.4f}")
    print(f"  Nuclei SQ (vs H&E GT):     {np.mean(crc_he_nuclei_sq):.4f} ¬± {np.std(crc_he_nuclei_sq):.4f}")
    print(f"  Nuclei PQ (vs MIF GT):     {np.mean(crc_he_nuclei_vs_mif_pq):.4f} ¬± {np.std(crc_he_nuclei_vs_mif_pq):.4f}")
    print(f"  Cell PQ:                   {np.mean(crc_he_cell_pq):.4f} ¬± {np.std(crc_he_cell_pq):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei PQ: {np.mean(crc_mif_nuclei_pq):.4f} ¬± {np.std(crc_mif_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(crc_mif_cell_pq):.4f} ¬± {np.std(crc_mif_cell_pq):.4f}")

if xenium_samples:
    print("\n" + "="*50)
    print(f"XENIUM SAMPLES ({len(xenium_samples)} samples)")
    print("="*50)
    
    xenium_he_nuclei_pq = [m['he_nuclei_pq'] for m in xenium_samples.values()]
    xenium_he_nuclei_vs_mif_pq = [m['he_nuclei_vs_mif_pq'] for m in xenium_samples.values()]
    xenium_he_cell_pq = [m['he_cell_pq'] for m in xenium_samples.values()]
    xenium_mif_nuclei_pq = [m['mif_nuclei_pq'] for m in xenium_samples.values()]
    xenium_mif_cell_pq = [m['mif_cell_pq'] for m in xenium_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei PQ (vs H&E GT): {np.mean(xenium_he_nuclei_pq):.4f} ¬± {np.std(xenium_he_nuclei_pq):.4f}")
    print(f"  Nuclei PQ (vs MIF GT): {np.mean(xenium_he_nuclei_vs_mif_pq):.4f} ¬± {np.std(xenium_he_nuclei_vs_mif_pq):.4f}")
    print(f"  Cell PQ:               {np.mean(xenium_he_cell_pq):.4f} ¬± {np.std(xenium_he_cell_pq):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei PQ: {np.mean(xenium_mif_nuclei_pq):.4f} ¬± {np.std(xenium_mif_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(xenium_mif_cell_pq):.4f} ¬± {np.std(xenium_mif_cell_pq):.4f}")

print("\n" + "="*100)
print("‚úÖ Instance segmentation evaluation complete for VitaminPDual (both H&E and MIF)!")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda

üì¶ Loading VitaminPDual model...
Building H&E encoder with DINOv2-base
Building MIF encoder with DINOv2-base
Building shared encoder with DINOv2-base
‚úì VitaminPDual initialized with base backbone
  Embed dim: 768 | Decoder dims: [768, 384, 192, 96]
‚úÖ VitaminPDual model loaded

üîÑ Evaluating on all test samples (H&E + MIF - DUAL MODEL - INSTANCE METRICS ONLY)...


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 127/127 [19:05<00:00,  9.02s/it]


üìä PER-SAMPLE RESULTS - H&E INSTANCE METRICS (VitaminPDual)

üî¨ CRC SAMPLES - H&E INSTANCES (vs H&E GT):
--------------------------------------------------------------------------------------------------------------------------------------------
Sample        Patches   Nuclei PQ   Nuclei DQ   Nuclei SQ   Nuclei AJI   Cell PQ   Cell DQ   Cell SQ   Cell AJI
--------------------------------------------------------------------------------------------------------------------------------------------
CRC15              18      0.1512      0.2516      0.5873       0.3446    0.8212    0.9402    0.8733     0.7414
CRC16             142      0.2541      0.4213      0.5941       0.4272    0.8686    0.9544    0.9100     0.7897
CRC17              76      0.1004      0.1732      0.5702       0.2871    0.8135    0.9291    0.8754     0.7302
CRC18              80      0.1758      0.2949      0.5941       0.3295    0.7405    0.8745    0.8460     0.5939
CRC19              29      0.1490      0.2468   




## HE Base Line Instance Metrics

In [6]:
import torch
import numpy as np
from collections import defaultdict
from metrics import (
    get_fast_pq,
    aggregated_jaccard_index
)
from vitaminp import VitaminPBaselineHE, SimplePreprocessing, prepare_he_input
from postprocessing import process_model_outputs
from tqdm import tqdm

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load H&E BASELINE model
print("\nüì¶ Loading VitaminPBaselineHE model...")
model = VitaminPBaselineHE(model_size='large', dropout_rate=0.3, freeze_backbone=False)
checkpoint_path = "checkpoints/vitamin_p_baselinehe_large_fold3_best.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
model = model.to(device)
print(f"‚úÖ VitaminPBaselineHE model loaded")

# Preprocessor
preprocessor = SimplePreprocessing()

# Initialize metric storage PER SAMPLE for H&E ONLY - INSTANCE METRICS
sample_metrics = defaultdict(lambda: {
    'he_nuclei_pq': [], 'he_nuclei_dq': [], 'he_nuclei_sq': [], 'he_nuclei_aji': [], 
    'he_cell_pq': [], 'he_cell_dq': [], 'he_cell_sq': [], 'he_cell_aji': [],
    'patch_count': 0
})

print(f"\nüîÑ Evaluating on all test samples (H&E BASELINE - INSTANCE METRICS ONLY)...")

skipped_batches = 0

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader, desc="Processing batches")):
        
        try:
            # ========== H&E PROCESSING ONLY ==========
            he_img = prepare_he_input(batch['he_image'].to(device))
            he_img = preprocessor.percentile_normalize(he_img)
            
            he_nuclei_inst_gt = batch['he_nuclei_instance'].cpu().numpy()
            he_cell_inst_gt = batch['he_cell_instance'].cpu().numpy()
            
            # Get sample names
            sample_names = batch['sample_name']
            
            # ========== INFERENCE - H&E ONLY ==========
            outputs = model(he_img)
            
            # Process each sample in the batch
            batch_size = he_img.shape[0]
            
            for i in range(batch_size):
                sample_name = sample_names[i]
                
                # ========== H&E INSTANCE SEGMENTATION ==========
                # Extract predictions for post-processing
                he_nuclei_seg_np = outputs['he_nuclei_seg'][i, 0].cpu().numpy()
                he_nuclei_h_map = outputs['he_nuclei_hv'][i, 0].cpu().numpy()
                he_nuclei_v_map = outputs['he_nuclei_hv'][i, 1].cpu().numpy()
                
                he_cell_seg_np = outputs['he_cell_seg'][i, 0].cpu().numpy()
                he_cell_h_map = outputs['he_cell_hv'][i, 0].cpu().numpy()
                he_cell_v_map = outputs['he_cell_hv'][i, 1].cpu().numpy()
                
                # Apply post-processing to get instance maps
                he_nuclei_inst_pred, _, _ = process_model_outputs(
                    he_nuclei_seg_np, he_nuclei_h_map, he_nuclei_v_map,
                    magnification=40
                )
                
                he_cell_inst_pred, _, _ = process_model_outputs(
                    he_cell_seg_np, he_cell_h_map, he_cell_v_map,
                    magnification=40
                )
                
                # Compute H&E instance metrics (returns pq, dq, sq)
                he_nuclei_pq, he_nuclei_dq, he_nuclei_sq = get_fast_pq(he_nuclei_inst_gt[i], he_nuclei_inst_pred)
                he_cell_pq, he_cell_dq, he_cell_sq = get_fast_pq(he_cell_inst_gt[i], he_cell_inst_pred)
                
                sample_metrics[sample_name]['he_nuclei_pq'].append(he_nuclei_pq)
                sample_metrics[sample_name]['he_nuclei_dq'].append(he_nuclei_dq)
                sample_metrics[sample_name]['he_nuclei_sq'].append(he_nuclei_sq)
                sample_metrics[sample_name]['he_nuclei_aji'].append(
                    aggregated_jaccard_index(he_nuclei_inst_gt[i], he_nuclei_inst_pred))
                
                sample_metrics[sample_name]['he_cell_pq'].append(he_cell_pq)
                sample_metrics[sample_name]['he_cell_dq'].append(he_cell_dq)
                sample_metrics[sample_name]['he_cell_sq'].append(he_cell_sq)
                sample_metrics[sample_name]['he_cell_aji'].append(
                    aggregated_jaccard_index(he_cell_inst_gt[i], he_cell_inst_pred))
                
                sample_metrics[sample_name]['patch_count'] += 1
        
        except RuntimeError as e:
            print(f"\n‚ö†Ô∏è Skipping batch {batch_idx} due to error: {str(e)[:100]}")
            skipped_batches += 1
            continue

# ========== COMPUTE PER-SAMPLE AVERAGES ==========
print("\n" + "="*120)
print("üìä PER-SAMPLE RESULTS - H&E BASELINE INSTANCE METRICS")
print("="*120)

# Separate CRC and Xenium samples
crc_samples = {}
xenium_samples = {}

for sample_name, metrics in sample_metrics.items():
    avg_metrics = {
        'he_nuclei_pq': np.mean(metrics['he_nuclei_pq']),
        'he_nuclei_dq': np.mean(metrics['he_nuclei_dq']),
        'he_nuclei_sq': np.mean(metrics['he_nuclei_sq']),
        'he_nuclei_aji': np.mean(metrics['he_nuclei_aji']),
        'he_cell_pq': np.mean(metrics['he_cell_pq']),
        'he_cell_dq': np.mean(metrics['he_cell_dq']),
        'he_cell_sq': np.mean(metrics['he_cell_sq']),
        'he_cell_aji': np.mean(metrics['he_cell_aji']),
        'patch_count': metrics['patch_count']
    }
    
    if sample_name.startswith('CRC'):
        crc_samples[sample_name] = avg_metrics
    else:
        xenium_samples[sample_name] = avg_metrics

# Print CRC samples
if crc_samples:
    print("\nüî¨ CRC SAMPLES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_pq']:>11.4f} {m['he_nuclei_dq']:>11.4f} {m['he_nuclei_sq']:>11.4f} {m['he_nuclei_aji']:>12.4f} "
              f"{m['he_cell_pq']:>9.4f} {m['he_cell_dq']:>9.4f} {m['he_cell_sq']:>9.4f} {m['he_cell_aji']:>10.4f}")

# Print Xenium samples
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_pq']:>11.4f} {m['he_nuclei_dq']:>11.4f} {m['he_nuclei_sq']:>11.4f} {m['he_nuclei_aji']:>12.4f} "
              f"{m['he_cell_pq']:>9.4f} {m['he_cell_dq']:>9.4f} {m['he_cell_sq']:>9.4f} {m['he_cell_aji']:>10.4f}")

# ========== OVERALL STATISTICS ==========
print("\n" + "="*100)
print("üìä OVERALL TEST SET RESULTS - H&E BASELINE INSTANCE METRICS")
print("="*100)

all_he_nuclei_pq = []
all_he_nuclei_dq = []
all_he_nuclei_sq = []
all_he_nuclei_aji = []
all_he_cell_pq = []
all_he_cell_dq = []
all_he_cell_sq = []
all_he_cell_aji = []

for sample_name, metrics in sample_metrics.items():
    all_he_nuclei_pq.extend(metrics['he_nuclei_pq'])
    all_he_nuclei_dq.extend(metrics['he_nuclei_dq'])
    all_he_nuclei_sq.extend(metrics['he_nuclei_sq'])
    all_he_nuclei_aji.extend(metrics['he_nuclei_aji'])
    all_he_cell_pq.extend(metrics['he_cell_pq'])
    all_he_cell_dq.extend(metrics['he_cell_dq'])
    all_he_cell_sq.extend(metrics['he_cell_sq'])
    all_he_cell_aji.extend(metrics['he_cell_aji'])

total_patches = sum(m['patch_count'] for m in sample_metrics.values())

print(f"\nTotal samples: {len(sample_metrics)}")
print(f"Total patches: {total_patches}")
if skipped_batches > 0:
    print(f"‚ö†Ô∏è Skipped batches: {skipped_batches}")

print("\nüî¨ H&E NUCLEI INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_he_nuclei_pq):.4f} ¬± {np.std(all_he_nuclei_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_he_nuclei_dq):.4f} ¬± {np.std(all_he_nuclei_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_he_nuclei_sq):.4f} ¬± {np.std(all_he_nuclei_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_he_nuclei_aji):.4f} ¬± {np.std(all_he_nuclei_aji):.4f}")

print("\nüß¨ H&E CELL INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_he_cell_pq):.4f} ¬± {np.std(all_he_cell_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_he_cell_dq):.4f} ¬± {np.std(all_he_cell_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_he_cell_sq):.4f} ¬± {np.std(all_he_cell_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_he_cell_aji):.4f} ¬± {np.std(all_he_cell_aji):.4f}")

# Statistics by cancer type
if crc_samples:
    print("\n" + "="*50)
    print(f"CRC SAMPLES ({len(crc_samples)} samples)")
    print("="*50)
    
    crc_he_nuclei_pq = [m['he_nuclei_pq'] for m in crc_samples.values()]
    crc_he_nuclei_dq = [m['he_nuclei_dq'] for m in crc_samples.values()]
    crc_he_nuclei_sq = [m['he_nuclei_sq'] for m in crc_samples.values()]
    crc_he_cell_pq = [m['he_cell_pq'] for m in crc_samples.values()]
    
    print(f"\n  Nuclei PQ: {np.mean(crc_he_nuclei_pq):.4f} ¬± {np.std(crc_he_nuclei_pq):.4f}")
    print(f"  Nuclei DQ: {np.mean(crc_he_nuclei_dq):.4f} ¬± {np.std(crc_he_nuclei_dq):.4f}")
    print(f"  Nuclei SQ: {np.mean(crc_he_nuclei_sq):.4f} ¬± {np.std(crc_he_nuclei_sq):.4f}")
    print(f"  Cell PQ:   {np.mean(crc_he_cell_pq):.4f} ¬± {np.std(crc_he_cell_pq):.4f}")

if xenium_samples:
    print("\n" + "="*50)
    print(f"XENIUM SAMPLES ({len(xenium_samples)} samples)")
    print("="*50)
    
    xenium_he_nuclei_pq = [m['he_nuclei_pq'] for m in xenium_samples.values()]
    xenium_he_cell_pq = [m['he_cell_pq'] for m in xenium_samples.values()]
    
    print(f"\n  Nuclei PQ: {np.mean(xenium_he_nuclei_pq):.4f} ¬± {np.std(xenium_he_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(xenium_he_cell_pq):.4f} ¬± {np.std(xenium_he_cell_pq):.4f}")

print("\n" + "="*100)
print("‚úÖ H&E Baseline instance segmentation evaluation complete!")

Using device: cuda

üì¶ Loading VitaminPBaselineHE model...
Building H&E Baseline encoder with DINOv2-large
‚úì VitaminPBaselineHE initialized with large backbone
  Embed dim: 1024 | Decoder dims: [1024, 512, 256, 128]
‚úÖ VitaminPBaselineHE model loaded

üîÑ Evaluating on all test samples (H&E BASELINE - INSTANCE METRICS ONLY)...


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 127/127 [07:22<00:00,  3.49s/it]


üìä PER-SAMPLE RESULTS - H&E BASELINE INSTANCE METRICS

üî¨ CRC SAMPLES:
------------------------------------------------------------------------------------------------------------------------
Sample        Patches   Nuclei PQ   Nuclei DQ   Nuclei SQ   Nuclei AJI   Cell PQ   Cell DQ   Cell SQ   Cell AJI
------------------------------------------------------------------------------------------------------------------------
CRC15              18      0.6945      0.8526      0.8143       0.6865    0.5509    0.7604    0.7190     0.5113
CRC16             142      0.7966      0.9149      0.8703       0.8210    0.6583    0.8621    0.7623     0.5728
CRC17              76      0.6012      0.7661      0.7837       0.5926    0.5504    0.7584    0.7229     0.5173
CRC18              80      0.6671      0.8290      0.8037       0.6619    0.5513    0.7495    0.7332     0.4756
CRC19              29      0.6436      0.8117      0.7916       0.6461    0.5011    0.6971    0.7168     0.5114
CRC20     




## mIF base line Instances MEtrics

In [7]:
import torch
import numpy as np
from collections import defaultdict
from metrics import (
    get_fast_pq,
    aggregated_jaccard_index
)
from vitaminp import VitaminPBaselineMIF, SimplePreprocessing
from postprocessing import process_model_outputs
from tqdm import tqdm

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load MIF BASELINE model
print("\nüì¶ Loading VitaminPBaselineMIF model...")
model = VitaminPBaselineMIF(model_size='large', dropout_rate=0.3, freeze_backbone=False)
checkpoint_path = "checkpoints/vitamin_p_baselinemif_large_fold3_best.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
model = model.to(device)
print(f"‚úÖ VitaminPBaselineMIF model loaded")

# Preprocessor
preprocessor = SimplePreprocessing()

# Initialize metric storage PER SAMPLE for MIF ONLY - INSTANCE METRICS
sample_metrics = defaultdict(lambda: {
    'mif_nuclei_pq': [], 'mif_nuclei_dq': [], 'mif_nuclei_sq': [], 'mif_nuclei_aji': [],
    'mif_cell_pq': [], 'mif_cell_dq': [], 'mif_cell_sq': [], 'mif_cell_aji': [],
    'patch_count': 0
})

print(f"\nüîÑ Evaluating on all test samples (MIF BASELINE - INSTANCE METRICS ONLY)...")

skipped_batches = 0

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader, desc="Processing batches")):
        
        try:
            # ========== MIF PROCESSING ONLY (NO prepare_mif_input - keep as 2 channels) ==========
            mif_img = batch['mif_image'].to(device)  # Keep as 2 channels!
            mif_img = preprocessor.percentile_normalize(mif_img)
            
            mif_nuclei_inst_gt = batch['mif_nuclei_instance'].cpu().numpy()
            mif_cell_inst_gt = batch['mif_cell_instance'].cpu().numpy()
            
            # Get sample names
            sample_names = batch['sample_name']
            
            # ========== INFERENCE - MIF ONLY ==========
            outputs = model(mif_img)
            
            # Process each sample in the batch
            batch_size = mif_img.shape[0]
            
            for i in range(batch_size):
                sample_name = sample_names[i]
                
                # ========== MIF INSTANCE SEGMENTATION ==========
                # Extract predictions for post-processing
                mif_nuclei_seg_np = outputs['mif_nuclei_seg'][i, 0].cpu().numpy()
                mif_nuclei_h_map = outputs['mif_nuclei_hv'][i, 0].cpu().numpy()
                mif_nuclei_v_map = outputs['mif_nuclei_hv'][i, 1].cpu().numpy()
                
                mif_cell_seg_np = outputs['mif_cell_seg'][i, 0].cpu().numpy()
                mif_cell_h_map = outputs['mif_cell_hv'][i, 0].cpu().numpy()
                mif_cell_v_map = outputs['mif_cell_hv'][i, 1].cpu().numpy()
                
                # Apply post-processing to get instance maps
                mif_nuclei_inst_pred, _, _ = process_model_outputs(
                    mif_nuclei_seg_np, mif_nuclei_h_map, mif_nuclei_v_map,
                    magnification=40
                )
                
                mif_cell_inst_pred, _, _ = process_model_outputs(
                    mif_cell_seg_np, mif_cell_h_map, mif_cell_v_map,
                    magnification=40
                )
                
                # Compute MIF instance metrics (returns pq, dq, sq)
                mif_nuclei_pq, mif_nuclei_dq, mif_nuclei_sq = get_fast_pq(mif_nuclei_inst_gt[i], mif_nuclei_inst_pred)
                mif_cell_pq, mif_cell_dq, mif_cell_sq = get_fast_pq(mif_cell_inst_gt[i], mif_cell_inst_pred)
                
                sample_metrics[sample_name]['mif_nuclei_pq'].append(mif_nuclei_pq)
                sample_metrics[sample_name]['mif_nuclei_dq'].append(mif_nuclei_dq)
                sample_metrics[sample_name]['mif_nuclei_sq'].append(mif_nuclei_sq)
                sample_metrics[sample_name]['mif_nuclei_aji'].append(
                    aggregated_jaccard_index(mif_nuclei_inst_gt[i], mif_nuclei_inst_pred))
                
                sample_metrics[sample_name]['mif_cell_pq'].append(mif_cell_pq)
                sample_metrics[sample_name]['mif_cell_dq'].append(mif_cell_dq)
                sample_metrics[sample_name]['mif_cell_sq'].append(mif_cell_sq)
                sample_metrics[sample_name]['mif_cell_aji'].append(
                    aggregated_jaccard_index(mif_cell_inst_gt[i], mif_cell_inst_pred))
                
                sample_metrics[sample_name]['patch_count'] += 1
        
        except RuntimeError as e:
            print(f"\n‚ö†Ô∏è Skipping batch {batch_idx} due to error: {str(e)[:100]}")
            skipped_batches += 1
            continue

# ========== COMPUTE PER-SAMPLE AVERAGES ==========
print("\n" + "="*120)
print("üìä PER-SAMPLE RESULTS - MIF BASELINE INSTANCE METRICS")
print("="*120)

# Separate CRC and Xenium samples
crc_samples = {}
xenium_samples = {}

for sample_name, metrics in sample_metrics.items():
    avg_metrics = {
        'mif_nuclei_pq': np.mean(metrics['mif_nuclei_pq']),
        'mif_nuclei_dq': np.mean(metrics['mif_nuclei_dq']),
        'mif_nuclei_sq': np.mean(metrics['mif_nuclei_sq']),
        'mif_nuclei_aji': np.mean(metrics['mif_nuclei_aji']),
        'mif_cell_pq': np.mean(metrics['mif_cell_pq']),
        'mif_cell_dq': np.mean(metrics['mif_cell_dq']),
        'mif_cell_sq': np.mean(metrics['mif_cell_sq']),
        'mif_cell_aji': np.mean(metrics['mif_cell_aji']),
        'patch_count': metrics['patch_count']
    }
    
    if sample_name.startswith('CRC'):
        crc_samples[sample_name] = avg_metrics
    else:
        xenium_samples[sample_name] = avg_metrics

# Print CRC samples
if crc_samples:
    print("\nüî¨ CRC SAMPLES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_pq']:>11.4f} {m['mif_nuclei_dq']:>11.4f} {m['mif_nuclei_sq']:>11.4f} {m['mif_nuclei_aji']:>12.4f} "
              f"{m['mif_cell_pq']:>9.4f} {m['mif_cell_dq']:>9.4f} {m['mif_cell_sq']:>9.4f} {m['mif_cell_aji']:>10.4f}")

# Print Xenium samples
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES:")
    print("-" * 120)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei PQ':>11} {'Nuclei DQ':>11} {'Nuclei SQ':>11} {'Nuclei AJI':>12} {'Cell PQ':>9} {'Cell DQ':>9} {'Cell SQ':>9} {'Cell AJI':>10}")
    print("-" * 120)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_pq']:>11.4f} {m['mif_nuclei_dq']:>11.4f} {m['mif_nuclei_sq']:>11.4f} {m['mif_nuclei_aji']:>12.4f} "
              f"{m['mif_cell_pq']:>9.4f} {m['mif_cell_dq']:>9.4f} {m['mif_cell_sq']:>9.4f} {m['mif_cell_aji']:>10.4f}")

# ========== OVERALL STATISTICS ==========
print("\n" + "="*100)
print("üìä OVERALL TEST SET RESULTS - MIF BASELINE INSTANCE METRICS")
print("="*100)

all_mif_nuclei_pq = []
all_mif_nuclei_dq = []
all_mif_nuclei_sq = []
all_mif_nuclei_aji = []
all_mif_cell_pq = []
all_mif_cell_dq = []
all_mif_cell_sq = []
all_mif_cell_aji = []

for sample_name, metrics in sample_metrics.items():
    all_mif_nuclei_pq.extend(metrics['mif_nuclei_pq'])
    all_mif_nuclei_dq.extend(metrics['mif_nuclei_dq'])
    all_mif_nuclei_sq.extend(metrics['mif_nuclei_sq'])
    all_mif_nuclei_aji.extend(metrics['mif_nuclei_aji'])
    all_mif_cell_pq.extend(metrics['mif_cell_pq'])
    all_mif_cell_dq.extend(metrics['mif_cell_dq'])
    all_mif_cell_sq.extend(metrics['mif_cell_sq'])
    all_mif_cell_aji.extend(metrics['mif_cell_aji'])

total_patches = sum(m['patch_count'] for m in sample_metrics.values())

print(f"\nTotal samples: {len(sample_metrics)}")
print(f"Total patches: {total_patches}")
if skipped_batches > 0:
    print(f"‚ö†Ô∏è Skipped batches: {skipped_batches}")

print("\nüî¨ MIF NUCLEI INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_mif_nuclei_pq):.4f} ¬± {np.std(all_mif_nuclei_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_mif_nuclei_dq):.4f} ¬± {np.std(all_mif_nuclei_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_mif_nuclei_sq):.4f} ¬± {np.std(all_mif_nuclei_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_mif_nuclei_aji):.4f} ¬± {np.std(all_mif_nuclei_aji):.4f}")

print("\nüß¨ MIF CELL INSTANCE METRICS (all patches):")
print(f"  PQ (Panoptic Quality):   {np.mean(all_mif_cell_pq):.4f} ¬± {np.std(all_mif_cell_pq):.4f}")
print(f"  DQ (Detection Quality):  {np.mean(all_mif_cell_dq):.4f} ¬± {np.std(all_mif_cell_dq):.4f}")
print(f"  SQ (Segmentation Quality): {np.mean(all_mif_cell_sq):.4f} ¬± {np.std(all_mif_cell_sq):.4f}")
print(f"  AJI (Agg. Jaccard):      {np.mean(all_mif_cell_aji):.4f} ¬± {np.std(all_mif_cell_aji):.4f}")

# Statistics by cancer type
if crc_samples:
    print("\n" + "="*50)
    print(f"CRC SAMPLES ({len(crc_samples)} samples)")
    print("="*50)
    
    crc_mif_nuclei_pq = [m['mif_nuclei_pq'] for m in crc_samples.values()]
    crc_mif_nuclei_dq = [m['mif_nuclei_dq'] for m in crc_samples.values()]
    crc_mif_nuclei_sq = [m['mif_nuclei_sq'] for m in crc_samples.values()]
    crc_mif_cell_pq = [m['mif_cell_pq'] for m in crc_samples.values()]
    
    print(f"\n  Nuclei PQ: {np.mean(crc_mif_nuclei_pq):.4f} ¬± {np.std(crc_mif_nuclei_pq):.4f}")
    print(f"  Nuclei DQ: {np.mean(crc_mif_nuclei_dq):.4f} ¬± {np.std(crc_mif_nuclei_dq):.4f}")
    print(f"  Nuclei SQ: {np.mean(crc_mif_nuclei_sq):.4f} ¬± {np.std(crc_mif_nuclei_sq):.4f}")
    print(f"  Cell PQ:   {np.mean(crc_mif_cell_pq):.4f} ¬± {np.std(crc_mif_cell_pq):.4f}")

if xenium_samples:
    print("\n" + "="*50)
    print(f"XENIUM SAMPLES ({len(xenium_samples)} samples)")
    print("="*50)
    
    xenium_mif_nuclei_pq = [m['mif_nuclei_pq'] for m in xenium_samples.values()]
    xenium_mif_cell_pq = [m['mif_cell_pq'] for m in xenium_samples.values()]
    
    print(f"\n  Nuclei PQ: {np.mean(xenium_mif_nuclei_pq):.4f} ¬± {np.std(xenium_mif_nuclei_pq):.4f}")
    print(f"  Cell PQ:   {np.mean(xenium_mif_cell_pq):.4f} ¬± {np.std(xenium_mif_cell_pq):.4f}")

print("\n" + "="*100)
print("‚úÖ MIF Baseline instance segmentation evaluation complete!")

Using device: cuda

üì¶ Loading VitaminPBaselineMIF model...
Building MIF Baseline encoder with DINOv2-large
‚úì VitaminPBaselineMIF initialized with large backbone
  Embed dim: 1024 | Decoder dims: [1024, 512, 256, 128]
‚úÖ VitaminPBaselineMIF model loaded

üîÑ Evaluating on all test samples (MIF BASELINE - INSTANCE METRICS ONLY)...


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 127/127 [08:26<00:00,  3.98s/it]


üìä PER-SAMPLE RESULTS - MIF BASELINE INSTANCE METRICS

üî¨ CRC SAMPLES:
------------------------------------------------------------------------------------------------------------------------
Sample        Patches   Nuclei PQ   Nuclei DQ   Nuclei SQ   Nuclei AJI   Cell PQ   Cell DQ   Cell SQ   Cell AJI
------------------------------------------------------------------------------------------------------------------------
CRC15              18      0.7774      0.9153      0.8488       0.6593    0.7641    0.9127    0.8370     0.6634
CRC16             142      0.7729      0.9231      0.8371       0.7217    0.8242    0.9328    0.8834     0.7096
CRC17              76      0.7707      0.8976      0.8582       0.7072    0.7779    0.9101    0.8547     0.6939
CRC18              80      0.7607      0.8900      0.8540       0.6667    0.7581    0.8867    0.8545     0.6138
CRC19              29      0.7534      0.8996      0.8372       0.7196    0.7990    0.9235    0.8649     0.7470
CRC20     


