# CellPose-SAM Segmentation for ark-analysis

This notebook segments cells using CellPose (with optional SAM integration) as an alternative to DeepCell/Mesmer. It takes the same `deepcell_input` files and saves segmentation masks compatible with the ark-analysis pipeline.

**Prerequisites:**
- CellPose installed: `pip install cellpose`
- For GPU support: `pip install cellpose[gui]`
- Optimized parameters from local CellPose GUI testing

**Workflow:**
1. Load deepcell_input files (2-channel: nuclear + membrane)
2. Segment using CellPose with user-defined parameters
3. Save masks to `cellpose_output` in ark-analysis compatible format

## 0. Setup and Imports

In [None]:
# Install CellPose if needed
# !pip install cellpose
# !pip install cellpose[gui]  # For GUI and GPU support

In [None]:
import os
import warnings
from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage import io as skio
from tqdm.auto import tqdm

# CellPose imports
try:
    from cellpose import models, io as cellpose_io
    from cellpose.models import CellposeModel
    CELLPOSE_AVAILABLE = True
    print("CellPose loaded successfully")
except ImportError:
    CELLPOSE_AVAILABLE = False
    print("ERROR: CellPose not installed. Run: pip install cellpose")

# Check GPU availability
try:
    import torch
    GPU_AVAILABLE = torch.cuda.is_available()
    if GPU_AVAILABLE:
        print(f"GPU available: {torch.cuda.get_device_name(0)}")
    else:
        print("GPU not available, using CPU")
except ImportError:
    GPU_AVAILABLE = False
    print("PyTorch not installed, using CPU")

# ark-analysis imports
from alpineer import io_utils
from ark.utils import plot_utils

%matplotlib inline
plt.rcParams['figure.dpi'] = 100

## 1. Configuration

### 1.1 Set Directories

Configure paths to match your data structure.

In [None]:
# =============================================================================
# DIRECTORY CONFIGURATION
# =============================================================================

# Base directory for your project
base_dir = "/path/to/your/project/"

# Input: deepcell_input directory (same format as used for Mesmer)
deepcell_input_dir = os.path.join(base_dir, "segmentation/deepcell_input")

# Output: CellPose segmentation masks
cellpose_output_dir = os.path.join(base_dir, "segmentation/cellpose_output")

# Optional: Visualization output
cellpose_visualization_dir = os.path.join(base_dir, "segmentation/cellpose_visualization")

# Original TIFF directory (for visualization overlay)
tiff_dir = os.path.join(base_dir, "image_data")

In [None]:
# Create output directories
for directory in [cellpose_output_dir, cellpose_visualization_dir]:
    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f"Created: {directory}")

# Validate input directory exists
if os.path.exists(deepcell_input_dir):
    print(f"Input directory found: {deepcell_input_dir}")
else:
    print(f"WARNING: Input directory not found: {deepcell_input_dir}")

### 1.2 CellPose Parameters

**IMPORTANT:** These parameters should be optimized using the CellPose GUI locally before running batch segmentation.

To optimize parameters:
1. Run `python -m cellpose` to open the GUI
2. Load a representative image from your deepcell_input
3. Adjust diameter, flow_threshold, and cellprob_threshold
4. Test on multiple images to ensure robust settings
5. Copy the optimized parameters here

In [None]:
# =============================================================================
# CELLPOSE PARAMETERS - OPTIMIZE THESE LOCALLY FIRST
# =============================================================================

# Model selection
# Options: 'cyto3' (recommended), 'cyto2', 'cyto', 'nuclei', or path to custom model
MODEL_TYPE = "cyto3"

# Cell diameter in pixels
# Set to None for auto-estimation, or specify exact value from GUI optimization
# Tip: Use the GUI to measure typical cell diameters in your images
DIAMETER = None  # e.g., 30.0, 35.0, 40.0

# Flow threshold (0.0 to 1.0)
# Higher = fewer cells but higher quality
# Lower = more cells but may include artifacts
# Default: 0.4
FLOW_THRESHOLD = 0.4

# Cell probability threshold (-6.0 to 6.0)
# Higher = fewer cells (stricter)
# Lower = more cells (more permissive)
# Default: 0.0
CELLPROB_THRESHOLD = 0.0

# Minimum cell size in pixels
# Cells smaller than this will be removed
MIN_SIZE = 15

# Additional options
USE_GPU = True                # Use GPU if available
SEGMENT_NUCLEI = True         # Also segment nuclei separately
RESAMPLE = True               # Resample images for better accuracy
AUGMENT = False               # Test-time augmentation (slower but more robust)
NET_AVG = False               # Average multiple networks (slower)

