In [None]:
from functools import partial

from skimage.morphology import dilation, erosion
from skimage.segmentation import clear_border
from scipy.ndimage import gaussian_filter
from skimage.morphology import convex_hull_image
from skimage.segmentation import morphological_geodesic_active_contour
from skimage.filters import threshold_otsu
import numpy as np
from cellpose.models import CellposeModel

from calmutils.morphology.structuring_elements import hypersphere_centered
from calmutils.morphology.mask_postprocessing import keep_only_largest_component, get_mask_outlines


def periphery_segmentation_active_contours(img, blur_sigma=2, active_contour_iterations=40, return_outlines=True, outline_radius=2):

    """
    Simple segmentation of image with object border stained (e.g. nucleoporins, lamin for nucleus).
    Will first perform Otsu thresholding and then snap the mask to intensity image via morphological active countours.
    NOTE: This assumes a single object in the input.
    """

    # blur and threshold
    g = gaussian_filter(img.astype(float), blur_sigma)
    mask = g > threshold_otsu(g)

    # discard all but largest object and get convex hull
    mask = keep_only_largest_component(mask)
    mask = convex_hull_image(mask)

    # refine mask (snap to edge) via active contour
    # AC snaps to low values, so we pass the inverted gaussian filtered image
    mask = morphological_geodesic_active_contour(-g, active_contour_iterations, init_level_set=mask)

    # return mask
    if not return_outlines:
        return mask
    # or: calculate outlines and return
    else:
        outlines = get_mask_outlines(mask, expand_innner=outline_radius, expand_outer=outline_radius)
        return outlines


def get_axis_aligned_poles(mask):

    """
    Get "poles" of a mask image along all axes
    i.e. coordinates of the first and last nonzero pixel along each axis
    """

    nonzero_coords = np.stack(np.nonzero(mask)).T

    # add min and max coord index for each dimension
    pole_idxs = []
    for d in range(mask.ndim):
        pole_idxs.append(np.argmin(nonzero_coords.T[d]))
        pole_idxs.append(np.argmax(nonzero_coords.T[d]))

    return nonzero_coords[pole_idxs]


def cellpose_segment_in_projection(img, model, expansion_radius=3, diameter=50):

    # max project if necessary
    img = img.squeeze()
    if img.ndim > 2:
        projection = img.max(axis=0)
    else:
        projection = img

    print(projection.shape)

    # segment with model
    mask, _, _ = model.eval(projection, diameter=diameter)

    # clear borders and expand
    mask = clear_border(mask)
    mask = dilation(mask, hypersphere_centered(projection.ndim, expansion_radius))

    # if we have 3D input, repeat mask all along the z-axis to original shape
    # for 2D, just return the mask
    return np.array([mask] * img.shape[0]) if (img.ndim > 2) else mask

In [None]:
from autosted import AcquisitionPipeline
from autosted.taskgeneration import AcquisitionTaskGenerator
from autosted.callback_buildingblocks import (
    JSONSettingsLoader,
    LocationRemover,
    SpiralOffsetGenerator,
    LocationKeeper,
    NewestSettingsSelector, FOVSettingsGenerator
)
from autosted.imspector import get_current_stage_coords, ImspectorConnection
from autosted.stoppingcriteria import (
    MaximumAcquisitionsStoppingCriterion,
)
from autosted.detection import SegmentationWrapper, CoordinateDetectorWrapper

import logging
logging.basicConfig(level=logging.INFO)


# path to parameters
whole_fov_config = "config_json/20251126_60x_3d_ov_coarse_640.json"
cell_conv_config = "config_json/20251126_60x_3d_ov_fine_640.json"
sted_config = "config_json/20251126_60x_3d_detail_sted_640_fast.json"

# 3-level pipeline overview, cell (confocal), detail (STED)
pipeline = AcquisitionPipeline(
    data_save_path="acquisition_data/test_nup_640_poles",
    hierarchy_levels=["overview", 'cell', 'detail'],
    save_combined_hdf5=True
)

# for shorter delays between measurements, we can re-use them
pipeline.imspector_connection = ImspectorConnection(reuse_measurement=True)


### callback 1: next overviews in spiral
next_overview_generator = AcquisitionTaskGenerator(
    "overview",
    LocationRemover(JSONSettingsLoader(whole_fov_config)),
    SpiralOffsetGenerator(
        move_size=[75e-6, 75e-6],
        start_position=get_current_stage_coords(),
    ),
)


### callback 2: confocal image of individual cells (segment in projection via Cellpose)
model = CellposeModel(gpu=True, model_type='nuclei')
segment_in_ov_fun = partial(cellpose_segment_in_projection, model=model, expansion_radius=7)

cell_generator = AcquisitionTaskGenerator(
    "cell",
    LocationRemover(JSONSettingsLoader(cell_conv_config)),
    LocationKeeper(NewestSettingsSelector()),
    SegmentationWrapper(segment_in_ov_fun, offset_parameters='scan', channels=0),
)


### callback 3 (full nucleus version): STED images on tiled border of cell
def seg_fun(img):
    return periphery_segmentation_active_contours(img, outline_radius=5) * 1

sted_generator = AcquisitionTaskGenerator(
    "detail",
    LocationRemover(JSONSettingsLoader(sted_config)),
    LocationKeeper(NewestSettingsSelector()),
    SegmentationWrapper(seg_fun, channels=0)
)

### callback 3 (only tiles version): segment as above, but only image 2x2x2 micron tiles at poles of nucleus
def pole_detect_fun(img):
    mask = periphery_segmentation_active_contours(img, return_outlines=False)
    # erode a bit so we catch more when imaging around "poles"
    mask = erosion(mask, hypersphere_centered(img.ndim, 2))
    pole_coords = get_axis_aligned_poles(mask)
    return pole_coords

sted_generator = AcquisitionTaskGenerator(
    "detail",
    LocationRemover(JSONSettingsLoader(sted_config)),
    LocationKeeper(NewestSettingsSelector()),
    FOVSettingsGenerator(lengths=[2e-6, 2e-6, 2e-6]),
    CoordinateDetectorWrapper(pole_detect_fun, channels=0)
)


# add the callbacks and a stopping condition
pipeline.add_callback(next_overview_generator, "overview")
pipeline.add_callback(cell_generator, "overview")
pipeline.add_callback(sted_generator, "cell")
pipeline.add_stopping_condition(
    MaximumAcquisitionsStoppingCriterion(max_acquisitions_per_level={"overview": 2})
)

pipeline.run(initial_callback=next_overview_generator)

## TEST timing from HDF5 dataset

In [None]:
import h5py as h5
import datetime
from collections import defaultdict
import pandas as pd
from natsort import natsort_keygen

df = defaultdict(list)

with h5.File('acquisition_data/test_border_4/ea0a0374.h5') as fd:
    
    for k in fd['experiment'].keys():

        dataset_config0 = fd[f'experiment/{k}/0']
        start_time = dataset_config0.attrs['run_start_time']
        end_time = dataset_config0.attrs['run_end_time']
        start_time = datetime.datetime.fromtimestamp(start_time)
        end_time = datetime.datetime.fromtimestamp(end_time)

        df['meas_idx'].append(k)
        df['start_time'].append(start_time)
        df['end_time'].append(end_time)

df = pd.DataFrame(df)

# naturally sort by measurement idx -> chronological order
df = df.sort_values(by='meas_idx', key=natsort_keygen()).reset_index(drop=True)

df['dt_from_previous'] = df.start_time - df.shift(1).end_time
df['meas_duration'] = df.end_time - df.start_time

df.dt_from_previous.mean()