In [1]:
!pip install 'zarr<3'
!pip install timm


Defaulting to user installation because normal site-packages is not writeable
Collecting zarr<3
  Using cached zarr-2.18.7-py3-none-any.whl.metadata (5.8 kB)
Collecting asciitree (from zarr<3)
  Using cached asciitree-0.3.3-py3-none-any.whl
Collecting fasteners (from zarr<3)
  Using cached fasteners-0.20-py3-none-any.whl.metadata (4.8 kB)
Collecting numcodecs!=0.14.0,!=0.14.1,<0.16,>=0.10.0 (from zarr<3)
  Using cached numcodecs-0.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.9 kB)
Collecting deprecated (from numcodecs!=0.14.0,!=0.14.1,<0.16,>=0.10.0->zarr<3)
  Using cached deprecated-1.3.1-py2.py3-none-any.whl.metadata (5.9 kB)
Using cached zarr-2.18.7-py3-none-any.whl (211 kB)
Using cached numcodecs-0.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.9 MB)
Using cached fasteners-0.20-py3-none-any.whl (18 kB)
Using cached deprecated-1.3.1-py2.py3-none-any.whl (11 kB)
Installing collected packages: asciitree, fasteners, deprecated, numcode

In [6]:
# ALWAYS RUN THIS FIRST!
import os
import sys
from pathlib import Path

NOTEBOOK_DIR = Path("/rsrch9/home/plm/idso_fa1_pathology/codes/yshokrollahi/vitamin-p-latest")
os.chdir(NOTEBOOK_DIR)
sys.path.insert(0, str(NOTEBOOK_DIR))

print(f"‚úÖ Working directory: {os.getcwd()}")

‚úÖ Working directory: /rsrch9/home/plm/idso_fa1_pathology/codes/yshokrollahi/vitamin-p-latest


## Data Loading

In [7]:
# Cell 3: Import and create dataloaders
from dataset import Config, create_dataloaders

# Just use the correct relative path from your working directory
config = Config("configs/config_fold13.yaml")  # Note: "configs" not "config"
config.print_config()

train_loader, val_loader, test_loader = create_dataloaders(config)
print("\n‚úÖ Ready to use!")

CRC DATASET CONFIGURATION
Config File: configs/config_fold13.yaml
Zarr Base: /rsrch9/home/plm/idso_fa1_pathology/TIER2/yasin-vitaminp/ORION-CRC-Syn/zarr_data
Cache: ./cache/multimodal_syn_dataset_cache_fold3.pkl
Strategy: memory

üìä Data Splits:
  Train: 33 samples
  Val: 9 samples
  Test: 8 samples

üîÑ DataLoader:
  Batch Size: 4
  Num Workers: 0
  Pin Memory: True

üé® Augmentation:
  Training: True
  Probability: 0.0

üéØ HV Maps:
  Generate: True
  Method: pannuke
  HE Nuclei: True
  HE Cells: True
  MIF Nuclei: True
  MIF Cells: True

üîç Filtering:
  Min Instances: 0
  Filter Empty: True

CREATING DATALOADERS
Strategy: memory
Use Cache: True
Batch Size: 4
Num Workers: 0

Train split: 27 CRC + 6 Xenium samples
Val split: 7 CRC + 2 Xenium samples
Test split: 7 CRC + 1 Xenium samples

üì¶ Loading from cache: ./cache/combined_cache_train_0be9581c.pkl
üì¶ Loaded 3294 patches from cache
üì¶ Loading from cache: ./cache/combined_cache_val_0ed6113d.pkl
üì¶ Loaded 1444 patches f

## Syn model

In [8]:
import torch
import numpy as np
from collections import defaultdict
from metrics import (
    dice_coefficient, 
    iou_score, 
    precision_score, 
    recall_score,
    f1_score
)
from vitaminp import VitaminPSyn, SimplePreprocessing
from tqdm import tqdm

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load DUAL model
print("\nüì¶ Loading VitaminPSyn model...")
model = VitaminPSyn(model_size='base', dropout_rate=0.3, freeze_backbone=False)
checkpoint_path = "checkpoints/vitamin_p_syn_base_fold13_best.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
model = model.to(device)
print(f"‚úÖ VitaminPDual model loaded")

# Preprocessor
preprocessor = SimplePreprocessing()

# Initialize metric storage PER SAMPLE for BOTH H&E and MIF
sample_metrics = defaultdict(lambda: {
    'he_nuclei_dice': [], 'he_nuclei_iou': [], 
    'he_cell_dice': [], 'he_cell_iou': [],
    'mif_nuclei_dice': [], 'mif_nuclei_iou': [],
    'mif_cell_dice': [], 'mif_cell_iou': [],
    'patch_count': 0
})

