In [1]:
# # Install CellSAM
# # !pip install git+https://github.com/vanvalenlab/cellSAM.git
# # Alternative installation with all dependencies
# !pip install torch torchvision  # Make sure PyTorch is installed first
# !pip install git+https://github.com/vanvalenlab/cellSAM.git

In [2]:
# ALWAYS RUN THIS FIRST!
import os
import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Specify GPU 0 (out of 4 available GPUs)
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

NOTEBOOK_DIR = Path("/rsrch9/home/plm/idso_fa1_pathology/codes/yshokrollahi/vitamin-p-latest")
os.chdir(NOTEBOOK_DIR)
sys.path.insert(0, str(NOTEBOOK_DIR))
print(f"‚úÖ Working directory: {os.getcwd()}")
print(f"‚úÖ Using GPU: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")

‚úÖ Working directory: /rsrch9/home/plm/idso_fa1_pathology/codes/yshokrollahi/vitamin-p-latest
‚úÖ Using GPU: 2


In [3]:
# Step 1: Check what's available in cellSAM
import cellSAM
print("Available functions in cellSAM:")
print(dir(cellSAM))

Available functions in cellSAM:
['AnchorDETR', 'CellSAM', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__', '_auth', 'cellsam_pipeline', 'download_training_data', 'get_local_model', 'get_model', 'model', 'sam_inference', 'segment_cellular_image', 'utils', 'wsi']


In [4]:
import os
import numpy as np
from PIL import Image
from cellSAM import get_model, segment_cellular_image

# Step 1: Set the access token
os.environ['DEEPCELL_ACCESS_TOKEN'] = 'JbVVUStF.A6Ec6pe5vKsoB3RhTnSOaqXJ1thDE3B6'

# Step 2: Load the model (this will download it the first time)
print("Loading CellSAM model...")
model = get_model(model='cellsam_general')
print("Model loaded successfully!")

# Step 3: Load your test image
img = np.array(Image.open("test_images/prostate-he_chunk_12.png"))
print(f"Image shape: {img.shape}")
print(f"Image dtype: {img.dtype}")

# Step 4: Run segmentation
print("Running segmentation...")
mask, _, _ = segment_cellular_image(img, model=model, device='cuda')

print(f"Mask shape: {mask.shape}")
print("Segmentation complete!")

Loading CellSAM model...
Model loaded successfully!
Image shape: (512, 512, 3)
Image dtype: uint8
Running segmentation...
Mask shape: (512, 512)
Segmentation complete!


In [5]:
# import numpy as np
# import torch
# import time
# import pandas as pd
# from PIL import Image
# import matplotlib.pyplot as plt
# from cellSAM import segment_cellular_image, CellSAM
# from cellSAM.wsi import segment_wsi
# import psutil
# import gc
# import json
# import os
# import tifffile
# # ============================================
# # SETUP OUTPUT DIRECTORY
# # ============================================

# OUTPUT_DIR = "benchmark_results/cellsam"
# MASKS_DIR = os.path.join(OUTPUT_DIR, "masks")
# os.makedirs(OUTPUT_DIR, exist_ok=True)
# os.makedirs(MASKS_DIR, exist_ok=True)

# print(f"üìÅ Output directory: {OUTPUT_DIR}")
# print(f"   Masks will be saved to: {MASKS_DIR}")

# # ============================================
# # METRIC EXTRACTION FUNCTIONS
# # ============================================

# def compute_segmentation_metrics(mask):
#     """Compute detailed segmentation statistics"""
#     unique_labels = np.unique(mask)
#     n_cells = len(unique_labels) - 1  # exclude background (0)
    
#     # Cell size statistics
#     cell_sizes = []
#     for label in unique_labels[1:]:  # skip background
#         cell_sizes.append(np.sum(mask == label))
    
#     return {
#         'n_cells': n_cells,
#         'mean_cell_size': np.mean(cell_sizes) if cell_sizes else 0,
#         'median_cell_size': np.median(cell_sizes) if cell_sizes else 0,
#         'std_cell_size': np.std(cell_sizes) if cell_sizes else 0,
#         'min_cell_size': np.min(cell_sizes) if cell_sizes else 0,
#         'max_cell_size': np.max(cell_sizes) if cell_sizes else 0,
#     }

# def measure_gpu_memory():
#     """Measure current GPU memory usage"""
#     if torch.cuda.is_available():
#         return {
#             'allocated_gb': torch.cuda.memory_allocated() / 1024**3,
#             'reserved_gb': torch.cuda.memory_reserved() / 1024**3,
#             'max_allocated_gb': torch.cuda.max_memory_allocated() / 1024**3,
#         }
#     return {'allocated_gb': 0, 'reserved_gb': 0, 'max_allocated_gb': 0}

# def measure_cpu_memory():
#     """Measure current CPU memory usage"""
#     process = psutil.Process()
#     return process.memory_info().rss / 1024**3  # GB

# def save_mask(mask, filepath):
#     """Save mask as TIFF"""
#     Image.fromarray(mask.astype(np.uint16)).save(filepath)

# # ============================================
# # LOAD MODEL
# # ============================================

# print("\n" + "=" * 70)
# print("Loading CellSAM model...")
# print("  ‚úì Model loaded")

# # Load test images
# tile_512 = np.array(Image.open("test_images/ovarian-he_chunk_92.png"))
# wsi_5000 = tifffile.imread("test_images/breat_cancer_15000x15000.tiff")

# print(f"Tile shape: {tile_512.shape}")
# print(f"WSI shape: {wsi_5000.shape}")

# # ============================================
# # 1. SINGLE-TASK BENCHMARK: Nuclei Only
# # ============================================

# print("\n" + "=" * 70)
# print("CELLSAM BENCHMARK 1: SINGLE-TASK (Nuclei Only, 512√ó512)")
# print("=" * 70)

# torch.cuda.reset_peak_memory_stats()
# torch.cuda.empty_cache()
# gc.collect()

# # Warmup
# print("Warmup runs...")
# for _ in range(3):
#     _ = segment_cellular_image(tile_512, model=model, device='cuda')

# # Benchmark
# n_runs = 10
# times_nuclei = []
# memory_nuclei = []
# n_nuclei_list = []

# print(f"Running {n_runs} timed iterations...")
# for i in range(n_runs):
#     torch.cuda.reset_peak_memory_stats()
#     torch.cuda.synchronize()
    
#     start = time.time()
#     mask_nuclei, _, _ = segment_cellular_image(tile_512, model=model, device='cuda')
#     torch.cuda.synchronize()
#     end = time.time()
    
#     times_nuclei.append(end - start)
#     memory_nuclei.append(measure_gpu_memory()['max_allocated_gb'])
    
#     seg_metrics = compute_segmentation_metrics(mask_nuclei)
#     n_nuclei_list.append(seg_metrics['n_cells'])
    
#     print(f"  Run {i+1}/{n_runs}: {times_nuclei[-1]:.3f}s, {memory_nuclei[-1]:.2f} GB, Nuclei: {n_nuclei_list[-1]}")

# # Save mask from last run
# print("Saving nuclei masks...")
# save_mask(mask_nuclei, os.path.join(MASKS_DIR, 'tile_512_nuclei.tif'))

# results_nuclei = {
#     'task': 'Nuclei Only',
#     'image_size': '512√ó512',
#     'n_instances': n_nuclei_list[-1],
#     'mean_time_s': np.mean(times_nuclei),
#     'std_time_s': np.std(times_nuclei),
#     'mean_time_ms': np.mean(times_nuclei) * 1000,
#     'peak_memory_gb': np.mean(memory_nuclei),
#     'instances_per_second': n_nuclei_list[-1] / np.mean(times_nuclei),
# }

# print("\nRESULTS:")
# for key, value in results_nuclei.items():
#     if isinstance(value, float):
#         print(f"  {key}: {value:.4f}")
#     else:
#         print(f"  {key}: {value}")

# # ============================================
# # 2. SINGLE-TASK BENCHMARK: Cells Only
# # ============================================

# print("\n" + "=" * 70)
# print("CELLSAM BENCHMARK 2: SINGLE-TASK (Cells Only, 512√ó512)")
# print("=" * 70)

# torch.cuda.reset_peak_memory_stats()
# torch.cuda.empty_cache()
# gc.collect()

# # Warmup
# print("Warmup runs...")
# for _ in range(3):
#     _ = segment_cellular_image(tile_512, model=model, device='cuda')

# # Benchmark
# times_cells = []
# memory_cells = []
# n_cells_list = []

# print(f"Running {n_runs} timed iterations...")
# for i in range(n_runs):
#     torch.cuda.reset_peak_memory_stats()
#     torch.cuda.synchronize()
    
#     start = time.time()
#     mask_cells, _, _ = segment_cellular_image(tile_512, model=model, device='cuda')
#     torch.cuda.synchronize()
#     end = time.time()
    
#     times_cells.append(end - start)
#     memory_cells.append(measure_gpu_memory()['max_allocated_gb'])
    
#     seg_metrics = compute_segmentation_metrics(mask_cells)
#     n_cells_list.append(seg_metrics['n_cells'])
    
#     print(f"  Run {i+1}/{n_runs}: {times_cells[-1]:.3f}s, {memory_cells[-1]:.2f} GB, Cells: {n_cells_list[-1]}")

# # Save mask from last run
# print("Saving cell masks...")
# save_mask(mask_cells, os.path.join(MASKS_DIR, 'tile_512_cells.tif'))

# results_cells = {
#     'task': 'Cells Only',
#     'image_size': '512√ó512',
#     'n_instances': n_cells_list[-1],
#     'mean_time_s': np.mean(times_cells),
#     'std_time_s': np.std(times_cells),
#     'mean_time_ms': np.mean(times_cells) * 1000,
#     'peak_memory_gb': np.mean(memory_cells),
#     'instances_per_second': n_cells_list[-1] / np.mean(times_cells),
# }

# print("\nRESULTS:")
# for key, value in results_cells.items():
#     if isinstance(value, float):
#         print(f"  {key}: {value:.4f}")
#     else:
#         print(f"  {key}: {value}")

# # ============================================
# # 3. DUAL-TASK SIMULATION: Run Twice
# # ============================================

# print("\n" + "=" * 70)
# print("‚≠ê CELLSAM BENCHMARK 3: DUAL-TASK SIMULATION (Run Twice, 512√ó512)")
# print("=" * 70)
# print("Note: CellSAM requires TWO separate runs for nuclei + cells")

# torch.cuda.reset_peak_memory_stats()
# torch.cuda.empty_cache()
# gc.collect()

# # Warmup
# print("Warmup runs...")
# for _ in range(3):
#     # Run 1: Nuclei
#     _ = segment_cellular_image(tile_512, model=model, device='cuda')
#     # Run 2: Cells
#     _ = segment_cellular_image(tile_512, model=model, device='cuda')

# # Benchmark - Run TWICE per iteration
# times_dual = []
# memory_dual = []
# n_nuclei_dual = []
# n_cells_dual = []

# print(f"Running {n_runs} timed iterations (2 runs each)...")
# for i in range(n_runs):
#     torch.cuda.reset_peak_memory_stats()
#     torch.cuda.synchronize()
    
#     # RUN 1: Nuclei
#     start = time.time()
#     mask_n, _, _ = segment_cellular_image(tile_512, model=model, device='cuda')
#     torch.cuda.synchronize()
#     time_nuclei = time.time() - start
    
#     # RUN 2: Cells
#     start = time.time()
#     mask_c, _, _ = segment_cellular_image(tile_512, model=model, device='cuda')
#     torch.cuda.synchronize()
#     time_cells = time.time() - start
    
#     # Total time
#     total_time = time_nuclei + time_cells
#     times_dual.append(total_time)
#     memory_dual.append(measure_gpu_memory()['max_allocated_gb'])
    
#     # Count instances
#     seg_n = compute_segmentation_metrics(mask_n)
#     seg_c = compute_segmentation_metrics(mask_c)
#     n_nuclei_dual.append(seg_n['n_cells'])
#     n_cells_dual.append(seg_c['n_cells'])
    
#     print(f"  Run {i+1}/{n_runs}: {total_time:.3f}s ({time_nuclei:.3f}s + {time_cells:.3f}s), {memory_dual[-1]:.2f} GB, Nuclei: {n_nuclei_dual[-1]}, Cells: {n_cells_dual[-1]}")

# # Save masks from last run
# print("Saving dual-task masks...")
# save_mask(mask_n, os.path.join(MASKS_DIR, 'tile_512_dual_nuclei.tif'))
# save_mask(mask_c, os.path.join(MASKS_DIR, 'tile_512_dual_cells.tif'))

# total_instances_dual = n_nuclei_dual[-1] + n_cells_dual[-1]
# mean_time_dual = np.mean(times_dual)
# mean_time_single = np.mean(times_nuclei)

# results_dual = {
#     'task': 'Nuclei + Cells (Sequential)',
#     'image_size': '512√ó512',
#     'n_nuclei': n_nuclei_dual[-1],
#     'n_cells': n_cells_dual[-1],
#     'total_instances': total_instances_dual,
#     'mean_time_s': mean_time_dual,
#     'std_time_s': np.std(times_dual),
#     'mean_time_ms': mean_time_dual * 1000,
#     'peak_memory_gb': np.mean(memory_dual),
#     'instances_per_second': total_instances_dual / mean_time_dual,
#     'overhead_vs_single_pct': ((mean_time_dual / mean_time_single) - 1) * 100,
# }

# print("\nRESULTS:")
# for key, value in results_dual.items():
#     if isinstance(value, float):
#         print(f"  {key}: {value:.4f}")
#     else:
#         print(f"  {key}: {value}")

# print("\n‚≠ê KEY INSIGHTS:")
# print(f"  Single-task (nuclei only): {mean_time_single:.3f}s")
# print(f"  Dual-task (2 sequential runs): {mean_time_dual:.3f}s")
# print(f"  Overhead: {results_dual['overhead_vs_single_pct']:.1f}%")
# print(f"  üî¥ CellSAM requires 2 separate runs (no simultaneous dual-task)")

# # ============================================
# # 4. WSI BENCHMARK: Nuclei Only
# # ============================================

# print("\n" + "=" * 70)
# print("CELLSAM BENCHMARK 4: WSI SINGLE-TASK (Nuclei Only, 15000√ó15000)")
# print("=" * 70)

# torch.cuda.reset_peak_memory_stats()
# torch.cuda.empty_cache()
# gc.collect()

# cpu_mem_before = measure_cpu_memory()

# start_wsi = time.time()
# mask_wsi_nuclei = segment_wsi(
#     wsi_5000,
#     block_size=512,
#     overlap=64,
#     iou_depth=1,
#     iou_threshold=0.5,
#     model=model,
#     device='cuda'
# )

# # Convert dask array if needed
# if hasattr(mask_wsi_nuclei, 'compute'):
#     mask_wsi_nuclei = mask_wsi_nuclei.compute()

# torch.cuda.synchronize()
# elapsed_wsi = time.time() - start_wsi

# cpu_mem_after = measure_cpu_memory()
# gpu_mem_wsi = measure_gpu_memory()

# seg_metrics_wsi = compute_segmentation_metrics(mask_wsi_nuclei)
# n_nuclei_wsi = seg_metrics_wsi['n_cells']

# # Save WSI mask
# print("Saving WSI nuclei masks...")
# save_mask(mask_wsi_nuclei, os.path.join(MASKS_DIR, 'wsi_15000_nuclei.tif'))

# # Tile statistics
# tile_size = 512
# overlap_pixels = 64
# stride = tile_size - overlap_pixels
# n_tiles_x = int(np.ceil(wsi_5000.shape[1] / stride))
# n_tiles_y = int(np.ceil(wsi_5000.shape[0] / stride))
# total_tiles = n_tiles_x * n_tiles_y

# results_wsi_nuclei = {
#     'task': 'Nuclei Only',
#     'image_size': '15000√ó15000',
#     'n_tiles': total_tiles,
#     'tile_size': '512√ó512',
#     'n_instances': n_nuclei_wsi,
#     'total_time_s': elapsed_wsi,
#     'total_time_min': elapsed_wsi / 60,
#     'time_per_tile_ms': (elapsed_wsi / total_tiles) * 1000,
#     'peak_gpu_memory_gb': gpu_mem_wsi['max_allocated_gb'],
#     'instances_per_second': n_nuclei_wsi / elapsed_wsi,
#     'throughput_mpx_per_min': (15000 * 15000 / 1e6) / (elapsed_wsi / 60),
# }

# print("\nRESULTS:")
# for key, value in results_wsi_nuclei.items():
#     if isinstance(value, float):
#         print(f"  {key}: {value:.4f}")
#     else:
#         print(f"  {key}: {value}")

# # ============================================
# # 5. WSI BENCHMARK: Dual-Task Simulation
# # ============================================

# print("\n" + "=" * 70)
# print("‚≠ê CELLSAM BENCHMARK 5: WSI DUAL-TASK SIMULATION (15000√ó15000)")
# print("=" * 70)
# print("Note: Running CellSAM TWICE for nuclei + cells")

# torch.cuda.reset_peak_memory_stats()
# torch.cuda.empty_cache()
# gc.collect()

# # RUN 1: Nuclei
# print("Running nuclei segmentation...")
# start = time.time()
# mask_wsi_n = segment_wsi(
#     wsi_5000,
#     block_size=512,
#     overlap=64,
#     iou_depth=1,
#     iou_threshold=0.5,
#     model=model,
#     device='cuda'
# )
# if hasattr(mask_wsi_n, 'compute'):
#     mask_wsi_n = mask_wsi_n.compute()
# torch.cuda.synchronize()
# time_wsi_nuclei = time.time() - start

# # RUN 2: Cells
# print("Running cell segmentation...")
# start = time.time()
# mask_wsi_c = segment_wsi(
#     wsi_5000,
#     block_size=512,
#     overlap=64,
#     iou_depth=1,
#     iou_threshold=0.5,
#     model=model,
#     device='cuda'
# )
# if hasattr(mask_wsi_c, 'compute'):
#     mask_wsi_c = mask_wsi_c.compute()
# torch.cuda.synchronize()
# time_wsi_cells = time.time() - start

# elapsed_wsi_dual = time_wsi_nuclei + time_wsi_cells
# gpu_mem_wsi_dual = measure_gpu_memory()

# seg_n_wsi = compute_segmentation_metrics(mask_wsi_n)
# seg_c_wsi = compute_segmentation_metrics(mask_wsi_c)
# n_nuclei_wsi_dual = seg_n_wsi['n_cells']
# n_cells_wsi_dual = seg_c_wsi['n_cells']
# total_instances_wsi = n_nuclei_wsi_dual + n_cells_wsi_dual

# # Save WSI dual masks
# print("Saving WSI dual-task masks...")
# save_mask(mask_wsi_n, os.path.join(MASKS_DIR, 'wsi_15000_dual_nuclei.tif'))
# save_mask(mask_wsi_c, os.path.join(MASKS_DIR, 'wsi_15000_dual_cells.tif'))

# results_wsi_dual = {
#     'task': 'Nuclei + Cells (Sequential)',
#     'image_size': '15000√ó15000',
#     'n_tiles': total_tiles,
#     'tile_size': '512√ó512',
#     'n_nuclei': n_nuclei_wsi_dual,
#     'n_cells': n_cells_wsi_dual,
#     'total_instances': total_instances_wsi,
#     'total_time_s': elapsed_wsi_dual,
#     'total_time_min': elapsed_wsi_dual / 60,
#     'time_per_tile_ms': (elapsed_wsi_dual / total_tiles) * 1000,
#     'peak_gpu_memory_gb': gpu_mem_wsi_dual['max_allocated_gb'],
#     'instances_per_second': total_instances_wsi / elapsed_wsi_dual,
#     'throughput_mpx_per_min': (15000 * 15000 / 1e6) / (elapsed_wsi_dual / 60),
#     'overhead_vs_single_pct': ((elapsed_wsi_dual / elapsed_wsi) - 1) * 100,
# }

# print("\nRESULTS:")
# for key, value in results_wsi_dual.items():
#     if isinstance(value, float):
#         print(f"  {key}: {value:.4f}")
#     else:
#         print(f"  {key}: {value}")

# print("\n‚≠ê KEY INSIGHTS:")
# print(f"  Single-task WSI: {elapsed_wsi / 60:.2f} min")
# print(f"  Dual-task WSI (2 sequential runs): {elapsed_wsi_dual / 60:.2f} min ({time_wsi_nuclei / 60:.2f} + {time_wsi_cells / 60:.2f})")
# print(f"  Overhead: {results_wsi_dual['overhead_vs_single_pct']:.1f}%")
# print(f"  üî¥ CellSAM requires 2 separate runs (no simultaneous dual-task)")

# # ============================================
# # 6. COMPILE ALL RESULTS
# # ============================================

# print("\n" + "=" * 70)
# print("CELLSAM COMPREHENSIVE BENCHMARK SUMMARY")
# print("=" * 70)

# summary_df = pd.DataFrame([
#     {
#         'Test': 'Single Tile - Nuclei',
#         'Task': 'Nuclei',
#         'Size': '512√ó512',
#         'Instances': results_nuclei['n_instances'],
#         'Time (s)': results_nuclei['mean_time_s'],
#         'Memory (GB)': results_nuclei['peak_memory_gb'],
#         'Inst/sec': results_nuclei['instances_per_second'],
#     },
#     {
#         'Test': 'Single Tile - Cells',
#         'Task': 'Cells',
#         'Size': '512√ó512',
#         'Instances': results_cells['n_instances'],
#         'Time (s)': results_cells['mean_time_s'],
#         'Memory (GB)': results_cells['peak_memory_gb'],
#         'Inst/sec': results_cells['instances_per_second'],
#     },
#     {
#         'Test': 'Single Tile - DUAL',
#         'Task': 'Both (2 runs)',
#         'Size': '512√ó512',
#         'Instances': results_dual['total_instances'],
#         'Time (s)': results_dual['mean_time_s'],
#         'Memory (GB)': results_dual['peak_memory_gb'],
#         'Inst/sec': results_dual['instances_per_second'],
#     },
#     {
#         'Test': 'WSI - Nuclei',
#         'Task': 'Nuclei',
#         'Size': '15000√ó15000',
#         'Instances': results_wsi_nuclei['n_instances'],
#         'Time (s)': results_wsi_nuclei['total_time_s'],
#         'Memory (GB)': results_wsi_nuclei['peak_gpu_memory_gb'],
#         'Inst/sec': results_wsi_nuclei['instances_per_second'],
#     },
#     {
#         'Test': 'WSI - DUAL',
#         'Task': 'Both (2 runs)',
#         'Size': '15000√ó15000',
#         'Instances': results_wsi_dual['total_instances'],
#         'Time (s)': results_wsi_dual['total_time_s'],
#         'Memory (GB)': results_wsi_dual['peak_gpu_memory_gb'],
#         'Inst/sec': results_wsi_dual['instances_per_second'],
#     }
# ])

# print("\n" + summary_df.to_string(index=False))

# # ============================================
# # 7. SAVE ALL RESULTS
# # ============================================

# print("\n" + "=" * 70)
# print("SAVING RESULTS...")
# print("=" * 70)

# # Save CSV
# csv_path = os.path.join(OUTPUT_DIR, 'cellsam_benchmark_summary.csv')
# summary_df.to_csv(csv_path, index=False)
# print(f"‚úì Saved CSV: {csv_path}")

# # Save JSON
# all_results = {
#     'single_tile_nuclei': results_nuclei,
#     'single_tile_cells': results_cells,
#     'single_tile_dual': results_dual,
#     'wsi_nuclei': results_wsi_nuclei,
#     'wsi_dual': results_wsi_dual,
#     'hardware': {
#         'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A',
#         'gpu_memory_total_gb': torch.cuda.get_device_properties(0).total_memory / 1024**3 if torch.cuda.is_available() else 0,
#     },
#     'benchmark_date': time.strftime('%Y-%m-%d %H:%M:%S'),
# }

# json_path = os.path.join(OUTPUT_DIR, 'cellsam_benchmark_complete.json')
# with open(json_path, 'w') as f:
#     json.dump(all_results, f, indent=2)
# print(f"‚úì Saved JSON: {json_path}")

# # ============================================
# # 8. VISUALIZATION
# # ============================================

# print("\nGenerating plots...")

# fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# # Plot 1: Single Tile Comparison
# tasks = ['Nuclei\nOnly', 'Cells\nOnly', 'Both\n(2 runs)']
# times = [results_nuclei['mean_time_ms'], results_cells['mean_time_ms'], results_dual['mean_time_ms']]
# axes[0, 0].bar(tasks, times, color=['steelblue', 'coral', 'red'])
# axes[0, 0].set_ylabel('Time (ms)', fontsize=12)
# axes[0, 0].set_title('Single Tile Processing Time (512√ó512)', fontsize=14, fontweight='bold')
# axes[0, 0].grid(True, alpha=0.3, axis='y')

# # Plot 2: WSI Comparison
# wsi_tasks = ['Nuclei\nOnly', 'Both\n(2 runs)']
# wsi_times = [
#     results_wsi_nuclei['total_time_min'],
#     results_wsi_dual['total_time_min']
# ]
# axes[0, 1].bar(wsi_tasks, wsi_times, color=['steelblue', 'red'])
# axes[0, 1].set_ylabel('Time (minutes)', fontsize=12)
# axes[0, 1].set_title('WSI Processing Time (15000√ó15000)', fontsize=14, fontweight='bold')
# axes[0, 1].grid(True, alpha=0.3, axis='y')

# # Plot 3: Throughput
# throughputs = [
#     results_nuclei['instances_per_second'],
#     results_cells['instances_per_second'],
#     results_dual['instances_per_second']
# ]
# axes[1, 0].bar(tasks, throughputs, color=['steelblue', 'coral', 'red'])
# axes[1, 0].set_ylabel('Instances/Second', fontsize=12)
# axes[1, 0].set_title('Segmentation Throughput', fontsize=14, fontweight='bold')
# axes[1, 0].grid(True, alpha=0.3, axis='y')

# # Plot 4: Memory Usage
# memory_vals = [
#     results_nuclei['peak_memory_gb'],
#     results_cells['peak_memory_gb'],
#     results_dual['peak_memory_gb']
# ]
# axes[1, 1].bar(tasks, memory_vals, color=['steelblue', 'coral', 'red'])
# axes[1, 1].set_ylabel('Peak GPU Memory (GB)', fontsize=12)
# axes[1, 1].set_title('Memory Usage', fontsize=14, fontweight='bold')
# axes[1, 1].grid(True, alpha=0.3, axis='y')

# plt.tight_layout()
# plot_path = os.path.join(OUTPUT_DIR, 'cellsam_benchmark_plots.png')
# plt.savefig(plot_path, dpi=300, bbox_inches='tight')
# print(f"‚úì Saved plot: {plot_path}")
# plt.close()

# # ============================================
# # 9. LIST SAVED MASK FILES
# # ============================================

# print("\nüìÅ Saved mask files:")
# mask_files = sorted([f for f in os.listdir(MASKS_DIR) if f.endswith('.tif')])
# for mf in mask_files:
#     size_mb = os.path.getsize(os.path.join(MASKS_DIR, mf)) / (1024**2)
#     print(f"  - {mf} ({size_mb:.2f} MB)")

# print("\n" + "=" * 70)
# print("BENCHMARK COMPLETE!")
# print("=" * 70)
# print(f"\nüìÅ All results saved to: {os.path.abspath(OUTPUT_DIR)}")
# print("\n‚≠ê KEY FINDINGS:")
# print(f"  Single Tile - Nuclei: {results_nuclei['mean_time_ms']:.1f} ms")
# print(f"  Single Tile - Dual (2 runs): {results_dual['mean_time_ms']:.1f} ms")
# print(f"  WSI - Nuclei: {results_wsi_nuclei['total_time_min']:.2f} min")
# print(f"  WSI - Dual (2 runs): {results_wsi_dual['total_time_min']:.2f} min")
# print(f"  üî¥ CellSAM requires 2 separate runs for nuclei + cells")
# print("\nüìä Files created:")
# print(f"  - {os.path.basename(csv_path)}")
# print(f"  - {os.path.basename(json_path)}")
# print(f"  - {os.path.basename(plot_path)}")
# print(f"  - masks/ folder with {len(mask_files)} segmentation outputs")
# print("=" * 70)

In [None]:
import numpy as np
import torch
import time
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from cellSAM import segment_cellular_image, CellSAM
from cellSAM.wsi import segment_wsi
import psutil
import gc
import json
import os
import tifffile

# ============================================
# SETUP OUTPUT DIRECTORY
# ============================================

OUTPUT_DIR = "benchmark_results/cellsam"
MASKS_DIR = os.path.join(OUTPUT_DIR, "masks")
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(MASKS_DIR, exist_ok=True)

print(f"üìÅ Output directory: {OUTPUT_DIR}")
print(f"   Masks will be saved to: {MASKS_DIR}")

# ============================================
# METRIC EXTRACTION FUNCTIONS
# ============================================

def compute_segmentation_metrics(mask):
    """Compute detailed segmentation statistics"""
    unique_labels = np.unique(mask)
    n_cells = len(unique_labels) - 1  # exclude background (0)
    
    # Cell size statistics
    cell_sizes = []
    for label in unique_labels[1:]:  # skip background
        cell_sizes.append(np.sum(mask == label))
    
    return {
        'n_cells': n_cells,
        'mean_cell_size': np.mean(cell_sizes) if cell_sizes else 0,
        'median_cell_size': np.median(cell_sizes) if cell_sizes else 0,
        'std_cell_size': np.std(cell_sizes) if cell_sizes else 0,
        'min_cell_size': np.min(cell_sizes) if cell_sizes else 0,
        'max_cell_size': np.max(cell_sizes) if cell_sizes else 0,
    }

def measure_gpu_memory():
    """Measure current GPU memory usage"""
    if torch.cuda.is_available():
        return {
            'allocated_gb': torch.cuda.memory_allocated() / 1024**3,
            'reserved_gb': torch.cuda.memory_reserved() / 1024**3,
            'max_allocated_gb': torch.cuda.max_memory_allocated() / 1024**3,
        }
    return {'allocated_gb': 0, 'reserved_gb': 0, 'max_allocated_gb': 0}

def measure_cpu_memory():
    """Measure current CPU memory usage"""
    process = psutil.Process()
    return process.memory_info().rss / 1024**3  # GB

def save_mask(mask, filepath):
    """Save mask as TIFF"""
    Image.fromarray(mask.astype(np.uint16)).save(filepath)

# ============================================
# LOAD MODEL & TEST IMAGES
# ============================================

print("\n" + "=" * 70)
print("Loading CellSAM model...")

# Set the access token
os.environ['DEEPCELL_ACCESS_TOKEN'] = 'JbVVUStF.A6Ec6pe5vKsoB3RhTnSOaqXJ1thDE3B6'

# Load the model using get_model
from cellSAM import get_model
model = get_model(model='cellsam_general')
print("  ‚úì Model loaded successfully")

# Load test images
tile_512 = np.array(Image.open("test_images/ovarian-he_chunk_92.png"))
wsi_5000 = tifffile.imread("test_images/breat_cancer_15000x15000.tiff")

print(f"Tile shape: {tile_512.shape}")
print(f"WSI shape: {wsi_5000.shape}")

# ============================================
# 1. SINGLE-TASK BENCHMARK: Nuclei Only
# ============================================

print("\n" + "=" * 70)
print("CELLSAM BENCHMARK 1: SINGLE-TASK (Nuclei Only, 512√ó512)")
print("=" * 70)

torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()

# Warmup
print("Warmup runs...")
for _ in range(3):
    _ = segment_cellular_image(tile_512, model=model, device='cuda')

# Benchmark
n_runs = 10
times_nuclei = []
memory_nuclei = []
n_nuclei_list = []

print(f"Running {n_runs} timed iterations...")
for i in range(n_runs):
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    
    start = time.time()
    mask_nuclei, _, _ = segment_cellular_image(tile_512, model=model, device='cuda')
    torch.cuda.synchronize()
    end = time.time()
    
    times_nuclei.append(end - start)
    memory_nuclei.append(measure_gpu_memory()['max_allocated_gb'])
    
    seg_metrics = compute_segmentation_metrics(mask_nuclei)
    n_nuclei_list.append(seg_metrics['n_cells'])
    
    print(f"  Run {i+1}/{n_runs}: {times_nuclei[-1]:.3f}s, {memory_nuclei[-1]:.2f} GB, Nuclei: {n_nuclei_list[-1]}")

# Save mask from last run
print("Saving nuclei masks...")
save_mask(mask_nuclei, os.path.join(MASKS_DIR, 'tile_512_nuclei.tif'))

results_nuclei = {
    'task': 'Nuclei Only',
    'image_size': '512√ó512',
    'n_instances': n_nuclei_list[-1],
    'mean_time_s': np.mean(times_nuclei),
    'std_time_s': np.std(times_nuclei),
    'mean_time_ms': np.mean(times_nuclei) * 1000,
    'peak_memory_gb': np.mean(memory_nuclei),
    'instances_per_second': n_nuclei_list[-1] / np.mean(times_nuclei),
}

print("\nRESULTS:")
for key, value in results_nuclei.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

# ============================================
# 2. SINGLE-TASK BENCHMARK: Cells Only
# ============================================

print("\n" + "=" * 70)
print("CELLSAM BENCHMARK 2: SINGLE-TASK (Cells Only, 512√ó512)")
print("=" * 70)

torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()

# Warmup
print("Warmup runs...")
for _ in range(3):
    _ = segment_cellular_image(tile_512, model=model, device='cuda')

# Benchmark
times_cells = []
memory_cells = []
n_cells_list = []

print(f"Running {n_runs} timed iterations...")
for i in range(n_runs):
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    
    start = time.time()
    mask_cells, _, _ = segment_cellular_image(tile_512, model=model, device='cuda')
    torch.cuda.synchronize()
    end = time.time()
    
    times_cells.append(end - start)
    memory_cells.append(measure_gpu_memory()['max_allocated_gb'])
    
    seg_metrics = compute_segmentation_metrics(mask_cells)
    n_cells_list.append(seg_metrics['n_cells'])
    
    print(f"  Run {i+1}/{n_runs}: {times_cells[-1]:.3f}s, {memory_cells[-1]:.2f} GB, Cells: {n_cells_list[-1]}")

# Save mask from last run
print("Saving cell masks...")
save_mask(mask_cells, os.path.join(MASKS_DIR, 'tile_512_cells.tif'))

results_cells = {
    'task': 'Cells Only',
    'image_size': '512√ó512',
    'n_instances': n_cells_list[-1],
    'mean_time_s': np.mean(times_cells),
    'std_time_s': np.std(times_cells),
    'mean_time_ms': np.mean(times_cells) * 1000,
    'peak_memory_gb': np.mean(memory_cells),
    'instances_per_second': n_cells_list[-1] / np.mean(times_cells),
}

print("\nRESULTS:")
for key, value in results_cells.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

# ============================================
# 3. DUAL-TASK SIMULATION: Run Twice
# ============================================

print("\n" + "=" * 70)
print("‚≠ê CELLSAM BENCHMARK 3: DUAL-TASK SIMULATION (Run Twice, 512√ó512)")
print("=" * 70)
print("Note: CellSAM requires TWO separate runs for nuclei + cells")

torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()

# Warmup
print("Warmup runs...")
for _ in range(3):
    # Run 1: Nuclei
    _ = segment_cellular_image(tile_512, model=model, device='cuda')
    # Run 2: Cells
    _ = segment_cellular_image(tile_512, model=model, device='cuda')

# Benchmark - Run TWICE per iteration
times_dual = []
memory_dual = []
n_nuclei_dual = []
n_cells_dual = []

print(f"Running {n_runs} timed iterations (2 runs each)...")
for i in range(n_runs):
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    
    # RUN 1: Nuclei
    start = time.time()
    mask_n, _, _ = segment_cellular_image(tile_512, model=model, device='cuda')
    torch.cuda.synchronize()
    time_nuclei = time.time() - start
    
    # RUN 2: Cells
    start = time.time()
    mask_c, _, _ = segment_cellular_image(tile_512, model=model, device='cuda')
    torch.cuda.synchronize()
    time_cells = time.time() - start
    
    # Total time
    total_time = time_nuclei + time_cells
    times_dual.append(total_time)
    memory_dual.append(measure_gpu_memory()['max_allocated_gb'])
    
    # Count instances
    seg_n = compute_segmentation_metrics(mask_n)
    seg_c = compute_segmentation_metrics(mask_c)
    n_nuclei_dual.append(seg_n['n_cells'])
    n_cells_dual.append(seg_c['n_cells'])
    
    print(f"  Run {i+1}/{n_runs}: {total_time:.3f}s ({time_nuclei:.3f}s + {time_cells:.3f}s), {memory_dual[-1]:.2f} GB, Nuclei: {n_nuclei_dual[-1]}, Cells: {n_cells_dual[-1]}")

# Save masks from last run
print("Saving dual-task masks...")
save_mask(mask_n, os.path.join(MASKS_DIR, 'tile_512_dual_nuclei.tif'))
save_mask(mask_c, os.path.join(MASKS_DIR, 'tile_512_dual_cells.tif'))

total_instances_dual = n_nuclei_dual[-1] + n_cells_dual[-1]
mean_time_dual = np.mean(times_dual)
mean_time_single = np.mean(times_nuclei)

results_dual = {
    'task': 'Nuclei + Cells (Sequential)',
    'image_size': '512√ó512',
    'n_nuclei': n_nuclei_dual[-1],
    'n_cells': n_cells_dual[-1],
    'total_instances': total_instances_dual,
    'mean_time_s': mean_time_dual,
    'std_time_s': np.std(times_dual),
    'mean_time_ms': mean_time_dual * 1000,
    'peak_memory_gb': np.mean(memory_dual),
    'instances_per_second': total_instances_dual / mean_time_dual,
    'overhead_vs_single_pct': ((mean_time_dual / mean_time_single) - 1) * 100,
}

print("\nRESULTS:")
for key, value in results_dual.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

print("\n‚≠ê KEY INSIGHTS:")
print(f"  Single-task (nuclei only): {mean_time_single:.3f}s")
print(f"  Dual-task (2 sequential runs): {mean_time_dual:.3f}s")
print(f"  Overhead: {results_dual['overhead_vs_single_pct']:.1f}%")
print(f"  üî¥ CellSAM requires 2 separate runs (no simultaneous dual-task)")

# ============================================
# 4. WSI BENCHMARK: Nuclei Only
# ============================================

print("\n" + "=" * 70)
print("CELLSAM BENCHMARK 4: WSI SINGLE-TASK (Nuclei Only, 15000√ó15000)")
print("=" * 70)
print("‚è±Ô∏è  This will take a while - please be patient...")

torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()

cpu_mem_before = measure_cpu_memory()

# Start timing - INCLUDE EVERYTHING
start_wsi = time.time()

mask_wsi_nuclei = segment_wsi(
    wsi_5000,
    block_size=512,
    overlap=64,
    iou_depth=1,
    iou_threshold=0.5,
    model=model,
    device='cuda'
)

# Convert dask array if needed
if hasattr(mask_wsi_nuclei, 'compute'):
    mask_wsi_nuclei = mask_wsi_nuclei.compute()

# Force ALL computation to complete
print("  Computing segmentation metrics (forces computation)...")
seg_metrics_wsi = compute_segmentation_metrics(mask_wsi_nuclei)
n_nuclei_wsi = seg_metrics_wsi['n_cells']

# Synchronize GPU
torch.cuda.synchronize()

# Save masks (this might also take time)
print("  Saving WSI nuclei masks...")
save_mask(mask_wsi_nuclei, os.path.join(MASKS_DIR, 'wsi_15000_nuclei.tif'))

# Final sync and time
torch.cuda.synchronize()
elapsed_wsi = time.time() - start_wsi

cpu_mem_after = measure_cpu_memory()
gpu_mem_wsi = measure_gpu_memory()

# Tile statistics
tile_size = 512
overlap_pixels = 64
stride = tile_size - overlap_pixels
n_tiles_x = int(np.ceil(wsi_5000.shape[1] / stride))
n_tiles_y = int(np.ceil(wsi_5000.shape[0] / stride))
total_tiles = n_tiles_x * n_tiles_y

results_wsi_nuclei = {
    'task': 'Nuclei Only',
    'image_size': '15000√ó15000',
    'n_tiles': total_tiles,
    'tile_size': '512√ó512',
    'n_instances': n_nuclei_wsi,
    'total_time_s': elapsed_wsi,
    'total_time_min': elapsed_wsi / 60,
    'time_per_tile_ms': (elapsed_wsi / total_tiles) * 1000,
    'peak_gpu_memory_gb': gpu_mem_wsi['max_allocated_gb'],
    'instances_per_second': n_nuclei_wsi / elapsed_wsi,
    'throughput_mpx_per_min': (15000 * 15000 / 1e6) / (elapsed_wsi / 60),
}

print("\nRESULTS:")
for key, value in results_wsi_nuclei.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

# ============================================
# 5. WSI BENCHMARK: Dual-Task Simulation
# ============================================

print("\n" + "=" * 70)
print("‚≠ê CELLSAM BENCHMARK 5: WSI DUAL-TASK SIMULATION (15000√ó15000)")
print("=" * 70)
print("Note: Running CellSAM TWICE for nuclei + cells")
print("‚è±Ô∏è  This will take even longer - please be patient...")

torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()

# RUN 1: Nuclei
print("\nüî¨ Running nuclei segmentation...")
start_total = time.time()

start = time.time()
mask_wsi_n = segment_wsi(
    wsi_5000,
    block_size=512,
    overlap=64,
    iou_depth=1,
    iou_threshold=0.5,
    model=model,
    device='cuda'
)

if hasattr(mask_wsi_n, 'compute'):
    mask_wsi_n = mask_wsi_n.compute()

# Force computation
print("  Computing nuclei metrics...")
seg_n_wsi = compute_segmentation_metrics(mask_wsi_n)
n_nuclei_wsi_dual = seg_n_wsi['n_cells']
torch.cuda.synchronize()

# Save nuclei masks
print("  Saving nuclei masks...")
save_mask(mask_wsi_n, os.path.join(MASKS_DIR, 'wsi_15000_dual_nuclei.tif'))
torch.cuda.synchronize()

time_wsi_nuclei = time.time() - start
print(f"  ‚úì Nuclei complete: {time_wsi_nuclei / 60:.2f} min")

# RUN 2: Cells
print("\nüî¨ Running cell segmentation...")
start = time.time()
mask_wsi_c = segment_wsi(
    wsi_5000,
    block_size=512,
    overlap=64,
    iou_depth=1,
    iou_threshold=0.5,
    model=model,
    device='cuda'
)

if hasattr(mask_wsi_c, 'compute'):
    mask_wsi_c = mask_wsi_c.compute()

# Force computation
print("  Computing cell metrics...")
seg_c_wsi = compute_segmentation_metrics(mask_wsi_c)
n_cells_wsi_dual = seg_c_wsi['n_cells']
torch.cuda.synchronize()

# Save cell masks
print("  Saving cell masks...")
save_mask(mask_wsi_c, os.path.join(MASKS_DIR, 'wsi_15000_dual_cells.tif'))
torch.cuda.synchronize()

time_wsi_cells = time.time() - start
print(f"  ‚úì Cells complete: {time_wsi_cells / 60:.2f} min")

elapsed_wsi_dual = time.time() - start_total
gpu_mem_wsi_dual = measure_gpu_memory()

total_instances_wsi = n_nuclei_wsi_dual + n_cells_wsi_dual

results_wsi_dual = {
    'task': 'Nuclei + Cells (Sequential)',
    'image_size': '15000√ó15000',
    'n_tiles': total_tiles,
    'tile_size': '512√ó512',
    'n_nuclei': n_nuclei_wsi_dual,
    'n_cells': n_cells_wsi_dual,
    'total_instances': total_instances_wsi,
    'total_time_s': elapsed_wsi_dual,
    'total_time_min': elapsed_wsi_dual / 60,
    'time_per_tile_ms': (elapsed_wsi_dual / total_tiles) * 1000,
    'peak_gpu_memory_gb': gpu_mem_wsi_dual['max_allocated_gb'],
    'instances_per_second': total_instances_wsi / elapsed_wsi_dual,
    'throughput_mpx_per_min': (15000 * 15000 / 1e6) / (elapsed_wsi_dual / 60),
    'overhead_vs_single_pct': ((elapsed_wsi_dual / elapsed_wsi) - 1) * 100,
}

print("\nRESULTS:")
for key, value in results_wsi_dual.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

print("\n‚≠ê KEY INSIGHTS:")
print(f"  Single-task WSI: {elapsed_wsi / 60:.2f} min")
print(f"  Dual-task WSI (2 sequential runs): {elapsed_wsi_dual / 60:.2f} min ({time_wsi_nuclei / 60:.2f} + {time_wsi_cells / 60:.2f})")
print(f"  Overhead: {results_wsi_dual['overhead_vs_single_pct']:.1f}%")
print(f"  üî¥ CellSAM requires 2 separate runs (no simultaneous dual-task)")

# ============================================
# 6. COMPILE ALL RESULTS
# ============================================

print("\n" + "=" * 70)
print("CELLSAM COMPREHENSIVE BENCHMARK SUMMARY")
print("=" * 70)

summary_df = pd.DataFrame([
    {
        'Test': 'Single Tile - Nuclei',
        'Task': 'Nuclei',
        'Size': '512√ó512',
        'Instances': results_nuclei['n_instances'],
        'Time (s)': results_nuclei['mean_time_s'],
        'Memory (GB)': results_nuclei['peak_memory_gb'],
        'Inst/sec': results_nuclei['instances_per_second'],
    },
    {
        'Test': 'Single Tile - Cells',
        'Task': 'Cells',
        'Size': '512√ó512',
        'Instances': results_cells['n_instances'],
        'Time (s)': results_cells['mean_time_s'],
        'Memory (GB)': results_cells['peak_memory_gb'],
        'Inst/sec': results_cells['instances_per_second'],
    },
    {
        'Test': 'Single Tile - DUAL',
        'Task': 'Both (2 runs)',
        'Size': '512√ó512',
        'Instances': results_dual['total_instances'],
        'Time (s)': results_dual['mean_time_s'],
        'Memory (GB)': results_dual['peak_memory_gb'],
        'Inst/sec': results_dual['instances_per_second'],
    },
    {
        'Test': 'WSI - Nuclei',
        'Task': 'Nuclei',
        'Size': '15000√ó15000',
        'Instances': results_wsi_nuclei['n_instances'],
        'Time (s)': results_wsi_nuclei['total_time_s'],
        'Memory (GB)': results_wsi_nuclei['peak_gpu_memory_gb'],
        'Inst/sec': results_wsi_nuclei['instances_per_second'],
    },
    {
        'Test': 'WSI - DUAL',
        'Task': 'Both (2 runs)',
        'Size': '15000√ó15000',
        'Instances': results_wsi_dual['total_instances'],
        'Time (s)': results_wsi_dual['total_time_s'],
        'Memory (GB)': results_wsi_dual['peak_gpu_memory_gb'],
        'Inst/sec': results_wsi_dual['instances_per_second'],
    }
])

print("\n" + summary_df.to_string(index=False))

# ============================================
# 7. SAVE ALL RESULTS
# ============================================

print("\n" + "=" * 70)
print("SAVING RESULTS...")
print("=" * 70)

# Save CSV
csv_path = os.path.join(OUTPUT_DIR, 'cellsam_benchmark_summary.csv')
summary_df.to_csv(csv_path, index=False)
print(f"‚úì Saved CSV: {csv_path}")

# Save JSON
all_results = {
    'single_tile_nuclei': results_nuclei,
    'single_tile_cells': results_cells,
    'single_tile_dual': results_dual,
    'wsi_nuclei': results_wsi_nuclei,
    'wsi_dual': results_wsi_dual,
    'hardware': {
        'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A',
        'gpu_memory_total_gb': torch.cuda.get_device_properties(0).total_memory / 1024**3 if torch.cuda.is_available() else 0,
    },
    'benchmark_date': time.strftime('%Y-%m-%d %H:%M:%S'),
}

json_path = os.path.join(OUTPUT_DIR, 'cellsam_benchmark_complete.json')
with open(json_path, 'w') as f:
    json.dump(all_results, f, indent=2)
print(f"‚úì Saved JSON: {json_path}")

# ============================================
# 8. VISUALIZATION
# ============================================

print("\nGenerating plots...")

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Single Tile Comparison
tasks = ['Nuclei\nOnly', 'Cells\nOnly', 'Both\n(2 runs)']
times = [results_nuclei['mean_time_ms'], results_cells['mean_time_ms'], results_dual['mean_time_ms']]
axes[0, 0].bar(tasks, times, color=['steelblue', 'coral', 'red'])
axes[0, 0].set_ylabel('Time (ms)', fontsize=12)
axes[0, 0].set_title('Single Tile Processing Time (512√ó512)', fontsize=14, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3, axis='y')

# Plot 2: WSI Comparison
wsi_tasks = ['Nuclei\nOnly', 'Both\n(2 runs)']
wsi_times = [
    results_wsi_nuclei['total_time_min'],
    results_wsi_dual['total_time_min']
]
axes[0, 1].bar(wsi_tasks, wsi_times, color=['steelblue', 'red'])
axes[0, 1].set_ylabel('Time (minutes)', fontsize=12)
axes[0, 1].set_title('WSI Processing Time (15000√ó15000)', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3, axis='y')

# Plot 3: Throughput
throughputs = [
    results_nuclei['instances_per_second'],
    results_cells['instances_per_second'],
    results_dual['instances_per_second']
]
axes[1, 0].bar(tasks, throughputs, color=['steelblue', 'coral', 'red'])
axes[1, 0].set_ylabel('Instances/Second', fontsize=12)
axes[1, 0].set_title('Segmentation Throughput', fontsize=14, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3, axis='y')

# Plot 4: Memory Usage
memory_vals = [
    results_nuclei['peak_memory_gb'],
    results_cells['peak_memory_gb'],
    results_dual['peak_memory_gb']
]
axes[1, 1].bar(tasks, memory_vals, color=['steelblue', 'coral', 'red'])
axes[1, 1].set_ylabel('Peak GPU Memory (GB)', fontsize=12)
axes[1, 1].set_title('Memory Usage', fontsize=14, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plot_path = os.path.join(OUTPUT_DIR, 'cellsam_benchmark_plots.png')
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"‚úì Saved plot: {plot_path}")
plt.close()

# ============================================
# 9. LIST SAVED MASK FILES
# ============================================

print("\nüìÅ Saved mask files:")
mask_files = sorted([f for f in os.listdir(MASKS_DIR) if f.endswith('.tif')])
for mf in mask_files:
    size_mb = os.path.getsize(os.path.join(MASKS_DIR, mf)) / (1024**2)
    print(f"  - {mf} ({size_mb:.2f} MB)")

print("\n" + "=" * 70)
print("BENCHMARK COMPLETE!")
print("=" * 70)
print(f"\nüìÅ All results saved to: {os.path.abspath(OUTPUT_DIR)}")
print("\n‚≠ê KEY FINDINGS:")
print(f"  Single Tile - Nuclei: {results_nuclei['mean_time_ms']:.1f} ms")
print(f"  Single Tile - Dual (2 runs): {results_dual['mean_time_ms']:.1f} ms")
print(f"  WSI - Nuclei: {results_wsi_nuclei['total_time_min']:.2f} min")
print(f"  WSI - Dual (2 runs): {results_wsi_dual['total_time_min']:.2f} min")
print(f"  üî¥ CellSAM requires 2 separate runs for nuclei + cells")
print("\nüìä Files created:")
print(f"  - {os.path.basename(csv_path)}")
print(f"  - {os.path.basename(json_path)}")
print(f"  - {os.path.basename(plot_path)}")
print(f"  - masks/ folder with {len(mask_files)} segmentation outputs")
print("=" * 70)

üìÅ Output directory: benchmark_results/cellsam
   Masks will be saved to: benchmark_results/cellsam/masks

Loading CellSAM model...
  ‚úì Model loaded successfully
Tile shape: (512, 512, 3)
WSI shape: (15000, 15000, 3)

CELLSAM BENCHMARK 1: SINGLE-TASK (Nuclei Only, 512√ó512)
Warmup runs...
Running 10 timed iterations...
  Run 1/10: 1.133s, 3.48 GB, Nuclei: 125
  Run 2/10: 1.258s, 3.48 GB, Nuclei: 125
  Run 3/10: 1.293s, 3.48 GB, Nuclei: 125
  Run 4/10: 1.217s, 3.48 GB, Nuclei: 125
  Run 5/10: 1.228s, 3.48 GB, Nuclei: 125
  Run 6/10: 1.205s, 3.48 GB, Nuclei: 125
  Run 7/10: 1.207s, 3.48 GB, Nuclei: 125
  Run 8/10: 1.218s, 3.48 GB, Nuclei: 125
  Run 9/10: 1.239s, 3.48 GB, Nuclei: 125
  Run 10/10: 1.248s, 3.48 GB, Nuclei: 125
Saving nuclei masks...

RESULTS:
  task: Nuclei Only
  image_size: 512√ó512
  n_instances: 125
  mean_time_s: 1.2246
  std_time_s: 0.0397
  mean_time_ms: 1224.5991
  peak_memory_gb: 3.4841
  instances_per_second: 102.0742

CELLSAM BENCHMARK 2: SINGLE-TASK (Cells O

532it [11:59,  1.37it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
560it [12:29,  1.03s/it]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
564it [12:31,  1.82it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
589it [13:01,  1.11s/it]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
590it [13:01,  1.13it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
618it [13:35,  1.07s/it]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
621it [13:37,  1.57it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
648it [14:14,  1.11it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
672it [14:45,  1.07s/it]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
673it [14:45,  1.16it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribu

  Computing segmentation metrics (forces computation)...
  Saving WSI nuclei masks...

RESULTS:
  task: Nuclei Only
  image_size: 15000√ó15000
  n_tiles: 1156
  tile_size: 512√ó512
  n_instances: 60197
  total_time_s: 11441.3129
  total_time_min: 190.6885
  time_per_tile_ms: 9897.3295
  peak_gpu_memory_gb: 3.4858
  instances_per_second: 5.2614
  throughput_mpx_per_min: 1.1799

‚≠ê CELLSAM BENCHMARK 5: WSI DUAL-TASK SIMULATION (15000√ó15000)
Note: Running CellSAM TWICE for nuclei + cells
‚è±Ô∏è  This will take even longer - please be patient...

üî¨ Running nuclei segmentation...
Total blocks: 900


532it [11:55,  1.38it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
560it [12:25,  1.05s/it]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
564it [12:27,  1.78it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
589it [12:57,  1.11s/it]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
590it [12:58,  1.12it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
618it [13:32,  1.08s/it]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
621it [13:33,  1.56it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
648it [14:11,  1.11it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
672it [14:42,  1.06s/it]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribute 'ndim'
673it [14:42,  1.16it/s]ERROR:root:Error segmenting chunk: 'NoneType' object has no attribu

  Computing nuclei metrics...
