In [2]:
from arcos4py.tools import track_events_image, remove_image_background
from arcos4py.tools._detect_events import upscale_image
from arcos4py.tools._cleandata import blockwise_median
import pandas as pd
import napari
from scipy.ndimage import distance_transform_edt
from scipy.ndimage import binary_dilation, binary_fill_holes, binary_erosion
from skimage.morphology import erosion, remove_small_objects, square
from skimage import io, exposure
from scipy import ndimage
from skimage.morphology import closing
from statsmodels.nonparametric.smoothers_lowess import lowess
from tqdm.auto import tqdm

import numpy as np
from scipy.signal import savgol_filter



In [3]:
def drop_scattered_small_labels(label_image, min_size=100):
    """
    Removes small scattered regions of each label from a labeled image.

    Parameters:
    - label_image: 2D numpy array representing the labeled image.
    - min_size: Minimum pixel size for keeping a scattered part of a label.

    Returns:
    - Processed image with small scattered labels dropped.
    """
    label_image = closing(label_image)
    unique_labels = np.unique(label_image)
    output_image = np.zeros_like(label_image)

    for label in unique_labels:
        if label == 0:  # Assuming 0 is the background
            continue

        # Create a binary image for the current label
        binary_mask = label_image == label

        # Identify separate regions of the current label
        labeled_mask, num_features = ndimage.label(binary_mask)

        # Measure the size of each region
        sizes = ndimage.sum(binary_mask, labeled_mask, range(num_features + 1))

        # Create a mask of regions to be kept for the current label
        mask_size = sizes >= min_size
        keep = mask_size[labeled_mask]

        # Update the output image with regions of the current label that are kept
        output_image[keep] = label

    return output_image


def process_time_series_label_images(time_series_label_images, min_size=100):
    """
    Processes a time-series of label images by removing small scattered labels.

    Parameters:
    - time_series_label_images: 3D numpy array representing a time-series of labeled images.
      The first dimension is time.
    - min_size: Minimum pixel size for keeping a scattered part of a label.

    Returns:
    - Processed time-series with small scattered labels dropped from each frame.
    """

    # Get the number of time points
    T = time_series_label_images.shape[0]

    # Initialize an output array of the same shape as the input
    output_images = np.zeros_like(time_series_label_images)

    for t in range(T):
        output_images[t] = drop_scattered_small_labels(time_series_label_images[t], min_size=min_size)

    return output_images


def filter_by_centroid_displacement(labeled_stack, min_distance):
    """
    Removes tracks from a labeled image stack if the total displacement of their centroid
    is less than the specified minimum distance.
    """
    labeled_stack = np.copy(labeled_stack)

    unique_labels = np.unique(labeled_stack)[1:]  # Exclude background (label 0)

    for label in unique_labels:
        # Get the coordinates of all pixels belonging to the current label (track) for each time point
        time_points = np.unique(np.where(labeled_stack == label)[0])

        centroids = []
        for t in time_points:
            coords = np.argwhere(labeled_stack[t] == label)
            centroid = coords.mean(axis=0)
            centroids.append(centroid)

        # Calculate the total centroid displacement by summing up the distances between consecutive time points
        total_distance = sum(np.linalg.norm(centroids[i + 1] - centroids[i]) for i in range(len(centroids) - 1))

        # If total displacement is less than min_distance, remove the track
        if total_distance < min_distance:
            labeled_stack[labeled_stack == label] = 0

    return labeled_stack


def filter_by_duration(labeled_stack, min_duration):
    """
    Removes tracks from a labeled image stack if their duration is less than the specified minimum.
    """
    labeled_stack = np.copy(labeled_stack)

    unique_labels = np.unique(labeled_stack)[1:]  # Exclude background (label 0)

    for label in unique_labels:
        # Get the coordinates of all pixels belonging to the current label (track) for each time point
        time_points = np.unique(np.where(labeled_stack == label)[0])

        # If the duration is less than min_duration, remove the track
        if len(time_points) < min_duration:
            labeled_stack[labeled_stack == label] = 0

    return labeled_stack