# Tiling for large images
TILE = True                   # Enable tiling for large images
TILE_OVERLAP = 0.1            # Overlap fraction between tiles
BATCH_SIZE = 8                # Batch size for tile processing

In [None]:
# Print parameter summary
print("CellPose Parameters:")
print("=" * 40)
print(f"Model:                 {MODEL_TYPE}")
print(f"Diameter:              {DIAMETER if DIAMETER else 'auto'}")
print(f"Flow threshold:        {FLOW_THRESHOLD}")
print(f"Cell prob threshold:   {CELLPROB_THRESHOLD}")
print(f"Min size:              {MIN_SIZE}")
print(f"GPU:                   {USE_GPU and GPU_AVAILABLE}")
print(f"Segment nuclei:        {SEGMENT_NUCLEI}")
print(f"Resample:              {RESAMPLE}")
print(f"Augment:               {AUGMENT}")
print(f"Tiling:                {TILE}")

## 2. Discover and Preview Input Files

In [None]:
# Discover input files
input_files = sorted(
    list(Path(deepcell_input_dir).glob("*.tiff")) + 
    list(Path(deepcell_input_dir).glob("*.tif"))
)

fov_names = [f.stem for f in input_files]

print(f"Found {len(input_files)} input files:")
for i, f in enumerate(input_files[:10]):
    print(f"  {i+1}. {f.name}")
if len(input_files) > 10:
    print(f"  ... and {len(input_files) - 10} more")

In [None]:
# Optional: Select specific FOVs to process
# Set to None to process all FOVs
SELECTED_FOVS = None  # e.g., ["fov0", "fov1", "fov2"]

if SELECTED_FOVS:
    input_files = [f for f in input_files if f.stem in SELECTED_FOVS]
    print(f"Processing {len(input_files)} selected FOVs")

In [None]:
# Preview a sample input image
if input_files:
    sample_img = skio.imread(str(input_files[0]))
    print(f"Sample image shape: {sample_img.shape}")
    print(f"Sample image dtype: {sample_img.dtype}")
    
    # Display channels
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    if sample_img.ndim == 3 and sample_img.shape[0] == 2:
        # Channels-first format
        axes[0].imshow(sample_img[0], cmap='gray')
        axes[0].set_title('Channel 0: Nuclear')
        axes[1].imshow(sample_img[1], cmap='gray')
        axes[1].set_title('Channel 1: Membrane')
    elif sample_img.ndim == 3 and sample_img.shape[2] == 2:
        # Channels-last format
        axes[0].imshow(sample_img[:,:,0], cmap='gray')
        axes[0].set_title('Channel 0: Nuclear')
        axes[1].imshow(sample_img[:,:,1], cmap='gray')
        axes[1].set_title('Channel 1: Membrane')
    
    for ax in axes:
        ax.axis('off')
    
    plt.suptitle(f"Sample: {input_files[0].name}")
    plt.tight_layout()
    plt.show()

## 3. Initialize CellPose Model

In [None]:
# Initialize the CellPose model
print(f"Loading CellPose model: {MODEL_TYPE}")

use_gpu = USE_GPU and GPU_AVAILABLE

# Load model
if os.path.exists(MODEL_TYPE):
    # Custom model from file
    print(f"Loading custom model from: {MODEL_TYPE}")
    model = models.CellposeModel(
        pretrained_model=MODEL_TYPE,
        gpu=use_gpu
    )
else:
    # Built-in model
    model = models.Cellpose(
        model_type=MODEL_TYPE,
        gpu=use_gpu,
        net_avg=NET_AVG
    )

print(f"Model loaded. GPU: {use_gpu}")

# Also load nuclei model if segmenting nuclei separately
if SEGMENT_NUCLEI:
    nuclei_model = models.Cellpose(
        model_type='nuclei',
        gpu=use_gpu
    )
    print("Nuclei model loaded.")

## 4. Test Segmentation on Single Image

Before running batch processing, test on a single image to verify parameters.

