In [None]:
!pip install aicsimageio[nd2]
!pip install xlsxwriter
!pip install reportlab

In [None]:
from pathlib import Path
import napari
from napari.settings import get_settings
import pandas as pd
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D
import SimpleITK as sitk
import skimage

from scipy import ndimage as ndi
from scipy.ndimage import label, zoom, binary_dilation, generate_binary_structure
from scipy.stats import gaussian_kde
from skimage.segmentation import watershed, relabel_sequential, expand_labels
from skimage.draw import line as draw_line
from vispy.color import Colormap
from matplotlib.colors import to_rgb
from csbdeep.utils import normalize
from stardist.models import StarDist2D
from collections import defaultdict
from aicsimageio import AICSImage
from nd2reader import ND2Reader
import meshlib.mrmeshpy as mr
import meshlib.mrmeshnumpy as mrn
import meshio
import statistics as st
import tetgen

from skimage import filters, morphology
from skimage.feature import peak_local_max
from skimage.morphology import ball
from skimage.filters import threshold_otsu, threshold_sauvola
from skimage.measure import find_contours, regionprops

from PIL import Image as PILImage
from reportlab.platypus import Image as RLImage
from reportlab.platypus import (
    SimpleDocTemplate, Image, Paragraph, Spacer, Table, TableStyle, PageBreak
)
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import inch
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib import colors
import xlsxwriter

# Enable interactive mode for napari in Jupyter
settings = get_settings()
settings.application.ipy_interactive = True

#### Functions

In [None]:
# Image processing and utility functions

def hist_plot(im_in, stain_complete_df, thresh=0, legend=False):
    """Plot histogram and CDF for each channel."""
    fig, axs = plt.subplots(1, im_in.shape[3], figsize=(15, 2))
    for c in range(im_in.shape[3]):
        hist, _ = np.histogram(im_in[:, :, :, c].flatten(), 256, [0, 256])
        cdf = hist.cumsum()
        cdf_normalized = cdf * hist.max() / cdf.max()
        color = stain_complete_df.loc[stain_complete_df.index[c], 'Color']
        axs[c].plot(cdf_normalized, color='b')
        axs[c].hist(im_in[:, :, :, c].flatten(), 256, [0, 256], color=color if color != 'WHITE' else 'GRAY')
        axs[c].set_xlim([0, 256])
        if legend:
            axs[c].legend(('cdf', 'histogram'), loc='upper left')
        if thresh > 0:
            axs[c].plot([thresh, thresh], [0, cdf_normalized.max()], color='g')
        axs[c].set_title(stain_complete_df.index[c])
        axs[c].set_yscale('log')

def napari_contrast_gamma_uint8(image, contrast_limits, gamma):
    """
    Apply Napari-style contrast limits + gamma correction,
    and return the resulting image as uint8.

    Parameters
    ----------
    image : np.ndarray
        Input image (any dtype).
    contrast_limits : tuple (clim_min, clim_max)
        Same values you see in Napari GUI.
    gamma : float
        Gamma value from Napari GUI.

    Returns
    -------
    out_uint8 : np.ndarray (uint8)
        Image transformed exactly like Napari display,
        then mapped to 0–255.
    """

    clim_min, clim_max = contrast_limits

    # Convert to float
    img = image.astype(np.float32)

    # Napari contrast normalization
    img = (img - clim_min) / (clim_max - clim_min)
    img = np.clip(img, 0.0, 1.0)

    # Napari gamma
    img = img ** gamma

    # Convert display range [0,1] → uint8 [0,255]
    out_uint8 = (img * 255).round().astype(np.uint8)

    return out_uint8

def remove_small_islands(binary_matrix, area_threshold):
    """Remove small connected components from a binary mask."""
    labeled_array, num_features = label(binary_matrix)
    for i in range(1, num_features + 1):
        component = (labeled_array == i)
        if component.sum() < area_threshold:
            binary_matrix[component] = 0
    return binary_matrix

def stardist3d_from_2d(
    img_3d,
    model_name="2D_versatile_fluo",
    nucleus_radius=5,
    voxel_size=(1.0, 0.5, 0.5),
    norm=True,
):
    """
    Apply StarDist2D slice-by-slice to a 3D stack, merge predictions,
    and split weakly connected nuclei using distance-based watershed.
    Handles anisotropic voxel spacing.

    Parameters
    ----------
    img_3d : np.ndarray
        Input 3D grayscale image, shape (Z, Y, X).
    model_name : str
        Name of pretrained StarDist2D model.
    nucleus_radius : float
        Approximate radius of nuclei in pixels (XY units).
    voxel_size : tuple(float)
        Physical voxel size as (z_spacing, y_spacing, x_spacing).
    norm : bool
        Normalize each 2D slice before prediction.

    Returns
    -------
    labels_split : np.ndarray
        3D labeled array (int32), same shape as input.
    """

    assert img_3d.ndim == 3, "Input must be 3D (Z, Y, X)"
    z_spacing, y_spacing, x_spacing = voxel_size

    print(f"Running StarDist2D on {img_3d.shape[0]} z-slices...")
    model = StarDist2D.from_pretrained(model_name)

    labels_3d = np.zeros_like(img_3d, dtype=np.int32)
    current_label = 1

    for z in range(img_3d.shape[0]):
        img = img_3d[z]
        if norm:
            img = normalize(img, 1, 99.8, axis=None)

        labels2d, _ = model.predict_instances(img)
        labels2d = np.where(labels2d > 0, labels2d + current_label, 0)
        labels_3d[z] = labels2d
        current_label = labels2d.max() + 1

    # Merge touching objects in 3D
    labels_3d = skimage.measure.label(labels_3d > 0, connectivity=1)

    # --- Anisotropic distance-based splitting ---
    print("Computing distance transform with anisotropic voxel spacing...")
    distance = ndi.distance_transform_edt(labels_3d > 0, sampling=voxel_size)

    # Estimate local maxima using nucleus_radius as search distance in XY
    footprint = np.ones(
        (
            max(1, int(z_spacing / y_spacing)),  # thin in z
            int(nucleus_radius),
            int(nucleus_radius),
        ),
        dtype=bool,
    )

    local_max = peak_local_max(
        distance,
        footprint=footprint,
        labels=labels_3d > 0,
        exclude_border=False,
    )

    # Create markers for watershed
    markers = np.zeros_like(labels_3d, dtype=int)
    for i, coord in enumerate(local_max, start=1):
        markers[tuple(coord)] = i

    # Watershed segmentation
    print("Running 3D watershed to split connected nuclei...")
    labels_split = watershed(-distance, markers, mask=labels_3d > 0)

    print(f"Done. Found {labels_split.max()} nuclei.")
    return labels_split

def make_anisotropic_footprint(radius_Z, radius_Y, radius_X):
    zz, yy, xx = np.ogrid[
        -radius_Z:radius_Z+1,
        -radius_Y:radius_Y+1,
        -radius_X:radius_X+1
    ]
    ellipsoid = ((zz / radius_Z)**2 + (yy / radius_Y)**2 + (xx / radius_X)**2) <= 1
    return ellipsoid

def voxel_volume(ri_x, ri_y, ri_z, zooms):
    return (ri_x * ri_y * ri_z) / np.prod(zooms)