print(f"\nüîÑ Evaluating on all test samples (H&E + MIF - DUAL MODEL - BINARY METRICS ONLY)...")

skipped_batches = 0

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader, desc="Processing batches")):
        
        try:
            # ========== PREPARE INPUTS ==========
            he_img = batch['he_image'].to(device)
            mif_img = batch['mif_image'].to(device)
            
            # Normalize
            he_img = preprocessor.percentile_normalize(he_img)
            mif_img = preprocessor.percentile_normalize(mif_img)
            
            # Ground truth masks
            he_nuclei_mask_gt = batch['he_nuclei_mask'].float().unsqueeze(1).to(device)
            he_cell_mask_gt = batch['he_cell_mask'].float().unsqueeze(1).to(device)
            mif_nuclei_mask_gt = batch['mif_nuclei_mask'].float().unsqueeze(1).to(device)
            mif_cell_mask_gt = batch['mif_cell_mask'].float().unsqueeze(1).to(device)
            
            # Get sample names
            sample_names = batch['sample_name']
            
            # ========== INFERENCE WITH DUAL MODEL (processes both H&E and MIF together) ==========
            outputs = model(he_img, mif_img)
            
            # Get predictions for H&E
            pred_he_nuclei = (outputs['he_nuclei_seg'] > 0.5).float()
            pred_he_cell = (outputs['he_cell_seg'] > 0.5).float()
            
            # Get predictions for MIF
            pred_mif_nuclei = (outputs['mif_nuclei_seg'] > 0.5).float()
            pred_mif_cell = (outputs['mif_cell_seg'] > 0.5).float()
            
            # Process each sample in the batch
            batch_size = pred_he_nuclei.shape[0]
            
            for i in range(batch_size):
                sample_name = sample_names[i]
                
                # ========== H&E METRICS ==========
                sample_metrics[sample_name]['he_nuclei_dice'].append(
                    dice_coefficient(pred_he_nuclei[i:i+1], he_nuclei_mask_gt[i:i+1]))
                sample_metrics[sample_name]['he_nuclei_iou'].append(
                    iou_score(pred_he_nuclei[i:i+1], he_nuclei_mask_gt[i:i+1]))
                sample_metrics[sample_name]['he_cell_dice'].append(
                    dice_coefficient(pred_he_cell[i:i+1], he_cell_mask_gt[i:i+1]))
                sample_metrics[sample_name]['he_cell_iou'].append(
                    iou_score(pred_he_cell[i:i+1], he_cell_mask_gt[i:i+1]))
                
                # ========== MIF METRICS ==========
                sample_metrics[sample_name]['mif_nuclei_dice'].append(
                    dice_coefficient(pred_mif_nuclei[i:i+1], mif_nuclei_mask_gt[i:i+1]))
                sample_metrics[sample_name]['mif_nuclei_iou'].append(
                    iou_score(pred_mif_nuclei[i:i+1], mif_nuclei_mask_gt[i:i+1]))
                sample_metrics[sample_name]['mif_cell_dice'].append(
                    dice_coefficient(pred_mif_cell[i:i+1], mif_cell_mask_gt[i:i+1]))
                sample_metrics[sample_name]['mif_cell_iou'].append(
                    iou_score(pred_mif_cell[i:i+1], mif_cell_mask_gt[i:i+1]))
                
                sample_metrics[sample_name]['patch_count'] += 1
        
        except RuntimeError as e:
            print(f"\n‚ö†Ô∏è Skipping batch {batch_idx} due to error: {str(e)[:100]}")
            skipped_batches += 1
            continue

# ========== COMPUTE PER-SAMPLE AVERAGES ==========
print("\n" + "="*100)
print("üìä PER-SAMPLE RESULTS - H&E (VitaminPDual)")
print("="*100)

# Separate CRC and Xenium samples
crc_samples = {}
xenium_samples = {}

for sample_name, metrics in sample_metrics.items():
    # Calculate averages for this sample
    avg_metrics = {
        'he_nuclei_dice': np.mean(metrics['he_nuclei_dice']),
        'he_nuclei_iou': np.mean(metrics['he_nuclei_iou']),
        'he_cell_dice': np.mean(metrics['he_cell_dice']),
        'he_cell_iou': np.mean(metrics['he_cell_iou']),
        'mif_nuclei_dice': np.mean(metrics['mif_nuclei_dice']),
        'mif_nuclei_iou': np.mean(metrics['mif_nuclei_iou']),
        'mif_cell_dice': np.mean(metrics['mif_cell_dice']),
        'mif_cell_iou': np.mean(metrics['mif_cell_iou']),
        'patch_count': metrics['patch_count']
    }
    
    if sample_name.startswith('CRC'):
        crc_samples[sample_name] = avg_metrics
    else:
        xenium_samples[sample_name] = avg_metrics

