# Batch Nuclei and Cytoplasm Membrane Segmentation (Distributed, Xarray-based)

This notebook processes large microscopy images in a user-specified folder using distributed, blockwise segmentation with Cellpose and xarray. It segments nuclei and cytoplasm, relabels nuclei to match cytoplasm, extracts membrane regions, and reports progress efficiently using Dask.

**Workflow:**
1. Import required libraries
2. Get user input for folder and file extension
3. List and filter files
4. For each file, load as xarray (chunked)
5. Segment nuclei and cytoplasm using distributed segmentation
6. Relabel nuclei to match cytoplasm
7. Apply erosion to both label sets (distributed)
8. Extract membrane regions (distributed)
9. Save and display results


In [None]:
"""
Import required libraries for distributed, xarray-based image processing and segmentation.
"""
import os
import glob
from pathlib import Path
import numpy as np
from tqdm import tqdm
import xarray as xr
import dask
import dask.array as da
from dask.diagnostics import ProgressBar
from cellpose import models
from cellpose.contrib import distributed_segmentation as ds
import skimage.morphology
from skimage.measure import label, regionprops
import matplotlib.pyplot as plt
import ipywidgets as widgets

display = __import__("IPython.display").display


In [None]:
"""
Get user input for folder and file extension using ipywidgets.
Images will be loaded as xarray DataArrays for distributed processing.
"""
folder_widget = widgets.Text(
    value="", placeholder="Enter folder path", description="Folder:", disabled=False
)
ext_widget = widgets.Text(
    value=".tif", placeholder="e.g. .tif", description="Extension:", disabled=False
)
button = widgets.Button(description="Confirm")
ui = widgets.VBox([folder_widget, ext_widget, button])
display(ui)

user_inputs = {}


def on_button_clicked(b):
    user_inputs["folder"] = folder_widget.value
    user_inputs["ext"] = ext_widget.value
    print(f"Selected folder: {user_inputs['folder']}")
    print(f"File extension: {user_inputs['ext']}")


button.on_click(on_button_clicked)


In [None]:
"""
List and filter files in the specified folder with the given extension.
Efficiently prepare to load each as xarray.DataArray for distributed processing.
"""
import time
import glob
import os

# Wait for user input
while not user_inputs.get("folder") or not user_inputs.get("ext"):
    time.sleep(0.5)
folder = user_inputs["folder"]
ext = user_inputs["ext"]
file_list = sorted(glob.glob(os.path.join(folder, f"*{ext}")))
print(f"Found {len(file_list)} files with extension '{ext}' in '{folder}'.")


In [None]:
"""
For each file, load as xarray.DataArray for distributed processing using Bioio library.
Let distributed_eval handle chunking for efficiency on large images.
"""
from bioio import BioImage
import bioio_bioformats

xr_images = []
for fname in tqdm(file_list, desc="Loading images as xarray (bioio)"):
    img = BioImage(fname, reader=bioio_bioformats.Reader)
    arr = img.xarray_data  # do not chunk here; let distributed_eval handle it
    xr_images.append({"filename": fname, "xarr": arr})
print(f"Loaded {len(xr_images)} images as xarray.DataArray objects.")


In [None]:
"""
Segment nuclei channel for each image using distributed_eval (Cellpose, distributed, xarray-based).
Assumes nuclei channel is channel 0. Results are kept as xarray.DataArray in memory.

Blocksize and chunking are now inferred automatically for efficiency on large images.
"""
from cellpose.contrib.distributed_segmentation import myLocalCluster

nuclei_results = []
for img in tqdm(xr_images, desc="Distributed nuclei segmentation"):
    arr = img["xarr"]
    # Setup cluster for distributed segmentation
    cluster_kwargs = {"ncpus": 4}  # adjust as needed
    model_kwargs = {"model_type": "nuclei", "gpu": True}
    eval_kwargs = {"channels": [0, 0], "diameter": 30}
    with myLocalCluster(**cluster_kwargs) as cluster:
        seg_xr, _ = ds.distributed_eval(
            input_xr=arr,
            write_path=None,  # keep in memory
            model_kwargs=model_kwargs,
            eval_kwargs=eval_kwargs,
            cluster=cluster,
        )
    nuclei_results.append({"filename": img["filename"], "nuclei_seg": seg_xr})


In [None]:
"""
Segment nuclei+cytoplasm channel for each image using distributed_eval (Cellpose, distributed, xarray-based).
Assumes nuclei is channel 0, cytoplasm is channel 1. Results are kept as xarray.DataArray in memory.

Blocksize and chunking are now inferred automatically for efficiency on large images.
"""
cyto_results = []
for img in tqdm(xr_images, desc="Distributed cyto segmentation"):
    arr = img["xarr"]
    cluster_kwargs = {"ncpus": 4}  # adjust as needed
    model_kwargs = {"model_type": "cyto", "gpu": True}
    eval_kwargs = {"channels": [0, 1], "diameter": 30}
    with myLocalCluster(**cluster_kwargs) as cluster:
        seg_xr, _ = ds.distributed_eval(
            input_xr=arr,
            write_path=None,  # keep in memory
            model_kwargs=model_kwargs,
            eval_kwargs=eval_kwargs,
            cluster=cluster,
        )
    cyto_results.append({"filename": img["filename"], "cyto_seg": seg_xr})