def save_raw_png(arr, filename, contrast_limits=None, gamma=None):
    """
    Save a 2D numpy array to PNG while optionally applying Napari-style
    contrast limits and gamma so saved images match displayed intensities.

    Parameters
    - arr: 2D array-like
    - filename: output path
    - contrast_limits: tuple (min, max) to map to [0,1] before gamma (optional)
    - gamma: gamma exponent to apply after contrast (optional)

    Backwards-compatible: if no contrast_limits provided, tries to preserve
    dtype and dynamic range as before.
    """
    arr = np.asarray(arr)

    # If user requested Napari-style mapping, use helper
    if contrast_limits is not None:
        clim = contrast_limits
        g = 1.0 if gamma is None else float(gamma)
        try:
            out = napari_contrast_gamma_uint8(arr.astype(np.float32), (float(clim[0]), float(clim[1])), g)
            img = PILImage.fromarray(out)
            img.save(filename)
            return filename
        except Exception:
            # fallback to naive save below
            pass

    # --- Fallback / legacy behavior ---
    # Already uint8/uint16 → save as-is
    if arr.dtype == np.uint8 or arr.dtype == np.uint16:
        img = PILImage.fromarray(arr)
        img.save(filename)
        return filename

    # Float data: scale by max to choose appropriate depth
    if np.issubdtype(arr.dtype, np.floating):
        maxv = float(arr.max()) if arr.size else 0.0
        if maxv == 0:
            arr8 = np.zeros_like(arr, dtype=np.uint8)
            img = PILImage.fromarray(arr8)
            img.save(filename)
            return filename

        if maxv <= 255:
            arr_scaled = (arr / maxv) * 255.0
            arr_scaled = np.clip(arr_scaled, 0, 255).astype(np.uint8)
        else:
            arr_scaled = (arr / maxv) * 65535.0
            arr_scaled = np.clip(arr_scaled, 0, 65535).astype(np.uint16)

        img = PILImage.fromarray(arr_scaled)
        img.save(filename)
        return filename

    # Integer types other than uint8/uint16
    if np.issubdtype(arr.dtype, np.integer):
        maxv = int(arr.max()) if arr.size else 0
        if maxv <= 255:
            arr8 = arr.astype(np.uint8)
            img = PILImage.fromarray(arr8)
            img.save(filename)
            return filename
        elif maxv <= 65535:
            arr16 = arr.astype(np.uint16)
            img = PILImage.fromarray(arr16)
            img.save(filename)
            return filename
        else:
            arr16 = (arr / maxv * 65535).astype(np.uint16)
            img = PILImage.fromarray(arr16)
            img.save(filename)
            return filename

    raise ValueError("Unsupported dtype for PNG saving.")

def crop_nucleus_with_padding(nucleus_mask, full_img_stack, pad=20):
    """
    nucleus_mask: 3D boolean array
    full_img_stack: dict {cond: 3D array}
    """
    # pick best z slice
    if nucleus_mask.ndim == 3:
        z_counts = nucleus_mask.sum(axis=(1, 2))
        best_z = int(np.argmax(z_counts))
        nuc2d = nucleus_mask[best_z]
    else:
        best_z = 0
        nuc2d = nucleus_mask

    ys, xs = np.where(nuc2d)
    if len(xs) == 0:
        return None, best_z, None, None

    # bounding box
    y_min0 = max(0, ys.min() - pad)
    y_max0 = ys.max() + pad
    x_min0 = max(0, xs.min() - pad)
    x_max0 = xs.max() + pad

    crop_dict = {}
    heights, widths = [], []

    # crop each condition
    for cond, img3D in full_img_stack.items():
        Z, H_full, W_full = img3D.shape

        y_min = y_min0
        y_max = min(y_max0, H_full)
        x_min = x_min0
        x_max = min(x_max0, W_full)

        # keep original intensities (do not normalize here)
        cropped = img3D[best_z, y_min:y_max, x_min:x_max].astype(float)

        crop_dict[cond] = cropped
        heights.append(cropped.shape[0])
        widths.append(cropped.shape[1])

    min_H = int(min(heights)) if heights else 0
    min_W = int(min(widths)) if widths else 0

    return crop_dict, best_z, (y_min0, x_min0), (min_H, min_W)

def save_merged_figure(
    nucleus_mask, full_img_stack, condition_colors, nucleus_id,
    seg_stack,
    nucleus_color='blue', cytoplasm_color='green', pcm_color='magenta',
    pad=20, out_dir="merged_png"
):
    """
    Build merged RGB: nuclei (blue) + cytoplasm (green) + PCM (magenta) + marker overlay.
    """
    os.makedirs(out_dir, exist_ok=True)

    crop_dict, best_z, (y0, x0), (min_H, min_W) = crop_nucleus_with_padding(nucleus_mask, full_img_stack, pad=pad)
    if crop_dict is None or min_H <= 0 or min_W <= 0:
        return None

    # Initialize black RGB
    merged_rgb = np.zeros((min_H, min_W, 3), dtype=float)

    structure_opacity = 0.2
    white_rgb = np.array([1.0, 1.0, 1.0])  # Pure white
    blue_rgb = np.array([0.0,0.0,1.0])
    
    # 1. Nuclei (white @ 20%)
    if nucleus_mask.ndim == 3:
        nuc2d = nucleus_mask[best_z]
    else:
        nuc2d = nucleus_mask
    nuc_crop = nuc2d[y0:y0+min_H, x0:x0+min_W].astype(float)
    merged_rgb += nuc_crop[..., None] * blue_rgb * structure_opacity
    
    # 2+3. Dashed contours around cytoplasm + PCM (white, 20% opacity)
    structure_opacity = 0.2
    white_rgb = np.array([1.0, 1.0, 1.0])
    
    # Combined cyto + PCM mask
    cyto_mask = (seg_stack.get('Cytoplasm', np.zeros_like(nucleus_mask)) == nucleus_id)
    pcm_mask = (seg_stack.get('PCM', np.zeros_like(nucleus_mask)) == nucleus_id)
    combined_mask = cyto_mask | pcm_mask  # Union of both
    
    if np.any(combined_mask):
        # Use best_z slice
        if combined_mask.ndim == 3:
            combined_2d = combined_mask[best_z]
        else:
            combined_2d = combined_mask
        
        # Crop
        combined_crop = combined_2d[y0:y0+min_H, x0:x0+min_W].astype(float)
        
        # Create dashed contour (3px thick, 50% dash:gap)
        from skimage.measure import find_contours
        contours = find_contours(combined_crop, 0.5, fully_connected='low')
        
        # Draw on temp RGB canvas
        contour_rgb = np.zeros((min_H, min_W, 3), dtype=float)
        for contour in contours:
            # Scale contour coords back to image space
            contour[:, 0] *= min_H / combined_crop.shape[0]  # y
            contour[:, 1] *= min_W / combined_crop.shape[1]  # x
            
            # Integer positions for drawing
            contour_int = contour.astype(int)
            
            # Dash pattern: every other pixel
            for i in range(0, len(contour_int), 2):  # Step 2 for dash/gap
                if i+1 < len(contour_int):
                    # Draw 2px line segments
                    rr, cc = draw_line(int(contour_int[i,0]), int(contour_int[i,1]),
                                       int(contour_int[i+1,0]), int(contour_int[i+1,1]))
                    contour_rgb[rr, cc] = white_rgb * structure_opacity
        
        merged_rgb += contour_rgb

    # 4. Marker channels (additive overlay)
    for cond, img in crop_dict.items():
        img_small = img[:min_H, :min_W].copy()
        
        # Napari contrast/gamma
        if (cond in stain_complete_df.index) and ('Cont_min' in stain_complete_df.columns):
            try:
                clim = (stain_complete_df.loc[cond, 'Cont_min'], stain_complete_df.loc[cond, 'Cont_max'])
                gamma = stain_complete_df.loc[cond, 'Gamma'] if 'Gamma' in stain_complete_df.columns else 1.0
                img_display = napari_contrast_gamma_uint8(img_small.astype(np.float32), 
                                                          (float(clim[0]), float(clim[1])), 
                                                          float(gamma))
                img_normalized = img_display.astype(float) / 255.0
            except:
                img_normalized = img_small / (img_small.max() + 1e-6)
        else:
            img_normalized = img_small / (img_small.max() + 1e-6)
        
        color = np.array(mcolors.to_rgb(condition_colors.get(cond, 'gray')))
        merged_rgb += img_normalized[..., None] * color * 0.6  # Transparent overlay

    # Clip and convert to uint8
    merged_rgb = np.clip(merged_rgb, 0, 1.0)
    merged_uint8 = (merged_rgb * 255).astype(np.uint8)

    # Save PNG filename (not array)
    fname = os.path.join(out_dir, f"n{nucleus_id}_merged.png")
    PILImage.fromarray(merged_uint8).save(fname)
    return fname