In [None]:
def segment_single_image(image, model, diameter, flow_threshold, cellprob_threshold, 
                         min_size, resample=True, augment=False, tile=True, 
                         tile_overlap=0.1, batch_size=8):
    """
    Segment a single deepcell_input format image.
    """
    # Handle channels-first format
    if image.ndim == 3 and image.shape[0] == 2:
        # Convert to channels-last for CellPose
        image_cp = np.moveaxis(image, 0, -1)
    else:
        image_cp = image
    
    # CellPose channels: [cytoplasm, nucleus]
    # DeepCell format: ch0=nuclear, ch1=membrane
    # So: cytoplasm=membrane(ch1), nucleus=nuclear(ch0)
    channels = [2, 1]  # [membrane channel, nuclear channel]
    
    # Run segmentation
    masks, flows, styles = model.eval(
        image_cp,
        diameter=diameter,
        channels=channels,
        flow_threshold=flow_threshold,
        cellprob_threshold=cellprob_threshold,
        min_size=min_size,
        resample=resample,
        augment=augment,
        tile=tile,
        tile_overlap=tile_overlap,
        batch_size=batch_size,
    )
    
    return masks, flows, styles

In [None]:
# Test on first image
if input_files:
    test_file = input_files[0]
    print(f"Testing segmentation on: {test_file.name}")
    
    # Load image
    test_img = skio.imread(str(test_file))
    
    # Segment
    test_masks, test_flows, test_styles = segment_single_image(
        test_img, model, DIAMETER, FLOW_THRESHOLD, CELLPROB_THRESHOLD,
        MIN_SIZE, RESAMPLE, AUGMENT, TILE, TILE_OVERLAP, BATCH_SIZE
    )
    
    n_cells = len(np.unique(test_masks)) - 1  # Exclude background
    print(f"Segmented {n_cells} cells")

In [None]:
# Visualize test segmentation
if 'test_masks' in dir():
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image (composite)
    if test_img.ndim == 3 and test_img.shape[0] == 2:
        composite = np.zeros((*test_img.shape[1:], 3), dtype=np.float32)
        composite[:,:,0] = test_img[1] / test_img[1].max()  # Membrane = red
        composite[:,:,2] = test_img[0] / test_img[0].max()  # Nuclear = blue
    else:
        composite = test_img
    
    axes[0].imshow(composite)
    axes[0].set_title('Input (Red=Membrane, Blue=Nuclear)')
    axes[0].axis('off')
    
    # Segmentation mask
    axes[1].imshow(test_masks, cmap='nipy_spectral')
    axes[1].set_title(f'Segmentation ({n_cells} cells)')
    axes[1].axis('off')
    
    # Overlay
    from cellpose import plot
    overlay = plot.mask_overlay(composite, test_masks)
    axes[2].imshow(overlay)
    axes[2].set_title('Overlay')
    axes[2].axis('off')
    
    plt.suptitle(f"Test Segmentation: {test_file.name}")
    plt.tight_layout()
    plt.show()

### 4.1 Adjust Parameters if Needed

If the segmentation doesn't look good, adjust the parameters in Section 1.2 and re-run.

In [None]:
# Quick parameter testing cell - modify and re-run as needed
# Uncomment and adjust to test different parameters

# TEST_DIAMETER = 35.0
# TEST_FLOW_THRESHOLD = 0.6
# TEST_CELLPROB_THRESHOLD = -1.0

# test_masks_v2, _, _ = segment_single_image(
#     test_img, model, TEST_DIAMETER, TEST_FLOW_THRESHOLD, TEST_CELLPROB_THRESHOLD,
#     MIN_SIZE, RESAMPLE, AUGMENT, TILE, TILE_OVERLAP, BATCH_SIZE
# )
# print(f"Test v2: {len(np.unique(test_masks_v2)) - 1} cells")
# plt.figure(figsize=(8, 8))
# plt.imshow(test_masks_v2, cmap='nipy_spectral')
# plt.title(f"Test: d={TEST_DIAMETER}, ft={TEST_FLOW_THRESHOLD}, cpt={TEST_CELLPROB_THRESHOLD}")
# plt.axis('off')
# plt.show()

## 5. Run Batch Segmentation

In [None]:
def save_segmentation_masks(whole_cell_mask, nuclear_mask, output_dir, fov_name):
    """
    Save segmentation masks in ark-analysis compatible format.
    """
    # Save whole cell mask
    wc_path = os.path.join(output_dir, f"{fov_name}_whole_cell.tiff")
    skio.imsave(wc_path, whole_cell_mask.astype(np.int32), check_contrast=False)
    
    # Save nuclear mask if provided
    if nuclear_mask is not None:
        nuc_path = os.path.join(output_dir, f"{fov_name}_nuclear.tiff")
        skio.imsave(nuc_path, nuclear_mask.astype(np.int32), check_contrast=False)

In [None]:
# Run batch segmentation
print("\n" + "="*60)
print("Starting Batch Segmentation")
print("="*60)
print(f"Processing {len(input_files)} images...")
print(f"Output directory: {cellpose_output_dir}\n")

