# Hippocampal Cell Segmentation with CellSeg3D

This notebook demonstrates how to:
1. Extract hippocampal ROIs using atlas registration (zarrnii)
2. Run 3D cell segmentation using CellSeg3D models
3. Perform instance segmentation to identify individual cells
4. Extract quantitative statistics per region

## Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import tifffile
import pandas as pd

# zarrnii for ROI extraction
from zarrnii import ZarrNii, ZarrNiiAtlas

# CellSeg3D for segmentation
from napari_cellseg3d.predict import inference_on_np3d
from napari_cellseg3d.code_models.instance_segmentation import (
    binary_watershed,
    binary_connected,
    voronoi_otsu,
    volume_stats
)
from napari_cellseg3d.create_model import create_model

from tqdm.notebook import tqdm

## Configuration

In [None]:
# Paths
ZARR_PATH = '/nfs/trident3/lightsheet/prado/mouse_app_lecanemab_ki3/bids/sub-AS134F3/micr/sub-AS134F3_sample-brain_acq-imaris4x_SPIM.ome.zarr'
ATLAS_PATH = '/nfs/trident3/lightsheet/prado/mouse_app_lecanemab_ki3/derivatives/spimquant_aae813e/sub-AS134F3/micr/sub-AS134F3_sample-brain_acq-imaris4x_seg-all_from-ABAv3_level-5_desc-deform_dseg.nii.gz'
ATLAS_TSV = '/nfs/trident3/lightsheet/prado/mouse_app_lecanemab_ki3/derivatives/spimquant_aae813e/tpl-ABAv3/seg-all_tpl-ABAv3_dseg.tsv'

MODEL_CONFIG = {
    'device': 'cuda',  # Change to 'cpu' if no GPU available
    'num_classes': 2,  # Background + cells
    'model_weight_path': 'path/to/your/trained_model.pth',  # UPDATE THIS
    'input_brightness_range': None,  # Auto-detect from data
    'in_channels': 1,
    'out_channels': 1,
}

# Channels to process
CHANNELS = ['Iba1', 'Abeta']

# Brain regions to analyze
REGIONS = {
    'Left_Hippocampus': 'Left Hipp',
    'Right_Hippocampus': 'Right Hipp',
    'Left_CA1': 'Left.*Field CA1',
    'Right_CA1': 'Right.*Field CA1',
}

# Processing parameters
RESOLUTION_LEVEL = 0  # 0=full resolution, 2=intermediate (4x downsampled)
ROI_SIZE = [64, 64, 64]  # Sliding window size for inference
OUTPUT_DIR = Path('./segmentation_results')
OUTPUT_DIR.mkdir(exist_ok=True)

## Step 1: Load Atlas and Image Data

In [None]:
# Load atlas
print("Loading atlas...")
atlas = ZarrNiiAtlas.from_files(ATLAS_PATH, ATLAS_TSV)
print(f"Atlas loaded: {len(atlas.labels_df)} regions")

# Display available regions
atlas.labels_df.head()

## Step 2: Extract ROIs for Each Region and Channel

In [None]:
def extract_roi(zarr_path, channel, region_regex, atlas, level=0):
    """Extract ROI for a specific channel and brain region."""
    img = ZarrNii.from_ome_zarr(
        zarr_path,
        channel_labels=[channel],
        level=level,
        downsample_near_isotropic=False  # Keep original resolution for segmentation
    )
    
    # Get bounding box from atlas
    bbox = atlas.get_region_bounding_box(regex=region_regex)
    
    # Crop to region
    roi = img.crop_with_bounding_box(*bbox, ras_coords=True)
    
    # Convert to numpy (remove channel dimension)
    roi_data = np.array(roi.data[0])  # Shape: (Z, Y, X)
    
    print(f"  Extracted {channel} - {region_regex}: shape={roi_data.shape}, "
          f"dtype={roi_data.dtype}, range=[{roi_data.min()}, {roi_data.max()}]")
    
    return roi_data, roi