def double_plateau_hist_equalization_nd(
    img: np.ndarray,
    num_plateaus: int = 2,
    plateau_factor: float = 0.5
) -> np.ndarray:
    """
    Multi-plateau histogram equalization for 8-bit images or volumes.

    Parameters
    ----------
    img : np.ndarray
        Input image/volume, uint8. Can be 2D (H,W) or 3D (Z,H,W) or 3D color (H,W,3).
        For 3D, assumes scalar intensities (one channel).
    num_plateaus : int
        Number of plateau levels (2 = double plateau).
    plateau_factor : float
        Factor (0–1+) to compute plateau(s) from average count.

    Returns
    -------
    out : np.ndarray
        Equalized image/volume with multi-plateau clipping.
    """
    if img.dtype != np.uint8:
        raise ValueError("Input must be uint8")

    # 2D grayscale
    if img.ndim == 2:
        return _mphe_channel(img, num_plateaus, plateau_factor)

    # 3D scalar volume (e.g. Z,H,W)
    if img.ndim == 3 and img.shape[-1] != 3:
        # Flatten to 1D for histogram, then map back
        flat = img.ravel()
        flat_eq = _mphe_flat(flat, num_plateaus, plateau_factor)
        return flat_eq.reshape(img.shape)

    # 3D color (H,W,3) image: apply on luminance
    if img.ndim == 3 and img.shape[-1] == 3:
        ycrcb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
        y, cr, cb = cv2.split(ycrcb)
        y_eq = _mphe_channel(y, num_plateaus, plateau_factor)
        ycrcb_eq = cv2.merge([y_eq, cr, cb])
        out = cv2.cvtColor(ycrcb_eq, cv2.COLOR_YCrCb2BGR)
        return out

    raise ValueError("Unsupported input shape")


def _mphe_flat(
    flat: np.ndarray,
    num_plateaus: int,
    plateau_factor: float
) -> np.ndarray:
    """Multi-plateau HE on a flat uint8 array."""
    # Compute histogram
    hist = np.bincount(flat, minlength=256).astype(np.float64)
    total_pixels = flat.size

    mean_count = total_pixels / 256.0
    base_plateau = plateau_factor * mean_count

    plateau_levels = np.linspace(
        base_plateau * 0.5,
        base_plateau * (0.5 + num_plateaus),
        num_plateaus
    )

    clipped_hist = hist.copy()

    for p in plateau_levels:
        excess = np.maximum(clipped_hist - p, 0)
        clipped_hist = np.minimum(clipped_hist, p)
        redistributed = excess.sum() / 256.0
        clipped_hist += redistributed

    cdf = np.cumsum(clipped_hist)
    cdf_norm = cdf / cdf[-1]
    lut = np.floor(255 * cdf_norm).astype(np.uint8)

    return lut[flat]


def _mphe_channel(channel: np.ndarray,
                  num_plateaus: int,
                  plateau_factor: float) -> np.ndarray:
    """2D helper using the 1D flat implementation."""
    flat = channel.ravel()
    flat_eq = _mphe_flat(flat, num_plateaus, plateau_factor)
    return flat_eq.reshape(channel.shape)

In [None]:
# Unused watershed functions removed
# The main nuclei segmentation is performed using shrink_to_markers_robust in the segmentation section

In [None]:
# TEST

def shrink_to_markers(binary_3d, connectivity=2, max_iter=100,
                     min_final_size=3, max_final_cc=3,
                     merge_small_touching=False, size_ratio_thresh=0.5):
    """
    Shrink each island until stable and create markers.
    Optionally merge small touching marker-islands into larger neighbors.

    Returns
    -------
    marker_image : bool
    marker_labels : int
    """

    binary_3d = binary_3d.astype(bool)
    cc_labels, n_cc = ndi.label(
        binary_3d,
        structure=ndi.generate_binary_structure(3, connectivity)
    )
    marker_image = np.zeros_like(binary_3d, dtype=bool)
    selem = ndi.generate_binary_structure(3, connectivity)

    # --- shrink step ---
    for cc_id in range(1, n_cc + 1):
        cc_mask = (cc_labels == cc_id)
        current = cc_mask.copy()
        it = 0

        while np.sum(current) > min_final_size and it < max_iter:
            eroded = ndi.binary_erosion(current, structure=selem)

            # if erosion kills it completely, keep last state
            if not np.any(eroded):
                break

            # count components after erosion
            n_cc_eroded = ndi.label(eroded, structure=selem)[1]
            if n_cc_eroded > max_final_cc:
                break  # keep pre-split state

            current = eroded
            it += 1

        marker_image |= current

    # label marker islands
    marker_labels, n_markers = ndi.label(
        marker_image,
        structure=ndi.generate_binary_structure(3, connectivity)
    )
    
    if not merge_small_touching or n_markers <= 1:
        return marker_image, marker_labels

    # --- adjacency + size-based merging ---
    # compute sizes
    labels, counts = np.unique(marker_labels, return_counts=True)
    sizes = dict(zip(labels.tolist(), counts.tolist()))
    sizes.pop(0, None)  # remove background

    Z, Y, X = marker_labels.shape
    offsets = []
    for dz in [-1, 0, 1]:
        for dy in [-1, 0, 1]:
            for dx in [-1, 0, 1]:
                if dz == dy == dx == 0:
                    continue
                # 6-neighborhood if connectivity==1, else denser
                if connectivity == 1 and (abs(dx)+abs(dy)+abs(dz) > 1):
                    continue
                offsets.append((dz, dy, dx))

    # build adjacency set
    neighbors = defaultdict(set)
    for z in range(Z):
        for y in range(Y):
            for x in range(X):
                c = marker_labels[z, y, x]
                if c == 0:
                    continue
                for dz, dy, dx in offsets:
                    nz, ny, nx = z+dz, y+dy, x+dx
                    if 0 <= nz < Z and 0 <= ny < Y and 0 <= nx < X:
                        n = marker_labels[nz, ny, nx]
                        if n > 0 and n != c:
                            neighbors[c].add(n)

    # decide reassignment: each label may be reassigned to one bigger neighbor
    new_label = {lab: lab for lab in sizes.keys()}
    print(neighbors)
    for lab, neighs in neighbors.items():
        #print(neighs)
        if not neighs:
            continue
        size_lab = sizes[lab]
        # find biggest neighbor
        big = max(neighs, key=lambda L: sizes.get(L, 0))
        size_big = sizes.get(big, 0)
        if size_big <= 0:
            continue
        # if lab is at least 50% smaller than the biggest neighbor, merge into it
        print(size_lab)
        print(size_big)
        if (size_lab <= size_ratio_thresh * size_big) or (size_lab < 20.0):
            new_label[lab] = big
        else:
            print(size_lab)
            print(size_big)

    # apply reassignment
    relabeled = marker_labels.copy()
    for lab, target in new_label.items():
        if lab != target:
            relabeled[marker_labels == lab] = target

    # relabel to consecutive integers
    final_labels, _ = ndi.label(relabeled > 0,
                                structure=ndi.generate_binary_structure(3, connectivity))

    final_markers = final_labels > 0
    return final_markers, final_labels