results = []

for input_file in tqdm(input_files, desc="Segmenting"):
    fov_name = input_file.stem
    
    try:
        # Load image
        image = skio.imread(str(input_file))
        
        # Segment whole cells
        whole_cell_mask, flows, styles = segment_single_image(
            image, model, DIAMETER, FLOW_THRESHOLD, CELLPROB_THRESHOLD,
            MIN_SIZE, RESAMPLE, AUGMENT, TILE, TILE_OVERLAP, BATCH_SIZE
        )
        
        # Segment nuclei separately if requested
        nuclear_mask = None
        if SEGMENT_NUCLEI:
            # Extract nuclear channel
            if image.ndim == 3 and image.shape[0] == 2:
                nuc_img = image[0]
            else:
                nuc_img = image[:,:,0]
            
            nuclear_mask, _, _ = nuclei_model.eval(
                nuc_img,
                diameter=DIAMETER,
                channels=[0, 0],
                flow_threshold=FLOW_THRESHOLD,
                cellprob_threshold=CELLPROB_THRESHOLD,
                min_size=MIN_SIZE,
            )
        
        # Save masks
        save_segmentation_masks(
            whole_cell_mask, nuclear_mask, cellpose_output_dir, fov_name
        )
        
        n_cells = len(np.unique(whole_cell_mask)) - 1
        n_nuclei = len(np.unique(nuclear_mask)) - 1 if nuclear_mask is not None else 0
        
        results.append({
            'fov': fov_name,
            'n_cells': n_cells,
            'n_nuclei': n_nuclei,
            'status': 'success'
        })
        
    except Exception as e:
        print(f"\nError processing {fov_name}: {e}")
        results.append({
            'fov': fov_name,
            'n_cells': 0,
            'n_nuclei': 0,
            'status': f'error: {str(e)}'
        })

print("\nSegmentation complete!")

In [None]:
# Create results summary
results_df = pd.DataFrame(results)

print("\n" + "="*60)
print("SEGMENTATION SUMMARY")
print("="*60)

successful = results_df[results_df['status'] == 'success']
print(f"\nProcessed: {len(successful)}/{len(results_df)} images successfully")

if len(successful) > 0:
    print(f"Total cells: {successful['n_cells'].sum()}")
    print(f"Average cells per FOV: {successful['n_cells'].mean():.1f}")
    print(f"Min cells: {successful['n_cells'].min()}")
    print(f"Max cells: {successful['n_cells'].max()}")
    
    if SEGMENT_NUCLEI:
        print(f"\nTotal nuclei: {successful['n_nuclei'].sum()}")
        print(f"Average nuclei per FOV: {successful['n_nuclei'].mean():.1f}")

# Show failed FOVs if any
failed = results_df[results_df['status'] != 'success']
if len(failed) > 0:
    print(f"\nFailed FOVs ({len(failed)}):")
    for _, row in failed.iterrows():
        print(f"  - {row['fov']}: {row['status']}")

In [None]:
# Display results table
results_df

In [None]:
# Save results summary
summary_path = os.path.join(cellpose_output_dir, "segmentation_summary.csv")
results_df.to_csv(summary_path, index=False)
print(f"Saved summary to: {summary_path}")

## 6. Visualize Results

In [None]:
# Plot cell count distribution
if len(successful) > 0:
    fig, ax = plt.subplots(figsize=(10, 4))
    
    ax.bar(range(len(successful)), successful['n_cells'].values, color='steelblue', alpha=0.7)
    ax.axhline(successful['n_cells'].mean(), color='red', linestyle='--', label=f"Mean: {successful['n_cells'].mean():.1f}")
    ax.set_xlabel('FOV Index')
    ax.set_ylabel('Number of Cells')
    ax.set_title('Cells Segmented per FOV')
    ax.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize segmentation for a few FOVs
n_display = min(4, len(successful))