# Extract all ROIs
rois = {}
for region_name, region_regex in REGIONS.items():
    print(f"\nProcessing region: {region_name}")
    rois[region_name] = {}
    
    for channel in CHANNELS:
        try:
            roi_data, roi_obj = extract_roi(
                ZARR_PATH, channel, region_regex, atlas, level=RESOLUTION_LEVEL
            )
            rois[region_name][channel] = {
                'data': roi_data,
                'roi_obj': roi_obj
            }
        except Exception as e:
            print(f"  ERROR extracting {channel}: {e}")

## Step 3: Visualize ROIs (Optional)

In [None]:
def plot_mip(data, title=""):
    """Plot maximum intensity projections along each axis."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(data.max(axis=0), cmap='gray')
    axes[0].set_title(f"{title} - Z projection")
    axes[0].axis('off')
    
    axes[1].imshow(data.max(axis=1), cmap='gray')
    axes[1].set_title(f"{title} - Y projection")
    axes[1].axis('off')
    
    axes[2].imshow(data.max(axis=2), cmap='gray')
    axes[2].set_title(f"{title} - X projection")
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize first region
first_region = list(REGIONS.keys())[0]
for channel in CHANNELS:
    if channel in rois[first_region]:
        plot_mip(rois[first_region][channel]['data'], f"{first_region} - {channel}")

## Step 4: Run Cell Segmentation

This performs:
1. **Semantic segmentation**: Probability maps from CellSeg3D model
2. **Instance segmentation**: Separate individual cells using watershed

In [None]:
def segment_cells(roi_data, config, roi_size=[64, 64, 64], method='watershed'):
    """
    Run full segmentation pipeline on a 3D ROI.
    
    Args:
        roi_data: 3D numpy array (Z, Y, X)
        config: Model configuration dict
        roi_size: Sliding window size for inference
        method: 'watershed', 'connected', or 'voronoi'
    
    Returns:
        instance_labels: 3D array with unique label for each cell
        prob_map: Foreground probability map
    """
    print("  Running semantic segmentation (CellSeg3D)...")
    
    # Semantic segmentation (probability maps)
    prob_maps = inference_on_np3d(config, roi_data, roi_size)
    
    # Extract foreground probability (class 1)
    foreground_prob = prob_maps[0, 1].cpu().numpy()
    
    print(f"  Probability map range: [{foreground_prob.min():.3f}, {foreground_prob.max():.3f}]")
    
    # Instance segmentation
    print(f"  Running instance segmentation ({method})...")
    
    if method == 'watershed':
        instance_labels = binary_watershed(
            foreground_prob,
            thres_objects=0.3,      # Foreground threshold
            thres_seeding=0.9,      # Seed threshold
            thres_small=30,         # Remove objects < 30 voxels
            rem_seed_thres=3        # Remove seeds < 3 voxels
        )
    elif method == 'connected':
        instance_labels = binary_connected(
            foreground_prob,
            thres=0.8,
            thres_small=30
        )
    elif method == 'voronoi':
        instance_labels = voronoi_otsu(
            foreground_prob,
            spot_sigma=2.0,
            outline_sigma=2.0,
            remove_small_size=30
        )
    else:
        raise ValueError(f"Unknown method: {method}")
    
    n_cells = len(np.unique(instance_labels)) - 1  # Exclude background
    print(f"  Detected {n_cells} cells")
    
    return instance_labels, foreground_prob

# Run segmentation on all ROIs
segmentation_results = {}

for region_name in tqdm(REGIONS.keys(), desc="Regions"):
    print(f"\nSegmenting {region_name}...")
    segmentation_results[region_name] = {}
    
    for channel in CHANNELS:
        if channel not in rois[region_name]:
            continue
            
        print(f"  Channel: {channel}")
        roi_data = rois[region_name][channel]['data']
        
        try:
            instance_labels, prob_map = segment_cells(
                roi_data, 
                MODEL_CONFIG, 
                roi_size=ROI_SIZE,
                method='watershed'
            )
            
            segmentation_results[region_name][channel] = {
                'instance_labels': instance_labels,
                'prob_map': prob_map
            }
            
        except Exception as e:
            print(f"  ERROR during segmentation: {e}")

## Step 5: Extract Quantitative Statistics

In [None]:
# Collect statistics for all regions
all_stats = []

for region_name in REGIONS.keys():
    for channel in CHANNELS:
        if channel not in segmentation_results[region_name]:
            continue
            
        instance_labels = segmentation_results[region_name][channel]['instance_labels']
        
        print(f"\n{region_name} - {channel}:")
        stats = volume_stats(instance_labels)
        
        if stats is not None:
            print(f"  Total cells: {stats.number_objects}")
            print(f"  Mean volume: {np.mean(stats.volume):.2f} voxels")
            print(f"  Std volume: {np.std(stats.volume):.2f} voxels")
            print(f"  Filling ratio: {stats.filling_ratio[0]:.4f}")
            
            # Create per-cell dataframe
            df = pd.DataFrame({
                'region': region_name,
                'channel': channel,
                'cell_id': range(len(stats.volume)),
                'volume': stats.volume,
                'centroid_x': stats.centroid_x,
                'centroid_y': stats.centroid_y,
                'centroid_z': stats.centroid_z,
                'sphericity': stats.sphericity_ax
            })
            all_stats.append(df)

# Combine all statistics
if len(all_stats) > 0:
    stats_df = pd.concat(all_stats, ignore_index=True)
    
    # Save to CSV
    stats_csv = OUTPUT_DIR / 'cell_statistics.csv'
    stats_df.to_csv(stats_csv, index=False)
    print(f"\nStatistics saved to: {stats_csv}")
    
    # Display summary
    print("\nSummary by region and channel:")
    summary = stats_df.groupby(['region', 'channel']).agg({
        'cell_id': 'count',
        'volume': ['mean', 'std'],
        'sphericity': 'mean'
    })
    summary.columns = ['cell_count', 'mean_volume', 'std_volume', 'mean_sphericity']
    display(summary)
else:
    print("No statistics to save")

## Step 6: Visualize Segmentation Results

In [None]:
def visualize_segmentation(raw, prob_map, instance_labels, title=""):
    """Visualize segmentation results with MIP."""
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    
    # Z projection
    axes[0, 0].imshow(raw.max(0), cmap='gray')
    axes[0, 0].set_title(f"{title}\nRaw (Z proj)")
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(prob_map.max(0), cmap='hot', vmin=0, vmax=1)
    axes[0, 1].set_title("Probability Map (Z proj)")
    axes[0, 1].axis('off')
    
    axes[0, 2].imshow(instance_labels.max(0), cmap='nipy_spectral')
    axes[0, 2].set_title(f"Instance Labels (Z proj)\n{len(np.unique(instance_labels))-1} cells")
    axes[0, 2].axis('off')
    
    # Y projection
    axes[1, 0].imshow(raw.max(1), cmap='gray')
    axes[1, 0].set_title("Raw (Y proj)")
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(prob_map.max(1), cmap='hot', vmin=0, vmax=1)
    axes[1, 1].set_title("Probability Map (Y proj)")
    axes[1, 1].axis('off')
    
    axes[1, 2].imshow(instance_labels.max(1), cmap='nipy_spectral')
    axes[1, 2].set_title("Instance Labels (Y proj)")
    axes[1, 2].axis('off')
    
    # X projection
    axes[2, 0].imshow(raw.max(2), cmap='gray')
    axes[2, 0].set_title("Raw (X proj)")
    axes[2, 0].axis('off')
    
    axes[2, 1].imshow(prob_map.max(2), cmap='hot', vmin=0, vmax=1)
    axes[2, 1].set_title("Probability Map (X proj)")
    axes[2, 1].axis('off')
    
    axes[2, 2].imshow(instance_labels.max(2), cmap='nipy_spectral')
    axes[2, 2].set_title("Instance Labels (X proj)")
    axes[2, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize results for each region/channel
for region_name in REGIONS.keys():
    for channel in CHANNELS:
        if channel not in segmentation_results[region_name]:
            continue
            
        raw = rois[region_name][channel]['data']
        prob_map = segmentation_results[region_name][channel]['prob_map']
        instance_labels = segmentation_results[region_name][channel]['instance_labels']
        
        visualize_segmentation(raw, prob_map, instance_labels, f"{region_name} - {channel}")

## Step 7: Save Segmentation Results

In [None]:
# Save instance labels as TIFF files
for region_name in REGIONS.keys():
    for channel in CHANNELS:
        if channel not in segmentation_results[region_name]:
            continue
            
        instance_labels = segmentation_results[region_name][channel]['instance_labels']
        prob_map = segmentation_results[region_name][channel]['prob_map']
        
        # Save instance labels
        labels_path = OUTPUT_DIR / f"{region_name}_{channel}_instance_labels.tif"
        tifffile.imwrite(labels_path, instance_labels.astype(np.uint32))
        print(f"Saved: {labels_path}")
        
        # Save probability map
        prob_path = OUTPUT_DIR / f"{region_name}_{channel}_probability_map.tif"
        tifffile.imwrite(prob_path, prob_map.astype(np.float32))
        print(f"Saved: {prob_path}")

print(f"\nAll results saved to: {OUTPUT_DIR}")

## Step 8: Cell Density Analysis

In [None]:
# Calculate cell density (cells per mm^3)
# You'll need to know the voxel size from your imaging parameters

if len(all_stats) > 0:
    # Example: adjust these based on your actual voxel size
    VOXEL_SIZE_MM = (0.001625, 0.001625, 0.0022010)  # X, Y, Z in mm from zarrnii_example
    voxel_volume_mm3 = np.prod(VOXEL_SIZE_MM)
    
    density_stats = []
    
    for region_name in REGIONS.keys():
        for channel in CHANNELS:
            if channel not in segmentation_results[region_name]:
                continue
                
            instance_labels = segmentation_results[region_name][channel]['instance_labels']
            
            n_cells = len(np.unique(instance_labels)) - 1
            volume_voxels = instance_labels.size
            volume_mm3 = volume_voxels * voxel_volume_mm3
            density = n_cells / volume_mm3
            
            density_stats.append({
                'region': region_name,
                'channel': channel,
                'n_cells': n_cells,
                'volume_mm3': volume_mm3,
                'density_per_mm3': density
            })
    
    density_df = pd.DataFrame(density_stats)
    density_csv = OUTPUT_DIR / 'cell_density.csv'
    density_df.to_csv(density_csv, index=False)
    
    print("\nCell Density Summary:")
    display(density_df)
    print(f"\nDensity statistics saved to: {density_csv}")

## Next Steps

1. **Model Training**: If you don't have a trained model, you'll need to:
   - Create training data (manual annotations)
   - Train a CellSeg3D model on your specific cell type
   - See `train.py` and training notebook examples

2. **Parameter Tuning**: Adjust instance segmentation thresholds based on your data:
   - `thres_objects`: Lower to detect dimmer cells
   - `thres_seeding`: Lower to split touching cells more aggressively
   - `thres_small`: Adjust minimum cell size

3. **Multi-Channel Analysis**: 
   - Co-localization analysis between Iba1+ cells and Abeta plaques
   - Distance measurements
   - Spatial statistics

4. **Batch Processing**: Process multiple subjects in parallel

5. **Quality Control**: Visually inspect results and adjust parameters