def shrink_to_markers_robust(binary_3d, min_marker_size=1, size_ratio_thresh=0.4, **kwargs):
    """shrink_to_markers + absolute size filter + relaxed connectivity"""
    markers_bin, markers_lab = shrink_to_markers(binary_3d, connectivity=2, **kwargs)
    
    # absolute size filter
    sizes = np.bincount(markers_lab.ravel())[1:]
    keep_labels = np.where(sizes >= min_marker_size)[0] + 1
    
    filtered = np.isin(markers_lab, keep_labels)
    final_labels, _ = ndi.label(filtered)
    
    return filtered, final_labels

def remove_small_island_labels(marker_labels, connectivity=1, size_ratio_thresh=0.5, min_cell_size=5.0):
    labels, counts = np.unique(marker_labels, return_counts=True)
    sizes = dict(zip(labels.tolist(), counts.tolist()))
    sizes.pop(0, None)  # remove background

    Z, Y, X = marker_labels.shape
    offsets = []
    for dz in [-1, 0, 1]:
        for dy in [-1, 0, 1]:
            for dx in [-1, 0, 1]:
                if dz == dy == dx == 0:
                    continue
                # 6-neighborhood if connectivity==1, else denser
                if connectivity == 1 and (abs(dx)+abs(dy)+abs(dz) > 1):
                    continue
                offsets.append((dz, dy, dx))

    # build adjacency set
    neighbors = defaultdict(set)
    for z in range(Z):
        for y in range(Y):
            for x in range(X):
                c = marker_labels[z, y, x]
                if c == 0:
                    continue
                for dz, dy, dx in offsets:
                    nz, ny, nx = z+dz, y+dy, x+dx
                    if 0 <= nz < Z and 0 <= ny < Y and 0 <= nx < X:
                        n = marker_labels[nz, ny, nx]
                        if n > 0 and n != c:
                            neighbors[c].add(n)

    # decide reassignment: each label may be reassigned to one bigger neighbor
    new_label = {lab: lab for lab in sizes.keys()}
    for lab, neighs in neighbors.items():
        #print(neighs)
        if not neighs:
            continue
        size_lab = sizes[lab]
        # find biggest neighbor
        big = max(neighs, key=lambda L: sizes.get(L, 0))
        size_big = sizes.get(big, 0)
        if size_big <= 0:
            continue
        # if lab is at least 50% smaller than the biggest neighbor, merge into it
        if (size_lab <= size_ratio_thresh * size_big) or (size_lab < min_cell_size):
            new_label[lab] = big

    # apply reassignment
    relabeled = marker_labels.copy()
    for lab, target in new_label.items():
        if lab != target:
            relabeled[marker_labels == lab] = target

    # relabel to consecutive integers
    final_labels, _ = ndi.label(relabeled > 0,
                                structure=ndi.generate_binary_structure(3, connectivity))

    final_markers = final_labels > 0
    return final_markers, relabeled

In [None]:
# Additional functions are defined in the Functions cell above

# INPUTS

### File upload

In [None]:
# Load TIFF file and extract image data
input_file = 'GELMA1.nd2'
big_image=True

ROI = [0, 0, 0, 0, 0, 0] #XYZ - put 0 to keep the original value

In [None]:
meta = AICSImage(input_file)
if big_image:
    x0, x1, y0, y1, z0, z1 = ROI
    if x1==0:
        x1 = meta.shape[4]
    if y1==0:
        y1 = meta.shape[3]
    if z1==0:
        z1 = meta.shape[2]
    # Get lazy dask array in ZYXC order
    lazy = meta.get_image_dask_data("ZYXC")
    sub = lazy[z0:z1, y0:y1, x0:x1, :]
    # Actually load this subset into memory
    img = sub.compute()  # This will be in ZYXC order
    ROI = [0, 0, 0, 0, 0, 0]
else:
    # Also use ZYXC order for consistency
    img = meta.get_image_data("ZYXC", T=0) 

print(img.shape)

In [None]:
# Get physical pixel sizes
r_X = meta.physical_pixel_sizes.X #um/px
r_Y = meta.physical_pixel_sizes.Y #um/px
r_Z = meta.physical_pixel_sizes.Z #um/px
print([r_X, r_Y, r_Z])

if big_image==False:
    imdata=meta.get_image_data()
    imtype=imdata.dtype
    bdepth=imtype.itemsize*8
    print(imtype)

with ND2Reader(input_file) as nd2:
    print("Date:", nd2.metadata.get("date"))
    print("Channels:", nd2.metadata.get("channels"))

### Sample

In [None]:
nuclei_diameter=10 #um
cell_diameter=30 #um

cyto_factor=3.0
PCM_factor=4.0

stain_dict = {
    'LIVE': ['Calcein', '488_10x', 'Green'],
    'DEAD': ['EthD', '555_10x', 'Red']
}

### Image

In [None]:
scale_factor=0.5
zoom_factors = [1.0, 1.0, 1.0] #XYZ
zoom_factors = [x * scale_factor for x in zoom_factors]

### Setup

In [None]:
name_setup = 'CARLOTTA_LD_30m'
use_setup = True
trig_stardist = False  # Set to True to use StarDist model
multilabel=True

## INFORMATION

In [None]:
nuclei_radius=nuclei_diameter*0.5 #um
cell_radius=cell_diameter*0.5 #um

nuclei_volume=np.ceil(4.0*((nuclei_radius)**3.0)*np.pi/3.0) #um^3
cell_volume=np.ceil(4.0*((cell_radius)**3.0)*np.pi/3.0) #um^3

x0, x1, y0, y1, z0, z1 = ROI

# img is in ZYXC order, so shape[0]=Z, shape[1]=Y, shape[2]=X, shape[3]=C
if z1==0:
    z1 = img.shape[0]
if y1==0:
    y1 = img.shape[1]
if x1==0:
    x1 = img.shape[2]

if big_image:
    im_original = img.astype('float32')
    im_original_ROI = im_original.copy()
else:
    im_original = meta.get_image_data("ZYXC", S=0, T=0).astype('float32')
    im_original_ROI = im_original[z0:z1,y0:y1,x0:x1,:]

im_final_stack={'Original image': im_original_ROI}

In [None]:
if big_image:
    # Image is already in ZYXC order, just verify shape
    orig = im_final_stack['Original image']
    print('Original image shape (ZYXC):', orig.shape)
    # Handle edge case: if there's an extra singleton dimension, squeeze it
    if len(orig.shape) == 5 and orig.shape[4] == 1:
        orig = orig[..., 0]
        im_final_stack['Original image'] = orig
        print('After removing singleton dimension:', orig.shape)

### Information about the staining

In [None]:
# Define staining dictionary and create DataFrame
stain_dict = {k.upper(): [item.upper() if isinstance(item, str) else item for item in v] for k, v in stain_dict.items()}
stain_df = pd.DataFrame.from_dict(stain_dict, orient='index', columns=['Marker', 'Laser', 'Color'])
laser_order=nd2.metadata.get("channels")

# Map fluorophore to its order index
order_map = {name.strip().upper(): i for i, name in enumerate(laser_order)}
stain_df['order'] = stain_df['Laser'].map(order_map)

# Sort by that and drop helper column
stain_df = stain_df.sort_values('order').drop(columns='order')

stain_df.index.name = 'Condition'

if 'NUCLEI' not in stain_df.index:
    print('No nuclei condition!')

In [None]:
# Visualize each channel using napari
im_in=im_final_stack['Original image'].copy()