# Print CRC samples - H&E
if crc_samples:
    print("\nüî¨ CRC SAMPLES - H&E:")
    print("-" * 100)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei Dice':>13} {'Nuclei IoU':>12} {'Cell Dice':>12} {'Cell IoU':>11}")
    print("-" * 100)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_dice']:>13.4f} {m['he_nuclei_iou']:>12.4f} "
              f"{m['he_cell_dice']:>12.4f} {m['he_cell_iou']:>11.4f}")

# Print Xenium samples - H&E
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - H&E:")
    print("-" * 100)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei Dice':>13} {'Nuclei IoU':>12} {'Cell Dice':>12} {'Cell IoU':>11}")
    print("-" * 100)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['he_nuclei_dice']:>13.4f} {m['he_nuclei_iou']:>12.4f} "
              f"{m['he_cell_dice']:>12.4f} {m['he_cell_iou']:>11.4f}")

# ========== MIF RESULTS ==========
print("\n" + "="*100)
print("üìä PER-SAMPLE RESULTS - MIF (VitaminPDual)")
print("="*100)

# Print CRC samples - MIF
if crc_samples:
    print("\nüî¨ CRC SAMPLES - MIF:")
    print("-" * 100)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei Dice':>13} {'Nuclei IoU':>12} {'Cell Dice':>12} {'Cell IoU':>11}")
    print("-" * 100)
    
    for sample_name in sorted(crc_samples.keys()):
        m = crc_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_dice']:>13.4f} {m['mif_nuclei_iou']:>12.4f} "
              f"{m['mif_cell_dice']:>12.4f} {m['mif_cell_iou']:>11.4f}")

# Print Xenium samples - MIF
if xenium_samples:
    print("\nüß¨ XENIUM SAMPLES - MIF:")
    print("-" * 100)
    print(f"{'Sample':<12} {'Patches':>8} {'Nuclei Dice':>13} {'Nuclei IoU':>12} {'Cell Dice':>12} {'Cell IoU':>11}")
    print("-" * 100)
    
    for sample_name in sorted(xenium_samples.keys()):
        m = xenium_samples[sample_name]
        print(f"{sample_name:<12} {m['patch_count']:>8} "
              f"{m['mif_nuclei_dice']:>13.4f} {m['mif_nuclei_iou']:>12.4f} "
              f"{m['mif_cell_dice']:>12.4f} {m['mif_cell_iou']:>11.4f}")

# ========== OVERALL STATISTICS ==========
print("\n" + "="*100)
print("üìä OVERALL TEST SET RESULTS (VitaminPDual)")
print("="*100)

# Collect all metrics across all samples
all_he_nuclei_dice = []
all_he_nuclei_iou = []
all_he_cell_dice = []
all_he_cell_iou = []
all_mif_nuclei_dice = []
all_mif_nuclei_iou = []
all_mif_cell_dice = []
all_mif_cell_iou = []

for sample_name, metrics in sample_metrics.items():
    all_he_nuclei_dice.extend(metrics['he_nuclei_dice'])
    all_he_nuclei_iou.extend(metrics['he_nuclei_iou'])
    all_he_cell_dice.extend(metrics['he_cell_dice'])
    all_he_cell_iou.extend(metrics['he_cell_iou'])
    all_mif_nuclei_dice.extend(metrics['mif_nuclei_dice'])
    all_mif_nuclei_iou.extend(metrics['mif_nuclei_iou'])
    all_mif_cell_dice.extend(metrics['mif_cell_dice'])
    all_mif_cell_iou.extend(metrics['mif_cell_iou'])

total_patches = sum(m['patch_count'] for m in sample_metrics.values())

print(f"\nTotal samples: {len(sample_metrics)}")
print(f"Total patches: {total_patches}")
if skipped_batches > 0:
    print(f"‚ö†Ô∏è Skipped batches: {skipped_batches}")

print("\n" + "="*50)
print("H&E RESULTS")
print("="*50)
print("\nüî¨ H&E NUCLEI METRICS (all patches):")
print(f"  Dice:      {np.mean(all_he_nuclei_dice):.4f} ¬± {np.std(all_he_nuclei_dice):.4f}")
print(f"  IoU:       {np.mean(all_he_nuclei_iou):.4f} ¬± {np.std(all_he_nuclei_iou):.4f}")

print("\nüß¨ H&E CELL METRICS (all patches):")
print(f"  Dice:      {np.mean(all_he_cell_dice):.4f} ¬± {np.std(all_he_cell_dice):.4f}")
print(f"  IoU:       {np.mean(all_he_cell_iou):.4f} ¬± {np.std(all_he_cell_iou):.4f}")