def smooth_segmentation(binary_objects, expand_iterations=1, remove_small=True, remove_small_objects_size=100):
    """
    Smooths the segmentation by removing small objects and filling holes.

    Parameters
    ----------
    binary_objects : numpy array
        Binary image of the segmented objects.
    remove_small : bool, optional
        Whether to remove small objects. The default is True.
    remove_small_objects_size : int, optional
        Size of the objects to remove. The default is 100.

    Returns
    -------
    binary_objects : numpy array
        Smoothed binary image of the segmented objects.
    """
    binary_objects = np.where(binary_objects == 1, 1, 0)
    if len(binary_objects.shape) == 3:
        for index, image in enumerate(binary_objects):
            image = binary_fill_holes(image)
            image = binary_dilation(image, square(5), iterations=expand_iterations)
            image = erosion(image, footprint=square(5))
            bool_img = image.astype(bool)
            if remove_small:
                image = remove_small_objects(bool_img, min_size=remove_small_objects_size**2)
            image = binary_fill_holes(image)
            binary = np.where(image, 1, 0)
            binary_objects[index] = binary
        return binary_objects
    else:
        binary_objects = binary_fill_holes(binary_objects)
        binary_objects = binary_dilation(binary_objects, square(5), iterations=expand_iterations)
        binary_objects = erosion(binary_objects, footprint=square(5))
        bool_img = binary_objects.astype(bool)
        if remove_small:
            binary_objects = remove_small_objects(bool_img, min_size=remove_small_objects_size**2)
        binary_objects = binary_fill_holes(binary_objects)
        binary_objects = np.where(binary_objects, 1, 0)
        return binary_objects

def bleach_correction_smooth(img_stack, window_length=11, polyorder=2):
    """
    Perform bleach correction on a t,y,x image stack using Savitzky-Golay smoothing.

    Parameters:
    - img_stack: 3D numpy array with shape (t, y, x)
    - window_length: Length of the filter window (must be odd).
    - polyorder: Order of the polynomial used to fit the samples.

    Returns:
    - Corrected 3D numpy array with same shape as img_stack
    """

    # Convert img_stack to float type for the correction
    img_stack = img_stack.astype(np.float64)

    # Calculate average intensity for each time point
    avg_intensities = img_stack.mean(axis=(1, 2))

    # Apply Savitzky-Golay filter to average intensities
    smoothed_intensities = savgol_filter(avg_intensities, window_length, polyorder)

    # Calculate correction factors
    correction_factors = smoothed_intensities / avg_intensities[0]

    # Apply correction to the image stack
    for i in range(img_stack.shape[0]):
        img_stack[i] /= correction_factors[i]

    return img_stack


def bleach_correction_loess(img_stack, frac=0.1):
    """
    Perform bleach correction on a t,y,x image stack using LOESS smoothing.

    Parameters:
    - img_stack: 3D numpy array with shape (t, y, x)
    - frac: The fraction of data used when estimating each y-value for the lowess fit.
            It determines the span of the window; for example, a value of 0.1 means
            each smoothed point uses 10% of the data points.

    Returns:
    - Corrected 3D numpy array with same shape as img_stack
    """

    # Convert img_stack to float type for the correction
    img_stack = img_stack.astype(np.float64)

    # Calculate average intensity for each time point
    avg_intensities = img_stack.mean(axis=(1, 2))

    # Time points
    t_values = np.arange(len(avg_intensities))

    # Apply LOESS smoothing to average intensities
    smoothed_intensities = lowess(avg_intensities, t_values, frac=frac, return_sorted=False)

    # Calculate correction factors
    correction_factors = smoothed_intensities / avg_intensities[0]

    # Apply correction to the image stack
    for i in range(img_stack.shape[0]):
        img_stack[i] /= correction_factors[i]

    return img_stack