viewer_0 = napari.Viewer()
for c, c_name in enumerate(stain_df['Marker']):
    #im_in = meta.get_image_data("ZYX", C=c, S=0, T=0).astype('float32')
    im_channel = im_in[:,:,:,c]

    # Stretch to [0, 255]
    im_8b = ((im_channel - im_channel.min()) / (im_channel.max() - im_channel.min()) * 255).clip(0, 255).astype('uint8')
    
    viewer_0.add_image(im_8b, name=f"{stain_df.index[c]} ({c_name})", 
                        colormap=stain_df['Color'][c], blending='additive')

    viewer_0.scale_bar.visible = True
    viewer_0.scale_bar.unit = 'um'

### Acquisition processing setup

In [None]:
# Setup for acquisition and contrast/gamma settings
im_in=im_final_stack['Original image'].copy()

stain_df = stain_df.reset_index(drop=False)
stain_initial_df = stain_df.copy()
stain_initial_df.set_index(['Condition', 'Marker', 'Laser'], inplace=True)
stain_initial_df[['Cont_min', 'Cont_max', 'Gamma']] = [0, 255, 1]
stain_complete_df=stain_initial_df.copy()

setup_path = f"{name_setup}_setup.csv"
if use_setup and os.path.exists(setup_path):
    stain_setup_df = pd.read_csv(setup_path)
    stain_setup_df.set_index(['Condition', 'Marker', 'Laser'], inplace=True)
    for idx in stain_complete_df.index:
        if idx in stain_setup_df.index:
            stain_complete_df.loc[idx] = stain_setup_df.loc[idx]
            stain_complete_df['Color'] = stain_initial_df['Color']
        else:
            use_setup = False

if not use_setup or not os.path.exists(setup_path):
    stain_complete_df=stain_initial_df.copy()
    settings.application.ipy_interactive = False
    viewer_1 = napari.Viewer()
    for c, idx in enumerate(stain_complete_df.index):
        im_channel = im_in[:,:,:,c]
        im_channel = ((im_channel - im_channel.min()) / (im_channel.max() - im_channel.min()) * 255).clip(0, 255).astype('uint8')
        viewer_1.add_image(im_channel, name=f"{idx[0]} ({idx[1]})", colormap=stain_initial_df.loc[idx]['Color'], blending='additive')
    napari.run()
    image_layers = [layer for layer in viewer_1.layers if isinstance(layer, napari.layers.Image)]
    contrast_limits = {layer.name: layer.contrast_limits for layer in image_layers}
    gamma_val = {layer.name: layer.gamma for layer in image_layers}
    stain_complete_df.sort_index(inplace=True)
    for c, idx in enumerate(stain_complete_df.index):
        name = f"{idx[0]} ({idx[1]})"
        stain_complete_df.loc[idx, 'Cont_min'] = int(contrast_limits[name][0])
        stain_complete_df.loc[idx, 'Cont_max'] = int(contrast_limits[name][1])
        stain_complete_df.loc[idx, 'Gamma'] = gamma_val[name]
    if os.path.exists(setup_path):
        stain_setup_df = pd.read_csv(setup_path)
        stain_setup_df.set_index(['Condition', 'Marker', 'Laser'], inplace=True)
        for idx in stain_complete_df.index:
            stain_setup_df.loc[idx] = stain_complete_df.loc[idx]
    else:
        stain_setup_df = stain_complete_df.copy()
    stain_csv_setup_df = stain_setup_df.reset_index().sort_values(by='Condition')
    stain_csv_setup_df = stain_csv_setup_df[['Condition', 'Marker', 'Laser', 'Cont_min', 'Cont_max', 'Gamma']]
    stain_csv_setup_df.to_csv(setup_path, index=False)

stain_df = stain_df.set_index('Condition')
stain_complete_df = stain_complete_df.reset_index().set_index('Condition')
stain_complete_df = stain_complete_df.loc[stain_df.index]
stain_complete_df = stain_complete_df[['Marker', 'Laser', 'Color', 'Cont_min', 'Cont_max', 'Gamma']]
original_stain_complete_df=stain_complete_df.copy()

In [None]:
# Display stain settings DataFrame
stain_complete_df

## IMAGE PROCESSING

In [None]:
# Load and normalize image data for all channels
im_in=im_final_stack['Original image'].copy()
im_out=im_in.copy()
for c in range(im_in.shape[3]):
    im_ori = im_in[:, :, :, c].copy()
    im_out[:, :, :, c] = ((im_ori - im_ori.min()) / (im_ori.max() - im_ori.min()) * 255).clip(0, 255).astype('uint8')

im_final_stack['Normalized image']=im_out.copy()

# Plot histogram for each channel
hist_plot(im_final_stack['Normalized image'], stain_complete_df)

In [None]:
# Adapt resolution to isotropic
im_in=im_final_stack['Normalized image'].copy()

im_out = np.zeros((round(np.shape(im_in)[0] * (zoom_factors[0])),round(np.shape(im_in)[1] * (zoom_factors[1])),round(np.shape(im_in)[2] * (zoom_factors[2])),np.shape(im_in)[3]))

# Compute zoom factors to get isotropic spacing (same as Y and X)
r_zX = meta.physical_pixel_sizes.X/zoom_factors[0]
r_zY = meta.physical_pixel_sizes.Y/zoom_factors[1]
r_zZ = meta.physical_pixel_sizes.Z/zoom_factors[2]

# Resample image to isotropic spacing
for c in range(im_in.shape[3]):
    im_out[:, :, :, c] = zoom(im_in[:, :, :, c], zoom=zoom_factors, order=1)
    im_out[:, :, :, c] = im_out[:, :, :, c] - np.min(im_out[:, :, :, c])

    im_out[:, :, :, c] = ((im_out[:, :, :, c] - im_out[:, :, :, c].min()) / (im_out[:, :, :, c].max() - im_out[:, :, :, c].min()) * 255).clip(0, 255).astype('uint8')

im_final_stack['Zoomed image']=im_out.copy()
hist_plot(im_final_stack['Zoomed image'], stain_complete_df)

In [None]:
# Noise removal using median filter
im_in = im_final_stack['Zoomed image'].copy()
for c in range(im_in.shape[3]):
    im_out[:, :, :, c] = filters.median(im_in[:, :, :, c])
im_final_stack['Denoised image'] = im_out.copy()
hist_plot(im_final_stack['Denoised image'], stain_complete_df)

In [None]:
# Contrast and gamma adjustment for each channel
im_in = im_final_stack['Denoised image'].copy()
for c in range(im_in.shape[3]):
    idx = stain_complete_df.index[c]
    im_out[:, :, :, c] = napari_contrast_gamma_uint8(im_in[:, :, :, c], (stain_complete_df.loc[idx, 'Cont_min'], stain_complete_df.loc[idx, 'Cont_max']), stain_complete_df.loc[idx, 'Gamma'])
    
im_final_stack['Adjusted image'] = im_out.copy()
hist_plot(im_final_stack['Adjusted image'], stain_complete_df)

In [None]:
# Gaussian filter for smoothing
im_in = im_final_stack['Adjusted image'].copy()
for c in range(im_in.shape[3]):
    im_out[:, :, :, c] = filters.gaussian(im_in[:, :, :, c], 0.5, preserve_range=True)

im_final_stack['Filtered image'] = im_out.astype('uint8')
hist_plot(im_final_stack['Filtered image'], stain_complete_df)

In [None]:
# Export histograms
output_path=Path(input_file).stem + '_histograms.xlsx'
im_in = im_final_stack['Adjusted image'].copy()