if n_display > 0:
    fig, axes = plt.subplots(n_display, 3, figsize=(15, 4*n_display))
    if n_display == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(n_display):
        fov_name = successful.iloc[i]['fov']
        
        # Load input image
        input_path = os.path.join(deepcell_input_dir, f"{fov_name}.tiff")
        if not os.path.exists(input_path):
            input_path = os.path.join(deepcell_input_dir, f"{fov_name}.tif")
        img = skio.imread(input_path)
        
        # Load segmentation mask
        mask_path = os.path.join(cellpose_output_dir, f"{fov_name}_whole_cell.tiff")
        mask = skio.imread(mask_path)
        
        # Create composite
        if img.ndim == 3 and img.shape[0] == 2:
            composite = np.zeros((*img.shape[1:], 3), dtype=np.float32)
            composite[:,:,0] = img[1] / (img[1].max() + 1e-6)
            composite[:,:,2] = img[0] / (img[0].max() + 1e-6)
        else:
            composite = img
        
        axes[i, 0].imshow(composite)
        axes[i, 0].set_title(f'{fov_name}: Input')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(mask, cmap='nipy_spectral')
        axes[i, 1].set_title(f'{successful.iloc[i]["n_cells"]} cells')
        axes[i, 1].axis('off')
        
        # Overlay
        from cellpose import plot
        overlay = plot.mask_overlay(composite, mask)
        axes[i, 2].imshow(overlay)
        axes[i, 2].set_title('Overlay')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

## 7. Save Visualization Overlays (Optional)

In [None]:
# Save overlay visualizations for all FOVs
SAVE_VISUALIZATIONS = True  # Set to False to skip

if SAVE_VISUALIZATIONS:
    print("Saving visualization overlays...")
    
    from cellpose import plot
    
    for _, row in tqdm(successful.iterrows(), total=len(successful), desc="Saving overlays"):
        fov_name = row['fov']
        
        # Load input and mask
        input_path = os.path.join(deepcell_input_dir, f"{fov_name}.tiff")
        if not os.path.exists(input_path):
            input_path = os.path.join(deepcell_input_dir, f"{fov_name}.tif")
        
        img = skio.imread(input_path)
        mask = skio.imread(os.path.join(cellpose_output_dir, f"{fov_name}_whole_cell.tiff"))
        
        # Create composite
        if img.ndim == 3 and img.shape[0] == 2:
            composite = np.zeros((*img.shape[1:], 3), dtype=np.float32)
            composite[:,:,0] = img[1] / (img[1].max() + 1e-6)
            composite[:,:,2] = img[0] / (img[0].max() + 1e-6)
        else:
            composite = img
        
        # Create and save overlay
        overlay = plot.mask_overlay(composite, mask)
        overlay_path = os.path.join(cellpose_visualization_dir, f"{fov_name}_overlay.png")
        plt.imsave(overlay_path, overlay)
    
    print(f"Saved overlays to: {cellpose_visualization_dir}")

## 8. Next Steps

The segmentation masks are now saved in `cellpose_output` in the same format as DeepCell output. You can proceed with the rest of the ark-analysis pipeline:

1. **Extract cell table:** Use `marker_quantification.generate_cell_table()` with `segmentation_dir=cellpose_output_dir`
2. **Pixie clustering:** Proceed to Pixie notebooks for cell phenotyping
3. **Spatial analysis:** Run neighborhood and spatial enrichment analysis

In [None]:
# Example: Generate cell table using CellPose segmentation
# Uncomment to run

# from ark.segmentation import marker_quantification

# cell_table_size_normalized, cell_table_arcsinh_transformed = \
#     marker_quantification.generate_cell_table(
#         segmentation_dir=cellpose_output_dir,
#         tiff_dir=tiff_dir,
#         img_sub_folder=None,
#         fovs=fov_names,
#         batch_size=4,
#         nuclear_counts=SEGMENT_NUCLEI,
#         fast_extraction=False
#     )

# # Save cell tables
# cell_table_dir = os.path.join(base_dir, "segmentation/cell_table")
# os.makedirs(cell_table_dir, exist_ok=True)
# cell_table_size_normalized.to_csv(os.path.join(cell_table_dir, 'cell_table_size_normalized.csv'), index=False)
# cell_table_arcsinh_transformed.to_csv(os.path.join(cell_table_dir, 'cell_table_arcsinh_transformed.csv'), index=False)

---

## Parameters Reference

| Parameter | Default | Description |
|-----------|---------|-------------|
| `MODEL_TYPE` | cyto3 | CellPose model (cyto3, cyto2, cyto, nuclei, or custom path) |
| `DIAMETER` | None | Cell diameter in pixels (None=auto) |
| `FLOW_THRESHOLD` | 0.4 | Flow error threshold (0-1, higher=fewer cells) |
| `CELLPROB_THRESHOLD` | 0.0 | Cell probability threshold (-6 to 6, higher=fewer cells) |
| `MIN_SIZE` | 15 | Minimum cell size in pixels |
| `RESAMPLE` | True | Resample images for accuracy |
| `AUGMENT` | False | Test-time augmentation (slower) |
| `TILE` | True | Enable tiling for large images |

**Tip:** Use the CellPose GUI (`python -m cellpose`) to find optimal parameters for your data.