In [None]:
import tifffile
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd

from csbdeep.utils import normalize
from stardist.models import StarDist2D
from stardist.plot import render_label
from utils.helpers import clean_mask, cut_out_image
from skimage.exposure import rescale_intensity

In [2]:
model = StarDist2D.from_pretrained('2D_versatile_he')

Found model '2D_versatile_he' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.692478, nms_thresh=0.3.


In [3]:
image_paths = sorted(["data/images/"+path for path in os.listdir("data/images") if not path.startswith(".")])
mask_paths = sorted(["data/masks/"+path for path in os.listdir("data/masks") if not path.startswith(".")])
images = list(map(tifffile.imread, image_paths))
masks = list(map(tifffile.imread, mask_paths))
data = list(zip(images, masks))

In [33]:
# test_image = images[0]
# test_mask = masks[0]

# cleaned_mask = clean_mask(test_mask)
# cut_image = cut_out_image(test_image, cleaned_mask)

# plt.figure(figsize=(20, 10))

# plt.subplot(1, 2, 1)
# plt.imshow(test_image)
# plt.axis("off")
# plt.title("Image")

# plt.subplot(1, 2, 2)
# plt.imshow(cut_image)
# plt.axis("off")
# plt.title("Cut out image")

# plt.show()

In [32]:
# for image, mask in data:
#     cleaned_mask = clean_mask(mask)
#     cut_image = cut_out_image(image, cleaned_mask)
#     image_normed = rescale_intensity(cut_image, out_range=(0, 1))

#     labels, data_dict = model.predict_instances(image_normed, axes='YXC', prob_thresh=0.05, nms_thresh=0.3, return_labels=True)

#     plt.figure(figsize=(24, 12))

#     plt.subplot(1,2,1)
#     plt.imshow(image, cmap="gray")
#     plt.axis("off")
#     plt.title("input image")

#     plt.subplot(1,2,2)
#     plt.imshow(render_label(labels, img=image_normed, cmap=(1.0, 1.0, 0), alpha=0.6))
#     plt.axis("off")
#     plt.title("prediction + input overlay")

#     plt.show()

In [8]:
test_img, test_mask = data[0]

test_cleaned_mask = clean_mask(test_mask)
test_cut_image = cut_out_image(test_image, test_cleaned_mask)
test_image_normed = rescale_intensity(test_cut_image, out_range=(0, 1))

test_labels, test_data_dict = model.predict_instances(image_normed, axes='YXC', prob_thresh=0.05, nms_thresh=0.3, return_labels=True)

In [11]:
test_labels.shape

(1920, 2560)

In [34]:
# plt.imshow(test_labels, cmap="grey")

In [None]:
def extract_stardist_features(labels, data_dict):
    """
    Extract features from StarDist segmentation results
    
    Parameters:
    -----------
    labels : numpy.ndarray
        Instance segmentation mask from StarDist
    data_dict : dict
        Dictionary returned by StarDist with 'coord', 'points', 'prob'
    
    Returns:
    --------
    pandas.DataFrame
        DataFrame with features for each detected object
    """
    
    props = regionprops(labels)
    
    coordinates = data_dict.get('coord', [])
    points = data_dict.get('points', [])
    probabilities = data_dict.get('prob', [])
    
    features = []

    for i, prop in enumerate(props):
        feature_dict = {
            'label_id': prop.label,
            'area': prop.area,
            'perimeter': prop.perimeter,
            'centroid_y': prop.centroid[0],
            'centroid_x': prop.centroid[1],
            'bbox_min_row': prop.bbox[0],
            'bbox_min_col': prop.bbox[1], 
            'bbox_max_row': prop.bbox[2],
            'bbox_max_col': prop.bbox[3],
            'eccentricity': prop.eccentricity,
            'solidity': prop.solidity,
            'extent': prop.extent,
            'major_axis_length': prop.major_axis_length,
            'minor_axis_length': prop.minor_axis_length,
            'orientation': prop.orientation,
            'equivalent_diameter': prop.equivalent_diameter,
        }
        
        if i < len(probabilities):
            feature_dict['stardist_probability'] = probabilities[i]
        
            
        if i < len(coordinates):
            feature_dict['stardist_center_y'] = coordinates[i][0] if len(coordinates[i]) > 0 else None
            feature_dict['stardist_center_x'] = coordinates[i][1] if len(coordinates[i]) > 1 else None
        
        features.append(feature_dict)
    
    return pd.DataFrame(features)

In [None]:
def get_segmentation_summary(labels, data_dict):
    """
    Get summary statistics from StarDist segmentation
    """
    features_df = extract_stardist_features(labels, data_dict)

    summary = {
        'total_objects': len(features_df),
        'average_area': features_df['area'].mean(),
        'median_area': features_df['area'].median(),
        'std_area': features_df['area'].std(),
        'average_probability': features_df['stardist_probability'].mean(),
        'median_probability': features_df['stardist_probability'].median(),
    }
    
    return summary

In [None]:
def visualize_segmentation()