print("\n" + "="*50)
print("MIF RESULTS")
print("="*50)
print("\nüî¨ MIF NUCLEI METRICS (all patches):")
print(f"  Dice:      {np.mean(all_mif_nuclei_dice):.4f} ¬± {np.std(all_mif_nuclei_dice):.4f}")
print(f"  IoU:       {np.mean(all_mif_nuclei_iou):.4f} ¬± {np.std(all_mif_nuclei_iou):.4f}")

print("\nüß¨ MIF CELL METRICS (all patches):")
print(f"  Dice:      {np.mean(all_mif_cell_dice):.4f} ¬± {np.std(all_mif_cell_dice):.4f}")
print(f"  IoU:       {np.mean(all_mif_cell_iou):.4f} ¬± {np.std(all_mif_cell_iou):.4f}")

# Statistics by cancer type
if crc_samples:
    print("\n" + "="*50)
    print(f"CRC SAMPLES ({len(crc_samples)} samples)")
    print("="*50)
    
    crc_he_nuclei_dice = [m['he_nuclei_dice'] for m in crc_samples.values()]
    crc_he_cell_dice = [m['he_cell_dice'] for m in crc_samples.values()]
    crc_mif_nuclei_dice = [m['mif_nuclei_dice'] for m in crc_samples.values()]
    crc_mif_cell_dice = [m['mif_cell_dice'] for m in crc_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei Dice: {np.mean(crc_he_nuclei_dice):.4f} ¬± {np.std(crc_he_nuclei_dice):.4f}")
    print(f"  Cell Dice:   {np.mean(crc_he_cell_dice):.4f} ¬± {np.std(crc_he_cell_dice):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei Dice: {np.mean(crc_mif_nuclei_dice):.4f} ¬± {np.std(crc_mif_nuclei_dice):.4f}")
    print(f"  Cell Dice:   {np.mean(crc_mif_cell_dice):.4f} ¬± {np.std(crc_mif_cell_dice):.4f}")

if xenium_samples:
    print("\n" + "="*50)
    print(f"XENIUM SAMPLES ({len(xenium_samples)} samples)")
    print("="*50)
    
    xenium_he_nuclei_dice = [m['he_nuclei_dice'] for m in xenium_samples.values()]
    xenium_he_cell_dice = [m['he_cell_dice'] for m in xenium_samples.values()]
    xenium_mif_nuclei_dice = [m['mif_nuclei_dice'] for m in xenium_samples.values()]
    xenium_mif_cell_dice = [m['mif_cell_dice'] for m in xenium_samples.values()]
    
    print(f"\nH&E:")
    print(f"  Nuclei Dice: {np.mean(xenium_he_nuclei_dice):.4f} ¬± {np.std(xenium_he_nuclei_dice):.4f}")
    print(f"  Cell Dice:   {np.mean(xenium_he_cell_dice):.4f} ¬± {np.std(xenium_he_cell_dice):.4f}")
    print(f"\nMIF:")
    print(f"  Nuclei Dice: {np.mean(xenium_mif_nuclei_dice):.4f} ¬± {np.std(xenium_mif_nuclei_dice):.4f}")
    print(f"  Cell Dice:   {np.mean(xenium_mif_cell_dice):.4f} ¬± {np.std(xenium_mif_cell_dice):.4f}")

print("\n" + "="*100)
print("‚úÖ Evaluation complete for VitaminPDual (both H&E and MIF)!")

Using device: cuda

üì¶ Loading VitaminPSyn model...
Building H&E encoder with DINOv2-base
Building Synthetic MIF encoder with DINOv2-base
Building shared encoder with DINOv2-base
‚úì VitaminPSyn initialized with base backbone
  Embed dim: 768 | Decoder dims: [768, 384, 192, 96]
‚úÖ VitaminPDual model loaded

üîÑ Evaluating on all test samples (H&E + MIF - DUAL MODEL - BINARY METRICS ONLY)...


Processing batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 127/127 [00:25<00:00,  4.91it/s]


üìä PER-SAMPLE RESULTS - H&E (VitaminPDual)

üî¨ CRC SAMPLES - H&E:
----------------------------------------------------------------------------------------------------
Sample        Patches   Nuclei Dice   Nuclei IoU    Cell Dice    Cell IoU
----------------------------------------------------------------------------------------------------
CRC15              18        0.8584       0.7526       0.9256      0.8626
CRC16             142        0.9138       0.8417       0.9626      0.9282
CRC17              76        0.8316       0.7132       0.9374      0.8827
CRC18              80        0.8745       0.7775       0.9314      0.8723
CRC19              29        0.8332       0.7165       0.9354      0.8793
CRC20              35        0.8291       0.7092       0.9197      0.8530
CRC39              62        0.8553       0.7492       0.9168      0.8478

üß¨ XENIUM SAMPLES - H&E:
----------------------------------------------------------------------------------------------------
Sample


