# 2D Stardist segmentation on 2D/3D/timelapse OMERO images

This notebook is used for Stardist segmentation. Some inspiration from the https://github.com/ome/omero-guide-cellprofiler/idr0002.ipynb

### Import Packages

In [7]:
# Import OMERO Python BlitzGateway
import omero
from omero.gateway import BlitzGateway
import ezomero
# Import Numpy
import numpy as np

# Import Python System Packages
import os
import tempfile
import pandas
import warnings

#stardist related
from stardist.models import StarDist2D
from csbdeep.utils import normalize
from stardist.plot import render_label
import matplotlib.pyplot as plt
from tifffile import imsave

#load stardist model
model = StarDist2D.from_pretrained('2D_versatile_fluo')

Found model '2D_versatile_fluo' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.479071, nms_thresh=0.3.


### Set Temp Output Directory

In [8]:
new_output_directory = os.path.normcase(tempfile.mkdtemp())

### Setup connection with OMERO

In [9]:
conn = BlitzGateway(host='localhost', username='root', passwd='omero', secure=True)
print(conn.connect())
conn.c.enableKeepAlive(60)

True


### Get info from the dataset

In [10]:
datatype = "plate" # "plate", "dataset", "image"
data_id = 52
nucl_channel = 0

#validate that data_id matches datatype
if datatype == "plate":
    plate = conn.getObject("Plate", data_id)
    print('Plate Name: ', plate.getName())
elif datatype == "dataset":
    dataset = conn.getObject("Dataset", data_id)
    print('Dataset Name: ', dataset.getName())
elif datatype == "image":
    image = conn.getObject("Image", data_id)
    print('Image Name: ', image.getName())

Plate Name:  day7


### Run Stardist on the dataset

#### Function definitions

In [15]:
import logging
from typing import List, Tuple, Any, Optional
import numpy as np
from pathlib import Path
import os
from tifffile import imsave
import ezomero

class ProcessImage:
    """Class to handle image processing and segmentation using StarDist."""
    
    # Class constants
    SEGMENTATION_NAMESPACE = "stardist.segmentation"
    ROI_NAME = "Stardist Nuclei"
    ROI_DESCRIPTION = "Nuclei segmentation using Stardist"
    
    def __init__(self, image: Any, conn: Any) -> None:
        """
        Initialize ProcessImage instance.
        
        Args:
            image: OMERO image object
            conn: OMERO connection object
        
        Raises:
            ValueError: If image or connection is invalid
        """
        if not image or not conn:
            raise ValueError("Image and connection must be provided")
            
        self._image = image
        self._conn = conn
        self._pixels = image.getPrimaryPixels()
        self._size_c = image.getSizeC()
        self._size_z = image.getSizeZ()
        self._labels = None
        self._polygons = None
        
        # Set up logging
        self.logger = logging.getLogger(__name__)
        
    @property
    def labels(self) -> np.ndarray:
        """Get segmentation labels."""
        if self._labels is None:
            raise ValueError("Segmentation has not been performed yet")
        return self._labels
        
    def segment_nuclei(self, nucl_channel: int) -> None:
        """
        Segment nuclei in the specified channel.
        
        Args:
            nucl_channel: Channel number for nuclear staining
            
        Raises:
            ValueError: If channel number is invalid
        """
        if not 0 <= nucl_channel < self._size_c:
            raise ValueError(f"Invalid channel number: {nucl_channel}")
            
        try:
            if self._size_z > 1:
                self._segment_3d_image(nucl_channel)
            else:
                self._segment_2d_image(nucl_channel)
        except Exception as e:
            self.logger.error(f"Segmentation failed: {str(e)}")
            raise
            
    def _segment_3d_image(self, channel: int) -> None:
        """Handle 3D image segmentation."""
        planes = [self._pixels.getPlane(z, channel, 0) for z in range(self._size_z)]
        labels_polygons = [self._segment_slice(plane) for plane in planes]
        self._labels, self._polygons = zip(*labels_polygons)
        
    def _segment_2d_image(self, channel: int) -> None:
        """Handle 2D image segmentation."""
        plane = self._pixels.getPlane(0, channel, 0)
        labels_polygons = self._segment_slice(plane)
        self._labels, self._polygons = zip(*[labels_polygons])
        
    def _segment_slice(self, plane: np.ndarray) -> Tuple[np.ndarray, Any]:
        """
        Segment a single image plane.
        
        Args:
            plane: 2D numpy array representing image plane
            
        Returns:
            Tuple of (labels, polygons)
        """
        try:
            img = normalize(plane)
            return model.predict_instances(img)
        except Exception as e:
            self.logger.error(f"Slice segmentation failed: {str(e)}")
            raise
            
    def save_segmentation_to_omero_as_new_image(self, new_img_name: str, desc: str) -> None:
        """Save segmentation as new OMERO image."""
        try:
            new_img = self._conn.createImageFromNumpySeq(
                iter(self.labels), 
                new_img_name, 
                self._size_z, 
                1, 
                1, 
                description=desc, 
                dataset=self._image.getParent()
            )
            self.logger.info(f'Created new Image:{new_img.getId()} Name:"{new_img.getName()}"')
        except Exception as e:
            self.logger.error(f"Failed to save new image: {str(e)}")
            raise
            
    def save_segmentation_to_omero_as_attach(self, tmp_dir: str, desc: str) -> None:
        """Save segmentation as OMERO attachment."""
        tmp_path = Path(tmp_dir)
        if not tmp_path.exists():
            tmp_path.mkdir(parents=True)
            
        tif_file = tmp_path / f"{self._image.getName()}_segmentation.tif"
        
        try:
            imsave(tif_file, self.labels)
            file_annotation_id = ezomero.post_file_annotation(
                self._conn,
                str(tif_file),
                ns=self.SEGMENTATION_NAMESPACE,
                object_type="Image",
                object_id=self._image.getId(),
                description=desc
            )
            self.logger.info(f'File annotation ID: {file_annotation_id}')
        finally:
            if tif_file.exists():
                tif_file.unlink()
    def _create_polygon_shapes(self) -> List[dict]:
        """Create polygon shapes from segmentation results."""
        all_polygons = []
        for z, polygons in enumerate(self._polygons):
            coords_array = polygons['coord']
            # Process each contour in the coordinates array
            for contour_idx in range(coords_array.shape[0]):
                try:
                    # Extract x,y coordinates for current contour
                    xy_coords = coords_array[contour_idx]
                    # x and y coordinates are flipped from StarDist output
                    points = [(float(y), float(x)) for x, y in zip(xy_coords[0], xy_coords[1])]
                    
                    ezomero_polygon = ezomero.rois.Polygon(
                        points=points,
                        z=z,
                        c=None,
                        t=None,
                        label="nuclei",
                        fill_color=None,
                        stroke_color=None,
                        stroke_width=None
                    )
                    all_polygons.append(ezomero_polygon)
                
                except Exception as e:
                    print(f"Error processing contour {contour_idx} at z={z}: {e}")
                    continue
        return all_polygons

    def save_segmentation_to_omero_as_roi(self) -> None:
        """Save segmentation as OMERO ROIs."""
        if not self._polygons:
            raise ValueError("No polygons available - run segmentation first")
            
        all_polygons = self._create_polygon_shapes()
        
        if all_polygons:
            try:
                roi_id = ezomero.post_roi(
                    conn=self._conn,
                    image_id=self._image.getId(),
                    shapes=all_polygons,
                    name=self.ROI_NAME,
                    description=self.ROI_DESCRIPTION
                )
                self.logger.info(f"Created ROI with ID: {roi_id}")
            except Exception as e:
                self.logger.error(f"Error creating ROI: {str(e)}")
                raise
        else:
            self.logger.warning("No valid polygons were created")