def match_histogram(source, template, bins=65536):
    hist_source, bin_edges = np.histogram(source.ravel(), bins=bins, range=(0, bins))
    hist_template, _ = np.histogram(template.ravel(), bins=bins, range=(0, bins))

    cdf_source = hist_source.cumsum() / hist_source.sum()
    cdf_template = hist_template.cumsum() / hist_template.sum()

    lookup_table = np.zeros(bins, dtype=np.uint16)
    j = 0
    for i in range(bins):
        while cdf_template[j] < cdf_source[i] and j < bins:
            j += 1
        lookup_table[i] = j

    matched = lookup_table[source]
    return matched

In [4]:
stage_2_pos_12 = io.imread("transfer_187559_files_94515bab/lifeact_myosin_rgbd7_w15TIRF-GFP_s1_t1.TIF_-_Stage2__Position_12_.tiff")

In [5]:


myosin_bl = []
rGBD_bl = []
actin_bl = []

# split up image stack into channels
myosin = stage_2_pos_12[..., 0]
rGBD = stage_2_pos_12[..., 1]
actin = stage_2_pos_12[..., 2]

# reference images for histogram matching
reference_myosin = myosin[0]
reference_rGBD = rGBD[0]
reference_actin = actin[0]

# perform histogram matching
myosin_bl.append(np.stack([match_histogram(img, reference_myosin, bins=100000) for img in tqdm(myosin)]))
rGBD_bl.append(np.stack([match_histogram(img, reference_rGBD, bins=100000) for img in tqdm(rGBD)]))
actin_bl.append(np.stack([match_histogram(img, reference_actin, bins=100000) for img in tqdm(actin)]))

# processed images after histogram matching
myosin_bl = np.concatenate(myosin_bl)
rGBD_bl = np.concatenate(rGBD_bl)
actin_bl = np.concatenate(actin_bl)

  0%|          | 0/721 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

  0%|          | 0/721 [00:00<?, ?it/s]

## rGBD channel after histogram matching

In [6]:
viewer = napari.Viewer()
viewer.add_image(rGBD)
viewer.add_image(rGBD_bl)

<Image layer 'rGBD_bl' at 0x7fa71ce3c5b0>

## Bin rGBD channel

In [7]:
binned_rGBD = blockwise_median(rGBD, (1, 2, 2))
print(rGBD.shape)
print(binned_rGBD.shape)

(721, 540, 540)
(721, 270, 270)


In [8]:
viewer = napari.Viewer()
viewer.add_image(binned_rGBD)
viewer.add_image(rGBD)

<Image layer 'rGBD' at 0x7fa71cb77c40>

## Remove Background

In [9]:
rgbd_bg = remove_image_background(binned_rGBD, size=(20, 20, 20), filter_type="gaussian")
viewer = napari.Viewer()
viewer.add_image(rgbd_bg)

<Image layer 'rgbd_bg' at 0x7fa65922e980>

In [17]:
viewer = napari.Viewer()
viewer.add_image(binned_rGBD)
viewer.add_image(rgbd_bg)
viewer.add_labels(track_events_image(rgbd_bg > 10, eps=10, minClSz=50, predictor=True, nPrev=2))

100%|██████████| 721/721 [01:33<00:00,  7.72it/s]


<Labels layer 'Labels' at 0x7fa66910cd60>

In [12]:
from skimage.morphology import opening
from skimage.filters import gaussian

In [13]:
test = gaussian(opening(rgbd_bg), sigma=1)

In [14]:
tracked_events_rgbd = track_events_image(test > 10, eps=10, minClSz=50, predictor=True, nPrev=2)

100%|██████████| 721/721 [01:09<00:00, 10.43it/s]


In [16]:
viewer = napari.Viewer()
viewer.add_image(test)
viewer.add_labels(tracked_events_rgbd)

<Labels layer 'tracked_events_rgbd' at 0x7fa65c020700>