with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer:
    for c in range(im_in.shape[3]):
        # Example input: 3D array (e.g. image stack)
        im3d = im_in[:, :, :, c].copy()

        # Compute histogram
        values, counts = np.unique(im3d.astype('int'), return_counts=True)
        hist = np.zeros(256, dtype=int)
        hist[values] = counts

        # Calculate totals, percentages, and cumulative values
        total = hist.sum()
        percentage = (hist / total) * 100
        cumulative = np.cumsum(hist)
        cumulative_percentage = np.cumsum(percentage)

        # Build DataFrame
        df = pd.DataFrame({
            "Pixel_Value": np.arange(256),
            "Count": hist,
            "Percentage": percentage,
            "Cumulative_Count": cumulative,
            "Cumulative_Percentage": cumulative_percentage
        })

        idx = stain_complete_df.index[c]
        marker = stain_complete_df.loc[idx, 'Marker']

        # Write each to a different sheet
        df.to_excel(writer, sheet_name=marker, index=False)
    
print(f"Saved to: {output_path}")

In [None]:
# Histogram equalization, supporting thresholding
im_in = im_final_stack['Filtered image'].copy()
for c in range(im_in.shape[3]):
    im_out[:, :, :, c] = double_plateau_hist_equalization_nd(im_in[:, :, :, c].astype('uint8'), num_plateaus=2, plateau_factor=0.7)
    im_ori = im_out[:, :, :, c].copy()
    im_out[:, :, :, c] = ((im_ori - im_ori.min()) / (im_ori.max() - im_ori.min()) * 255).clip(0, 255).astype('uint8')

im_final_stack['Equalized image'] = im_out.copy()
hist_plot(im_final_stack['Equalized image'], stain_complete_df)

In [None]:
# Thresholding using Otsu, Sauvola, statistical background, gain filtering
im_in = im_final_stack["Equalized image"].copy()
im_out = im_in.copy()

# Sizes
nuclei_size = int(nuclei_diameter / (np.mean([r_zX, r_zY])))
cell_size = int(cell_diameter / (np.mean([r_zX, r_zY])))

for c in range(im_in.shape[3]):
    img = sitk.GetImageFromArray(im_in[:, :, :, c])

    # Stretch for Otsu
    rescaler = sitk.RescaleIntensityImageFilter()
    rescaler.SetOutputMinimum(0)
    rescaler.SetOutputMaximum(255)
    stretched = rescaler.Execute(img)


    # Otsu thresholds
    th_filter = sitk.OtsuThresholdImageFilter()
    _ = th_filter.Execute(stretched)
    otsu_value = th_filter.GetThreshold()

    _ = th_filter.Execute(img)
    otsu_value2 = th_filter.GetThreshold()

    if stain_complete_df.index[c] == "NUCLEI":
        window_size = 1 * nuclei_size #+ 1
    else:
        window_size = 4 * cell_size + 1

    # Convert to array
    arr = sitk.GetArrayFromImage(img) #.astype(np.float32)

    # Sauvola threshold map
    sauvola_value = threshold_sauvola(arr, window_size=int(window_size))

    # -------- GLOBAL statistical background, excluding zeros --------
    non_zero = arr[arr > 0]

    if non_zero.size > 0:
        hist, bins = np.histogram(non_zero, bins=256, range=(0, non_zero.max()))
        mode_bin = bins[np.argmax(hist)]
        print(mode_bin)
        bg_mask = (arr >= mode_bin - 5) & (arr <= mode_bin + 5) & (arr > 0) 
        gain_tot=6.0

        gain_ass=gain_tot*(255.0-4.0*mode_bin)/255.0
        bg_vals = arr[bg_mask]

        if bg_vals.size < 50:
            p10 = np.percentile(non_zero, 10)
            bg_vals = non_zero[non_zero <= p10]
    else:
        bg_vals = arr

    bg_mean = bg_vals.mean()
    bg_std = bg_vals.std() + 1e-6

    bg_mean_z = arr.mean()
    bg_std_z = arr.std() + 1e-6
    z = 3.0
    statistical_thr = bg_mean_z + z * bg_std_z

    # ---------------------------------------------------------------

    # 1) Soften Sauvola if it is too aggressive for large/bright cells
    # Clip Sauvola so it cannot exceed a few std above the global (zero‑including) mean
    max_sauvola = bg_mean_z + 2.0 * bg_std_z
    sauvola_clipped = np.minimum(sauvola_value, max_sauvola)

    # Final combined threshold map (slightly less Sauvola weight)
    final_thr = (
        0.60 * sauvola_clipped +
        0.25 * statistical_thr +
        0.15 * otsu_value2
    )

    # Extra improvement: intensity gain check (global, using non-zero-based bg_mean)
    gain = arr / (bg_mean + 1e-6)
    mask_gain = gain > gain_ass    # tune 2–5 depending on SNR

    # 2) Rescue pixels: strong gain but slightly under final_thr
    primary = (arr > final_thr) & mask_gain
    rescue = (gain > (gain_ass+3.0)) & (arr > statistical_thr)   # gain threshold > primary, to keep it conservative

    arrayseg = primary | rescue

    if stain_complete_df.index[c] != 'NUCLEI':
        min_size = np.ceil(0.8 * np.pi * ((nuclei_size / 2) ** 2))
    else:
        min_size= np.ceil(0.4 * np.pi * ((nuclei_size / 2) ** 2))

    # Remove small islands
    
    im_out[:, :, :, c] = remove_small_islands(arrayseg, min_size)

im_final_stack["Threshold image"] = im_out.copy()

In [None]:
# Segmentation of nuclei using watershed or StarDist
from skimage.segmentation import relabel_sequential

im_in=im_final_stack['Threshold image'].copy()

if 'NUCLEI' in stain_df.index:
   
    for c in range(im_in.shape[3]):
        if stain_complete_df.index[c] == 'NUCLEI':
            if trig_stardist:
                im_in=im_final_stack['Filtered image'].copy()
                transl=stardist3d_from_2d(img_3d=im_in[:,:,:,c],nucleus_radius=nuclei_diameter/2.0,voxel_size=(r_zZ, r_zY, r_zX))
                im_mask = transl>0
                im_mask = morphology.binary_erosion(im_mask, footprint=np.ones((2, 2, 2))).astype(im_mask.dtype)
                im_out,num = label((transl * im_mask)>0)
            else:
                binary_mask = im_in[:, :, :, c].copy()
                _, true_markers = shrink_to_markers_robust(binary_mask)
                distance = ndi.distance_transform_edt(binary_mask, sampling=[r_zZ, r_zY, r_zX])
                im_out = watershed(-distance, true_markers, mask=binary_mask)
                _,im_out = remove_small_island_labels(im_out, connectivity=1, size_ratio_thresh=0.5)
                im_out, _, _ = relabel_sequential(im_out)
else:
    im_thresh = np.zeros_like(im_in[:,:,:,0], dtype=np.int32)
    for c in range(im_in.shape[3]):
        binary_mask = im_in[:, :, :, c].copy()
        im_thresh = im_thresh | (binary_mask>0)

        stain_complete_df.loc['NUCLEI']=['', '', '', '', '', '']
        
    im_out= skimage.measure.label(im_thresh)
                
im_segmentation_stack={'Nuclei': im_out, 'Cytoplasm': np.zeros_like(im_out), 'PCM': np.zeros_like(im_out)} 

cm_rand = np.random.rand(int(np.max(im_segmentation_stack['Nuclei'])), 3)
cm_rand[0, :] = [0.0, 0.0, 0.0]
colormaps_rand = Colormap(cm_rand)

In [None]:
# Assign segmented nuclei labels to other channels (cell assignment)
im_in=im_final_stack['Threshold image'].copy()

for c in range(im_in.shape[3]):
    im_segmentation_stack[stain_df.index[c]] = im_in[:, :, :, c] * im_segmentation_stack['Nuclei']

In [None]:
# Visualize original, denoised, filtered, corrected, thresholded, assigned, and segmented images
viewer_0 = napari.Viewer()
scale_zoom=(r_zZ, r_zY, r_zX)