In [16]:
import pyclesperanto_prototype as cle
import pandas as pd

def measure_intensity(pixels, labels, size_z, size_c):
    all_statistics = []
    if size_z > 1:
        for z, label in zip(range(size_z), labels):
            for c in range(size_c):
                statistics = cle.statistics_of_labelled_pixels(pixels.getPlane(z, c, 0), label)
                statistics = pd.DataFrame(statistics)
                statistics['z'] = z
                statistics['channel'] = c
                all_statistics.append(statistics)
    else:
        statistics = cle.statistics_of_labelled_pixels(pixels.getPlane(1, 0, 0), labels)
        statistics['z'] = 0
        all_statistics.append(statistics)
    
    # Concatenate all statistics into a single DataFrame
    all_statistics_df = pd.concat(all_statistics, ignore_index=True)
    
    return all_statistics_df


#### code

In [None]:
#TODO extend to handle multiple channels,timepoints
#TODO measure the intensity of the segmented nuclei in the other channels
#TODO attach those results as a table to the original image in OMERO

if datatype == "plate":
    wells = list(plate.listChildren())
    # use the first 3 wells only
    #wells = wells[0:3] # for testing
    well_count = len(wells)
    for count, well in enumerate(wells):
            print('Well: %s/%s' % (count + 1, well_count), 'row:', well.row, 'column:', well.column)
            # Load a single Image per Well TODO load all images for a well if there are multiple
            fields = well.countWellSample()
            for field in range(fields):
                print('Field:', field)
                image = well.getImage(field)
                #save stack back to OMERO same project only add _nucleisegmentation to the name
                new_img_name = image.getName() + "_nucleisegmentation"
                desc = "Stardist nuclei segmentation"
                img = ProcessImage(image, conn)
                img.segment_nuclei(nucl_channel)
                img.save_segmentation_to_omero_as_attach(new_output_directory,desc)
                img.save_segmentation_to_omero_as_new_image(new_img_name,desc)
                img.save_segmentation_to_omero_as_roi()

                all_statistics_df = measure_intensity(img._pixels, img._labels, img._size_z, img._size_c)
                tabelid = ezomero.post_table(conn, object_type="Image", object_id=image.getId(), table = all_statistics_df,title="Nuclei_measurements")
                print('Created table ID:', tabelid)

            
elif datatype == "dataset":
    images = list(dataset.listChildren())
    # use the first 3 images only
    images = images[0:3]
    image_count = len(images)
    for count in range(image_count):
        image = well.getImage(count)
        #save stack back to OMERO same project only add _nucleisegmentation to the name
        new_img_name = image.getName() + "_nucleisegmentation"
        desc = "Stardist nuclei segmentation"
        img = ProcessImage(image, conn)
        img.segment_nuclei(nucl_channel)
        img.save_segmentation_to_omero_as_attach(new_output_directory,desc)
        img.save_segmentation_to_omero_as_new_image(desc)
        img.save_segmentation_to_omero_as_roi()

        all_statistics_df = measure_intensity(img.pixels, img.labels, img.size_z, img.size_c)
        tabelid = ezomero.post_table(conn, object_type="Image", object_id=image.getId(), table = all_statistics_df,title="Nuclei_measurements")
        print('Created table ID:', tabelid)

Well: 1/72 row: 0 column: 0
Field: 0