In [None]:
"""
Relabel nuclei labels to match cytoplasm results using xarray and Dask for distributed, chunked processing.
"""
from skimage.measure import regionprops

relabelled_nuclei = []
for nres, cres in tqdm(
    zip(nuclei_results, cyto_results),
    total=len(nuclei_results),
    desc="Relabel nuclei to match cyto",
):
    nuclei_xr = nres["nuclei_seg"]
    cyto_xr = cres["cyto_seg"]
    # For each label in nuclei, assign the most overlapping cyto label
    nuclei = nuclei_xr.values
    cyto = cyto_xr.values
    relabel_map = np.zeros(np.max(nuclei) + 1, dtype=np.int32)
    for region in regionprops(label(nuclei)):
        label_id = region.label
        coords = tuple(region.coords.T)
        cyto_labels, counts = np.unique(cyto[coords], return_counts=True)
        cyto_labels = cyto_labels[cyto_labels != 0]
        if len(cyto_labels) > 0:
            relabel_map[label_id] = cyto_labels[np.argmax(counts[cyto_labels != 0])]
    nuclei_relabel = relabel_map[nuclei]
    relabelled_nuclei.append(
        {
            "filename": nres["filename"],
            "nuclei_relabel": xr.DataArray(nuclei_relabel, dims=nuclei_xr.dims),
        }
    )


In [None]:
"""
Apply erosion to nuclei and cytoplasm labels using xarray and Dask for distributed, chunked processing.
Efficient for large images using blockwise operations.
"""


def erode_labels_xr(label_xr, erosion_radius=2):
    """
    Erode each label in a label image by a given radius using xarray and Dask.
    Parameters
    ----------
    label_xr : xarray.DataArray
        Labeled image.
    erosion_radius : int
        Radius for morphological erosion.
    Returns
    -------
    xarray.DataArray
        Eroded label image.
    Examples
    --------
    >>> eroded = erode_labels_xr(label_xr, 2)
    """

    def erode_block(block):
        eroded = np.zeros_like(block)
        for region in regionprops(label(block)):
            mask = block == region.label
            eroded_mask = skimage.morphology.erosion(
                mask, skimage.morphology.disk(erosion_radius)
            )
            eroded[eroded_mask] = region.label
        return eroded

    return label_xr.map_blocks(erode_block, dtype=label_xr.dtype)


eroded_results = []
for nres, cres, rlbl in tqdm(
    zip(nuclei_results, cyto_results, relabelled_nuclei),
    total=len(nuclei_results),
    desc="Eroding labels (distributed)",
):
    nuclei_eroded = erode_labels_xr(rlbl["nuclei_relabel"])
    cyto_eroded = erode_labels_xr(cres["cyto_seg"])
    eroded_results.append(
        {
            "filename": nres["filename"],
            "nuclei_eroded": nuclei_eroded,
            "cyto_eroded": cyto_eroded,
        }
    )


In [None]:
"""
Extract nuclei and cytoplasm membranes by subtracting eroded labels from original labels using xarray and Dask.
Efficient for large images using lazy xarray operations.
"""
membrane_results = []
for nres, cres, erres, rlbl in tqdm(
    zip(nuclei_results, cyto_results, eroded_results, relabelled_nuclei),
    total=len(nuclei_results),
    desc="Extracting membranes (distributed)",
):
    nuclei_membrane = (rlbl["nuclei_relabel"] > 0) & (erres["nuclei_eroded"] == 0)
    cyto_membrane = (cres["cyto_seg"] > 0) & (erres["cyto_eroded"] == 0)
    membrane_results.append(
        {
            "filename": nres["filename"],
            "nuclei_membrane": nuclei_membrane,
            "cyto_membrane": cyto_membrane,
        }
    )


In [None]:
"""
Save or display results for user inspection, using xarray for efficient I/O. Only save if user requests.
Efficient for large images: avoid .values unless displaying a small region. Use .to_netcdf() for saving.
"""
output_dir = os.path.join(folder, "membrane_results")
os.makedirs(output_dir, exist_ok=True)
for res in tqdm(membrane_results, desc="Saving membrane masks"):
    base = Path(res["filename"]).stem
    # Save as NetCDF only if needed
    # res['nuclei_membrane'].astype(np.uint8).to_netcdf(os.path.join(output_dir, f"{base}_nuclei_membrane.nc"))
    # res['cyto_membrane'].astype(np.uint8).to_netcdf(os.path.join(output_dir, f"{base}_cyto_membrane.nc"))
    # Optionally display one example (small region for large images)
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(
        xr.open_dataarray(res["filename"])
        .isel({list(xr.open_dataarray(res["filename"]).dims)[0]: 0})
        .values,
        cmap="gray",
    )
    plt.title("Nuclei channel")
    plt.subplot(1, 3, 2)
    plt.imshow(
        res["nuclei_membrane"].isel({list(res["nuclei_membrane"].dims)[0]: 0}).values,
        cmap="magma",
    )
    plt.title("Nuclei membrane")
    plt.subplot(1, 3, 3)
    plt.imshow(
        res["cyto_membrane"].isel({list(res["cyto_membrane"].dims)[0]: 0}).values,
        cmap="cividis",
    )
    plt.title("Cytoplasm membrane")
    plt.suptitle(base)
    plt.show()
    break  # Remove or comment to show all
print(
    f"Results are xarray.DataArray objects in memory. Save with .to_netcdf() if needed."
)