for c in range(im_in.shape[3]):
    idx = stain_complete_df.index[c]
    marker = stain_complete_df.loc[idx, 'Marker']
    color = stain_complete_df['Color'].iloc[c]
    #viewer_0.add_image(im_final_stack['Normalized image'], name=f'NORMALIZED {idx} ({marker})', colormap=color, blending='additive')
    viewer_0.add_image(im_final_stack['Original image'][:, :, :, c], name=f'ORIGINAL {idx} ({marker})', colormap=color, blending='additive', scale=[r_Z, r_Y, r_X])
    viewer_0.add_image(im_final_stack['Zoomed image'][:, :, :, c], name=f'ZOOMED {idx} ({marker})', colormap=color, blending='additive', scale=scale_zoom)
    viewer_0.add_image(im_final_stack['Denoised image'][:, :, :, c], name=f'DENOISED {idx} ({marker})', colormap=color, blending='additive', scale=scale_zoom)
    viewer_0.add_image(im_final_stack['Adjusted image'][:, :, :, c], name=f'CORRECTED {idx} ({marker})', colormap=color, blending='additive', scale=scale_zoom)
    viewer_0.add_image(im_final_stack['Filtered image'][:, :, :, c], name=f'FILTERED {idx} ({marker})', colormap=color, blending='additive', scale=scale_zoom)
    viewer_0.add_image(im_final_stack['Equalized image'][:, :, :, c], name=f'EQ {idx} ({marker})', colormap=color, blending='additive', scale=scale_zoom)
    viewer_0.add_image(im_final_stack['Threshold image'][:, :, :, c].astype('uint8'), name=f'THRESHOLD {idx} ({marker})', contrast_limits=[0, 1], colormap=color, blending='additive', scale=scale_zoom)    
    
viewer_0.add_labels(im_segmentation_stack['Nuclei'].astype('uint8'), name=f'NUCLEI', blending='additive', scale=scale_zoom) #, colormap=colormaps_rand, contrast_limits=[0, np.max(im_nuclei_segmented)], blending='additive')
viewer_0.scale_bar.visible = True
viewer_0.scale_bar.unit = 'um'

if ('NUCLEI' in stain_complete_df.index)|('CYTOPLASM' in stain_complete_df.index):
    viewer_1 = napari.Viewer()

    im_in=im_final_stack['Threshold image'].copy()
    
    for c in range(len(stain_df.index)):
        idx = stain_complete_df.index[c]
        marker = stain_complete_df.loc[idx, 'Marker']
        # viewer_1.add_labels(im_segmentation_stack['Nuclei'].astype('uint8'), name=f'{idx} ({marker})', blending='additive', scale=scale_zoom) #, colormap=colormaps_rand, contrast_limits=[0, np.max(im_nuclei_segmented)], blending='additive')
        viewer_1.add_labels(im_segmentation_stack[stain_df.index[c]].astype('uint8'), name=f'{idx} ({marker})', blending='additive', scale=scale_zoom) #, colormap=colormaps_rand, contrast_limits=[0, np.max(im_nuclei_segmented)], blending='additive')
    viewer_1.scale_bar.visible = True
    viewer_1.scale_bar.unit = 'um'

## QUANTIFICATION

# 6. Quantification and Analysis

This section quantifies nuclei and cell properties, computes statistics, and visualizes distributions. Results are exported for further analysis.

In [None]:
labels_dict = {}
#nuc_positions=np.zeros((1,np.max(im_segmentation_stack['Nuclei'])+1))[0]
nuc_sizes=np.zeros((1,np.max(im_segmentation_stack['Nuclei'])+1))[0]
nuc_position = [(0.0, 0.0, 0.0) for _ in range(np.max(im_segmentation_stack['Nuclei'])+1)]

r_xyz = (r_zX, r_zY, r_zZ)
zooms = zoom_factors

for n in list(np.unique(im_segmentation_stack['Nuclei']))[1:]:
    nuc_sizes[n]=np.sum(im_segmentation_stack['Nuclei']==n)*r_zX*r_zY*r_zZ
    z,y,x = np.where(im_segmentation_stack['Nuclei']==n)
    nuc_position[n]=(np.mean(x)*r_zX, np.mean(y)*r_zY, np.mean(z)*r_zZ)

for c in range(len(stain_df.index)):
    condition = stain_complete_df.index[c]
    if condition in ['NUCLEI', 'CYTOPLASM', 'PCM']:
        continue
    marker = stain_complete_df.index[c]

    shared_labels = list(np.unique(im_segmentation_stack[marker]).astype('int'))[1:]

    labels_dict[marker] = [
        condition,
        stain_complete_df['Laser'][c],
        stain_complete_df['Color'][c],
        len(shared_labels),
        tuple(sorted(shared_labels)),
        tuple(nuc_position[int(i)] for i in shared_labels) if len(shared_labels)>0 else (),
        tuple(nuc_sizes[int(i)] for i in shared_labels) if len(shared_labels)>0 else ()
    ]

In [None]:
# Create DataFrame for quantification results and truncate long values for display
labels_df = pd.DataFrame.from_dict(labels_dict, orient='index', columns=['Condition', 'Laser', 'Color', 'Number', 'Shared labels', 'Mean nuclei positions [um]', 'Nuclei size [um3]'])
labels_df.index.name = 'Combination'

In [None]:
# Display quantification DataFrame
labels_df

## Evaluate cell distribution in the space

In [None]:
# Plot spatial distribution of nuclei and cells
im_in=im_final_stack['Filtered image']

fig, axs = plt.subplots(3, 1, figsize=(15, 15))
for i, marker in enumerate(labels_df.index):   
    xcoor = [t[0] for t in labels_df['Mean nuclei positions [um]'][i]]
    ycoor = [t[1] for t in labels_df['Mean nuclei positions [um]'][i]]
    zcoor = [t[2] for t in labels_df['Mean nuclei positions [um]'][i]] 
    xcount, xbins = np.histogram(xcoor, range=(0, im_in.shape[2] * r_X /zoom_factors[2]), bins=30)
    ycount, ybins = np.histogram(ycoor, range=(0, im_in.shape[1] * r_Y /zoom_factors[1]), bins=30)
    zcount, zbins = np.histogram(zcoor, range=(0, im_in.shape[0] * r_Z /zoom_factors[0]), bins=30)
    xbin_centers = (xbins[:-1] + xbins[1:]) / 2
    ybin_centers = (ybins[:-1] + ybins[1:]) / 2
    zbin_centers = (zbins[:-1] + zbins[1:]) / 2
    if (np.size(marker)==1):
        color = stain_complete_df.loc[str(labels_df['Condition'][i])]['Color']
        if color == '':
            color='BLUE'
        if (labels_df['Condition'][i]!='NUCLEI'):
            axs[0].plot(xbin_centers,xcount,label=str(labels_df['Condition'][i]),color=color)
            axs[1].plot(ybin_centers,ycount,label=str(labels_df['Condition'][i]),color=color)
            axs[2].plot(zbin_centers,zcount,label=str(labels_df['Condition'][i]),color=color)
    else:
        rgb_list=[]
        for k in range(np.size(marker)):
            if stain_df.loc[(labels_df['Condition'][i][k])]['Color']!='WHITE':
                rgb_list.append(stain_complete_df.loc[(labels_df['Condition'][i][k])]['Color'])
            else:
                rgb_list.append('GRAY')

        colors_rgb = [to_rgb(name) for name in rgb_list]

        r_total, g_total, b_total = 0.0, 0.0, 0.0

        for r, g, b in colors_rgb:
            r_total += r
            g_total += g
            b_total += b
        
        r_final = min(r_total, 1.0)
        g_final = min(g_total, 1.0)
        b_final = min(b_total, 1.0)
    
        final_rgb = (r_final, g_final, b_final)
        
        axs[0].plot(xbin_centers,xcount,label=str(labels_df['Condition'][i]),linestyle=(0, (2, np.size(marker)-1)), color=final_rgb)
        axs[1].plot(ybin_centers,ycount,label=str(labels_df['Condition'][i]),linestyle=(0, (2, np.size(marker)-1)), color=final_rgb)
        axs[2].plot(zbin_centers,zcount,label=str(labels_df['Condition'][i]),linestyle=(0, (2, np.size(marker)-1)), color=final_rgb)
        
axs[0].set_title('NUCLEI X DISTRIBUTION')
axs[0].set_xlabel('[μm]')
axs[0].legend(loc='upper right')
axs[0].set_facecolor('black')
axs[1].set_title('NUCLEI Y DISTRIBUTION')
axs[1].set_xlabel('[μm]')
axs[1].legend(loc='upper right')
axs[1].set_facecolor('black')
axs[2].set_title('NUCLEI Z DISTRIBUTION')
axs[2].set_xlabel('[μm]')
axs[2].legend(loc='upper right')
axs[2].set_facecolor('black')

## Evaluate cell size distribution

In [None]:
# Plot size distribution of nuclei and cells
nuclei_max_size = max(x for t in labels_df['Nuclei size [um3]'] for x in t)
#cytoplasm_max_size = max(x for t in labels_df['Cytoplasm size [um3]'] for x in t)
for i, marker in enumerate(labels_df.index):
    nuclei_sizes = list(labels_df['Nuclei size [um3]'][i])
    if np.size(marker)==1:
        if stain_complete_df.loc[(labels_df['Condition'][i])]['Color']=='':
            color = 'BLUE'
        else:
            if stain_complete_df.loc[(labels_df['Condition'][i])]['Color']!='WHITE':
                color = stain_complete_df.loc[str(labels_df['Condition'][i])]['Color']
            else:
                color = 'GRAY'
        #color = stain_df.loc[str(labels_df['Condition'][i])]['Color']
        #axs[0].hist(nuclei_sizes, range=(0, nuclei_max_size), bins=30, label=str(labels_df['Condition'][i]), alpha=1/len(labels_df), color=color)
    else:
        rgb_list=[]
        for k in range(np.size(marker)):
            if stain_df.loc[(labels_df['Condition'][i][k])]['Color']!='WHITE':
                rgb_list.append(stain_complete_df.loc[(labels_df['Condition'][i][k])]['Color'])
            else:
                rgb_list.append('GRAY')

        colors_rgb = [to_rgb(name) for name in rgb_list]

        r_total, g_total, b_total = 0.0, 0.0, 0.0

        for r, g, b in colors_rgb:
            r_total += r
            g_total += g
            b_total += b
        
        r_final = min(r_total, 1.0)
        g_final = min(g_total, 1.0)
        b_final = min(b_total, 1.0)
    
        color = (r_final, g_final, b_final)
           
    plt.hist(nuclei_sizes, range=(0, nuclei_max_size), bins=30, label=str(labels_df['Condition'][i]), alpha=1/len(labels_df), color=color)
plt.title('NUCLEI SIZE DISTRIBUTION')
plt.xlabel('[μm3]')
plt.legend(loc='upper right')

### Create a complete report XSL

In [None]:
# Export quantification results to Excel file
with pd.ExcelWriter(Path(input_file).stem + '_segmentation.xlsx', engine='xlsxwriter') as writer:
    original_stain_complete_df.to_excel(writer, sheet_name='Staining', index=True)
    for i, marker in enumerate(labels_df.index):
        xlsx_dict = {}
        columns = ['ID', 'X position [um]', 'Y position [um]', 'Z position [um]', 'Nuclei size [um3]']
        for j in range(len(list(labels_df['Shared labels'][marker]))):
            row = [labels_df['Shared labels'][marker][j], labels_df['Mean nuclei positions [um]'][marker][j], labels_df['Nuclei size [um3]'][marker][j]]
            row = [row[0], row[1][0], row[1][1], row[1][2], row[2]]
            xlsx_dict[j] = row
        xlsx_df = pd.DataFrame.from_dict(xlsx_dict, orient='index', columns=columns)
        xlsx_df.to_excel(writer, sheet_name=marker, index=False)  
    resume_df = labels_df.copy()
    resume_df['Laser'] = [
        labels_df['Laser'][t] if (np.size(labels_df['Condition'][t])==1) else ''
        for t in range(len(labels_df))
    ]
    resume_df['Color'] = [
        labels_df['Color'][t] if (np.size(labels_df['Condition'][t])==1) else ''
        for t in range(len(labels_df))
    ]
    resume_df['%'] = [
        100.0 * labels_df['Number'][t] / labels_df['Number'][0] if labels_df['Condition'][t] != 'NUCLEI' else ''
        for t in range(len(labels_df))
    ]
    resume_df['Mean nuclei size [um3]'] = [np.mean(t) for t in labels_df['Nuclei size [um3]']]
    resume_df.to_excel(writer, sheet_name='RECAP', index=True)

# CREATE .inp FOR FINITE ELEMENT ANALYSIS

In [None]:
simpleVolume = mrn.simpleVolumeFrom3Darray(np.float32(im_segmentation_stack['Nuclei']))
floatGrid = mr.simpleVolumeToDenseGrid(simpleVolume)
mesh_stl = mr.gridToMesh(floatGrid , mr.Vector3f(1.0,1.0,1.0), 0.5)

outVerts = mrn.getNumpyVerts(mesh_stl)
#print(outVerts)

outFaces = mrn.getNumpyFaces(mesh_stl.topology)

tet = tetgen.TetGen(outVerts,outFaces)
nodes,elems=tet.tetrahedralize(order=1, mindihedral=20, minratio=1.5)

tet.write('FE_segmentation_full.vtk', binary=False)

In [None]:
meshel = meshio.read('FE_segmentation_full.vtk')
meshel.write('FE_segmentation.inp')

for c in range(1, np.max(im_segmentation_stack['Nuclei'])+1):
    globals()[str(c)+'cell_el']=[]

for ce, x in enumerate(elems):
    #print(np.shape(np.uint16(np.mean(nodes[x],0))))
    coord=np.int16(np.round(np.mean(nodes[x],0),0))
    step=0
    taken=False
    while not(taken):
        step+=1
        coord[coord<step]=1
        for k in [0,1,2]:
            if coord[k]>=np.shape(im_segmentation_stack['Nuclei'])[k]+1-step:coord[k]=np.shape(im_segmentation_stack['Nuclei'])[k]-1
        elemlist=im_segmentation_stack['Nuclei'][coord[0]-step:coord[0]+1+step,coord[1]-step:coord[1]+1+step,coord[2]-step:coord[2]+1+step].flatten()
        #print(elemlist)
        if sum(elemlist)>0:
            c_el=st.mode(elemlist[elemlist!=0])
            taken=True

    #print(c_el)
    if c_el!=0:
        globals()[str(c_el)+'cell_el'].append(ce+1)

f = open("FE_segmentation.inp", "a")
for c in range(1,np.max(im_segmentation_stack['Nuclei'])+1):
    f.write("*Elset, elset=cell" + str(c) + "\n")
    j=1
    for t in range(1, np.size(globals()[str(c)+'cell_el'])):
        f.write(str(globals()[str(c)+'cell_el'][t]) + ",")
        j+=1
        if j>16:
            f.write("\n")
            j=1
    f.write("\n")

    
f.close()

In [None]:
# Now insert *PART header manually
with open("FE_segmentation.inp", "r") as f:
    lines = f.readlines()

with open(Path(input_file).stem + "_FEA.inp", "w") as f:
    for line in lines:
        if (line=="*NODE\n"):
            f.write("*PART, name=Part-1\n")
        f.write(line)
    f.write("*END PART\n")