# Imports

In [1]:
import tifffile as tiff
import numpy as np
from scipy import ndimage
import tempfile
from tqdm import tqdm
import os
import plotly.graph_objs as go
import napari
from joblib import Parallel, delayed
from skimage import morphology, exposure, transform, feature
from multiprocessing import Pool, cpu_count
from scipy.spatial.distance import cdist
from scipy.ndimage import distance_transform_edt, binary_erosion
from skimage.feature import peak_local_max
from skimage.segmentation import watershed
from scipy import ndimage as ndi
from magicgui import magicgui
from skimage.measure import label
from tkinter import filedialog, Tk
from itertools import product

from scipy import optimize
from skimage import exposure
from scipy.ndimage import gaussian_filter1d

#Import time
import time

import yaml
import shutil
from typing import Dict, Any
from functools import partial


# Helper functions

In [2]:
def downsample_stack(image_stack, factor=2):
    """Downsample a 3D image stack by a given factor."""
    return image_stack[:, ::factor, ::factor]

def upsample_stack(labeled_stack, original_shape):
    """
    Upsamples the labeled stack to the original shape.
    """
    if labeled_stack.ndim == 3:
        return np.array([transform.resize(labeled_stack[z], original_shape[1:], 
                                      order=0, preserve_range=True, anti_aliasing=False) 
                     for z in range(labeled_stack.shape[0])], dtype=np.int32)
    elif labeled_stack.ndim == 2:
        return transform.resize(labeled_stack, original_shape, order=0, preserve_range=True, anti_aliasing=False)
    else:
        raise ValueError("Input stack must be either 2D or 3D.")


def interpolate_settings(z_position, z_size, slice_settings):
    """
    Interpolate settings for a given z position based on slice settings.
    """
    if len(slice_settings) >= 2:
        z_indices = np.array(sorted(slice_settings.keys()))
        # Scale z_indices to match the full volume size
        z_indices = (z_indices * z_size / max(z_indices)).astype(int)
        
        # Get settings for interpolation
        intensity_thresholds = np.array([
            slice_settings[z].get("intensity_threshold", 0.5) 
            for z in sorted(slice_settings.keys())
        ])
        min_volumes = np.array([
            slice_settings[z].get("min_volume", 500) 
            for z in sorted(slice_settings.keys())
        ])
        
        # Interpolate values for the specific z position
        intensity_threshold = np.interp(z_position, z_indices, intensity_thresholds)
        min_volume = np.interp(z_position, z_indices, min_volumes)
    else:
        # Default values if insufficient slice settings
        intensity_threshold = 0.5
        min_volume = 500
        
    return {
        "intensity_threshold": intensity_threshold,
        "min_volume": min_volume
    }

def get_volume_slices(shape, block_size, overlap):
    """
    Generate overlapping volume coordinates for parallel processing.
    Returns list of tuples (z_start, z_end, y_start, y_end, x_start, x_end)
    """
    print("Generating volume slices...")
    slices = []
    for dim_size, block_dim, overlap_dim in zip(shape, block_size, overlap):
        starts = list(range(0, dim_size - overlap_dim, block_dim - overlap_dim))
        if starts[-1] + block_dim < dim_size:
            starts.append(dim_size - block_dim)
        slices.append([(start, start + block_dim) for start in starts])
    
    coords = list(product(slices[0], slices[1], slices[2]))
    print(f"Created {len(coords)} subvolumes for processing")
    return coords

def process_subvolume(args):
    """
    Process a single subvolume using 3D watershed segmentation with depth-adaptive parameters.
    """
    volume_data, coords = args
    z_start, z_end = coords[0]
    y_start, y_end = coords[1]
    x_start, x_end = coords[2]
    
    # Get subvolume
    subvolume = volume_data['volume'][z_start:z_end, y_start:y_end, x_start:x_end]
    
    # Get depth-adaptive settings for the middle of this subvolume
    z_mid = (z_start + z_end) // 2
    settings = interpolate_settings(
        z_mid, 
        volume_data['volume'].shape[0],
        volume_data['slice_settings']
    )
    
    # Find local maxima with adaptive threshold
    local_max_coords = feature.peak_local_max(
        subvolume,
        min_distance=int(settings['min_volume']/3),
        threshold_rel=settings['intensity_threshold'],
        exclude_border=False
    )
    
    local_max = np.zeros_like(subvolume, dtype=bool)
    if len(local_max_coords) > 0:
        local_max[tuple(local_max_coords.T)] = True
    
    struct = ndimage.generate_binary_structure(3, 2)
    mask = subvolume > (settings['intensity_threshold'] * np.max(subvolume))
    distance = ndimage.distance_transform_edt(mask)
    markers = label(local_max)
    labels = watershed(-distance, markers, mask=mask)
    filtered_labels = morphology.remove_small_objects(labels, min_size=settings['min_volume'])
    
    return coords, filtered_labels

def make_labels_unique(results, start_label=1):
    """
    Ensure all labels across all blocks are unique by adding offsets.
    Returns modified results and the next available label.
    """
    print("Making labels unique across blocks...")
    new_results = []
    current_label = start_label
    
    for coords, labels in tqdm(results, desc="Relabeling blocks"):
        max_label = np.max(labels) if labels.any() else 0
        if max_label > 0:
            # Create mapping for this block's labels
            unique_labels = np.unique(labels)
            unique_labels = unique_labels[unique_labels > 0]
            label_map = {old: new for old, new in zip(unique_labels, 
                                                     range(current_label, current_label + len(unique_labels)))}
            
            # Apply mapping
            new_labels = np.zeros_like(labels)
            for old, new in label_map.items():
                new_labels[labels == old] = new
            
            new_results.append((coords, new_labels))
            current_label += len(unique_labels)
        else:
            new_results.append((coords, labels))
    
    return new_results, current_label

def find_overlapping_labels(label_array1, label_array2):
    """
    Find corresponding labels in overlapping regions between two arrays.
    Returns a dictionary mapping labels from array2 to array1.
    """
    mapping = {}
    
    # Find overlapping non-zero labels
    mask1 = label_array1 > 0
    mask2 = label_array2 > 0
    overlap_mask = mask1 & mask2
    
    if overlap_mask.any():
        labels1 = label_array1[overlap_mask]
        labels2 = label_array2[overlap_mask]
        
        # Count occurrences of each label pair
        unique_pairs, counts = np.unique(np.vstack((labels1, labels2)), axis=1, return_counts=True)
        
        # Sort by count to handle multiple overlaps
        sort_idx = np.argsort(-counts)
        unique_pairs = unique_pairs[:, sort_idx]
        
        # Create mapping from label2 to most overlapping label1
        for label1, label2 in unique_pairs.T:
            if label2 not in mapping:
                mapping[label2] = label1
    
    return mapping

def stitch_volumes(results, full_shape, overlap):
    """
    Stitch overlapping subvolumes ensuring consistent labeling across blocks.
    """
    print("Starting volume stitching...")
    
    # First make all labels unique across blocks
    results, _ = make_labels_unique(results)
    
    # Initialize final volume
    final_labels = np.zeros(full_shape, dtype=np.int32)
    label_mappings = {}
    
    # First pass: Place blocks and find overlapping regions
    print("First pass: Finding overlapping regions...")
    for coords, labels in tqdm(results, desc="Processing blocks"):
        z_start, z_end = coords[0]
        y_start, y_end = coords[1]
        x_start, x_end = coords[2]
        
        # Check overlaps with existing labels
        if z_start > 0:  # Check overlap with previous z block
            overlap_slice = slice(z_start, z_start + overlap[0])
            existing = final_labels[overlap_slice, y_start:y_end, x_start:x_end]
            new = labels[:overlap[0], :, :]
            mapping = find_overlapping_labels(existing, new)
            label_mappings.update(mapping)
        
        if y_start > 0:  # Check overlap with previous y block
            overlap_slice = slice(y_start, y_start + overlap[1])
            existing = final_labels[z_start:z_end, overlap_slice, x_start:x_end]
            new = labels[:, :overlap[1], :]
            mapping = find_overlapping_labels(existing, new)
            label_mappings.update(mapping)
            
        if x_start > 0:  # Check overlap with previous x block
            overlap_slice = slice(x_start, x_start + overlap[2])
            existing = final_labels[z_start:z_end, y_start:y_end, overlap_slice]
            new = labels[:, :, :overlap[2]]
            mapping = find_overlapping_labels(existing, new)
            label_mappings.update(mapping)
        
        # Place the block in the final volume
        final_labels[z_start:z_end, y_start:y_end, x_start:x_end] = labels
    
    # Optional: Relabel to ensure consecutive labels
    print("Relabeling to ensure consecutive labels...")
    unified_labels = label(final_labels > 0)
    
    return unified_labels


def intensity_based_segmentation(image_stack, slice_settings=None, temp_dir=None, downsample_factor=2):
    """
    3D cell segmentation using parallel processing of overlapping subvolumes with depth-adaptive parameters.
    """
    print("Initializing 3D segmentation...")
    if slice_settings is None:
        slice_settings = {}
    os.makedirs(temp_dir, exist_ok=True)

    # Default block processing settings
    default_block_settings = {
        "block_size": (64, 64, 64),
        "overlap": (16, 16, 16)
    }
    
    print("Preprocessing image stack...")
    print("Enhancing contrast...")
    enhanced_stack = adaptive_contrast_enhancement(image_stack)
    print("Downsampling...")
    downsampled_stack = downsample_stack(enhanced_stack, factor=downsample_factor)
    
    
    volume_coords = get_volume_slices(
        downsampled_stack.shape,
        default_block_settings["block_size"],
        default_block_settings["overlap"]
    )
    
    # Package data for parallel processing
    volume_data = {
        'volume': downsampled_stack,
        'slice_settings': slice_settings  # Pass full slice settings for interpolation
    }
    
    num_workers = min(cpu_count(), len(volume_coords))
    print(f"Starting parallel processing with {num_workers} workers...")
    
    with Pool(num_workers) as pool:
        results = list(tqdm(
            pool.imap(process_subvolume, [(volume_data, coord) for coord in volume_coords]),
            total=len(volume_coords),
            desc="Processing subvolumes"
        ))
    
    print("Stitching subvolumes...")
    labeled_volume = stitch_volumes(results, downsampled_stack.shape, default_block_settings["overlap"])
    
    if downsample_factor > 1:
        print("Upsampling results...")
        labeled_volume = upsample_stack(labeled_volume, image_stack.shape)
    
    print("Saving results...")
    save_path = os.path.join(temp_dir, "segmented_cells.npy")
    np.save(save_path, labeled_volume)
    print(f"Results saved to: {save_path}")
    
    return labeled_volume


def get_depth_weight(z, z_depth, edge_protection=0.2):
    """
    Calculate weight for enhancement strength based on depth position.
    Reduces enhancement strength at the edges of the stack.
    
    Parameters:
    -----------
    z : int
        Current z position
    z_depth : int
        Total depth of stack
    edge_protection : float
        Controls how quickly enhancement strength drops at edges (0-1)
        Higher values mean more conservative enhancement at edges
    """
    # Convert to 0-1 range
    rel_pos = z / (z_depth - 1)
    
    # Create a curve that drops off at the edges
    # Using smooth cosine falloff
    edge_weight = np.cos((rel_pos - 0.5) * np.pi) * 0.5 + 0.5
    
    # Apply edge protection factor
    edge_weight = edge_weight ** edge_protection
    
    return edge_weight

def fit_percentile_curve(image_stack, percentile, window_size=5, smoothing_sigma=1.0):
    """
    Fits a smooth curve to intensity percentiles across z-depth.
    """
    z_depth = image_stack.shape[0]
    percentile_values = np.zeros(z_depth)
    
    # Calculate percentiles with moving window
    pad_size = window_size // 2
    padded_stack = np.pad(image_stack, ((pad_size, pad_size), (0, 0), (0, 0)), mode='reflect')
    
    for z in range(z_depth):
        window = padded_stack[z:z + window_size]
        percentile_values[z] = np.percentile(window, percentile)
    
    # Smooth the percentile curve
    smoothed_values = gaussian_filter1d(percentile_values, smoothing_sigma)
    
    return smoothed_values

def fit_depth_decay(z_values, intensity_values):
    """
    Fits an exponential decay curve to intensity values across depth.
    """
    def exp_decay(z, amplitude, decay_rate, offset):
        return amplitude * np.exp(-decay_rate * z) + offset
    
    # Initial parameter guesses
    p0 = [
        np.max(intensity_values) - np.min(intensity_values),
        1/len(z_values),
        np.min(intensity_values)
    ]
    
    try:
        params, _ = optimize.curve_fit(exp_decay, z_values, intensity_values, p0=p0)
        return params
    except RuntimeError:
        return None

def adaptive_contrast_enhancement(image_stack, low_percentile=25, high_percentile=99.9, 
                                   window_size=5, smoothing_sigma=1.0, edge_protection=0.4):
    """
    Enhances contrast for each slice using depth-aware percentile curves with edge protection.
    
    Parameters:
    -----------
    image_stack : ndarray
        3D image stack
    low_percentile : float
        Lower percentile for contrast stretching
    high_percentile : float
        Upper percentile for contrast stretching
    window_size : int
        Size of moving window for percentile calculation
    smoothing_sigma : float
        Sigma for Gaussian smoothing of percentile values
    edge_protection : float
        Strength of edge protection (0-1). Higher values mean more conservative 
        enhancement at stack edges.
    """
    z_depth = image_stack.shape[0]
    z_values = np.arange(z_depth)
    
    # Get smoothed percentile curves
    low_curve = fit_percentile_curve(image_stack, low_percentile, window_size, smoothing_sigma)
    high_curve = fit_percentile_curve(image_stack, high_percentile, window_size, smoothing_sigma)
    
    # Fit decay curves
    low_params = fit_depth_decay(z_values, low_curve)
    high_params = fit_depth_decay(z_values, high_curve)
    
    if low_params is not None and high_params is not None:
        # Use fitted curves
        def exp_decay(z, amplitude, decay_rate, offset):
            return amplitude * np.exp(-decay_rate * z) + offset
        
        low_fitted = exp_decay(z_values, *low_params)
        high_fitted = exp_decay(z_values, *high_params)
    else:
        # Fallback to smoothed curves if fitting fails
        low_fitted = low_curve
        high_fitted = high_curve
    
    # Apply enhancement using fitted curves with edge protection
    enhanced_stack = np.zeros_like(image_stack, dtype=np.float32)
    global_min = np.min(image_stack)
    global_max = np.max(image_stack)
    
    for z in range(z_depth):
        # Get depth-dependent weight
        weight = get_depth_weight(z, z_depth, edge_protection)
        
        # Calculate enhanced and conservative intensity ranges
        p_low_enhanced = low_fitted[z]
        p_high_enhanced = high_fitted[z]
        
        # Blend between enhanced and conservative ranges based on weight
        p_low = global_min * (1 - weight) + p_low_enhanced * weight
        p_high = global_max * (1 - weight) + p_high_enhanced * weight
        
        # Apply enhancement
        enhanced_stack[z] = exposure.rescale_intensity(
            image_stack[z],
            in_range=(p_low, p_high),
            out_range=(0, 1)
        )
    
    return enhanced_stack

def validate_segment(segment, min_volume=100, min_solidity=0.5):
    """
    Validate a segment based on size and shape criteria.
    
    Parameters:
    -----------
    segment : np.ndarray
        Binary mask of the segment
    min_volume : int
        Minimum volume in voxels
    min_solidity : float
        Minimum solidity (ratio of volume to convex hull volume)
    
    Returns:
    --------
    bool
        True if segment passes validation
    """
    # Check volume
    volume = np.sum(segment)
    if volume < min_volume:
        return False
        
    # Check solidity
    # Get convex hull volume using bounding box fill ratio as approximation
    coords = np.where(segment)
    bbox_volume = (coords[0].max() - coords[0].min() + 1) * \
                 (coords[1].max() - coords[1].min() + 1) * \
                 (coords[2].max() - coords[2].min() + 1)
    solidity = volume / bbox_volume
    
    return solidity >= min_solidity

def process_single_mask(args):
    """
    Process a single mask for parallel execution with improved validation.
    """
    label, coords, raw_chunk, min_distance, min_intensity_ratio, min_volume, min_solidity = args
    
    # Extract mask bounds
    z_slice = slice(coords['z_min'], coords['z_max'])
    y_slice = slice(coords['y_min'], coords['y_max'])
    x_slice = slice(coords['x_min'], coords['x_max'])
    
    # Create mask in minimal volume
    mask = np.zeros_like(raw_chunk, dtype=bool)
    mask[coords['mask_indices']] = True
    
    # Skip small objects
    if not validate_segment(mask, min_volume, min_solidity):
        return label, mask, None
    
    # Get intensity image within current mask
    masked_intensity = np.where(mask, raw_chunk, 0)
    
    # Find local maxima
    max_intensity = np.max(masked_intensity)
    peaks = peak_local_max(
        masked_intensity,
        min_distance=min_distance,
        threshold_abs=max_intensity * min_intensity_ratio,
        exclude_border=False,
        labels=mask
    )
    
    # If only one peak, return original mask
    if len(peaks) <= 1:
        return label, mask, None
        
    # Prepare markers for watershed
    markers = np.zeros_like(mask, dtype=int)
    for i, peak in enumerate(peaks):
        markers[peak[0], peak[1], peak[2]] = i + 1
        
    # Apply watershed
    distance = ndimage.distance_transform_edt(mask)
    watershed_labels = watershed(-distance, markers, mask=mask)
    
    # Validate each watershed segment
    valid_labels = set()
    for i in range(1, watershed_labels.max() + 1):
        segment = watershed_labels == i
        if validate_segment(segment, min_volume, min_solidity):
            valid_labels.add(i)
    
    # If no valid segments after splitting, return original mask
    if len(valid_labels) <= 1:
        return label, mask, None
    
    # Create new watershed_labels with only valid segments
    new_watershed_labels = np.zeros_like(watershed_labels)
    new_label = 1
    for old_label in valid_labels:
        new_watershed_labels[watershed_labels == old_label] = new_label
        new_label += 1
    
    return label, mask, new_watershed_labels

def split_merged_masks(labeled_cells, raw_intensity, min_distance=50, 
                              min_intensity_ratio=0.3, min_volume=100, 
                              min_solidity=0.1, n_processes=None):
    """
    Parallel implementation of merged cell mask splitting with improved validation.
    
    Parameters:
    -----------
    labeled_cells : np.ndarray
        3D array where each unique value represents a cell mask
    raw_intensity : np.ndarray
        Original 3D intensity image
    min_distance : int
        Minimum distance between intensity peaks (in pixels)
    min_intensity_ratio : float
        Minimum ratio of peak intensity to max intensity
    min_volume : int
        Minimum volume for valid segments
    min_solidity : float
        Minimum solidity for valid segments
    n_processes : int, optional
        Number of processes to use
    """
    if n_processes is None:
        n_processes = max(1, cpu_count() - 1)
    
    print(f"Starting parallel processing with {n_processes} processes")
    print(f"Minimum volume threshold: {min_volume} voxels")
    print(f"Minimum solidity threshold: {min_solidity}")
    start_time = time.time()
    
    # Initialize output array
    output = np.zeros_like(labeled_cells)
    next_label = 1
    
    # Get all unique labels (excluding background)
    unique_labels = np.unique(labeled_cells)
    unique_labels = unique_labels[unique_labels > 0]
    
    print(f"Found {len(unique_labels)} cells to process")
    
    # Prepare arguments for parallel processing
    parallel_args = []
    for label in unique_labels:
        mask = labeled_cells == label
        z, y, x = np.where(mask)
        
        padding = max(5, min_distance)
        z_min, z_max = max(0, z.min() - padding), min(labeled_cells.shape[0], z.max() + padding + 1)
        y_min, y_max = max(0, y.min() - padding), min(labeled_cells.shape[1], y.max() + padding + 1)
        x_min, x_max = max(0, x.min() - padding), min(labeled_cells.shape[2], x.max() + padding + 1)
        
        local_z = z - z_min
        local_y = y - y_min
        local_x = x - x_min
        
        raw_chunk = raw_intensity[z_min:z_max, y_min:y_max, x_min:x_max]
        
        coords = {
            'z_min': z_min, 'z_max': z_max,
            'y_min': y_min, 'y_max': y_max,
            'x_min': x_min, 'x_max': x_max,
            'mask_indices': (local_z, local_y, local_x)
        }
        
        parallel_args.append((label, coords, raw_chunk, min_distance, 
                            min_intensity_ratio, min_volume, min_solidity))
    
    # Process in parallel with progress bar
    with Pool(n_processes) as pool:
        results = list(tqdm(
            pool.imap(process_single_mask, parallel_args),
            total=len(parallel_args),
            desc="Processing cells"
        ))
    
    # Combine results
    print("Combining results...")
    invalid_segments = 0
    for label, mask, watershed_result in results:
        if watershed_result is None:
            coords = parallel_args[label-1][1]
            z_slice = slice(coords['z_min'], coords['z_max'])
            y_slice = slice(coords['y_min'], coords['y_max'])
            x_slice = slice(coords['x_min'], coords['x_max'])
            
            output_mask = np.zeros_like(output[z_slice, y_slice, x_slice])
            output_mask[mask] = next_label
            output[z_slice, y_slice, x_slice] = np.where(
                output[z_slice, y_slice, x_slice] == 0,
                output_mask,
                output[z_slice, y_slice, x_slice]
            )
            next_label += 1
        else:
            coords = parallel_args[label-1][1]
            z_slice = slice(coords['z_min'], coords['z_max'])
            y_slice = slice(coords['y_min'], coords['y_max'])
            x_slice = slice(coords['x_min'], coords['x_max'])
            
            for i in range(1, watershed_result.max() + 1):
                output_mask = np.zeros_like(output[z_slice, y_slice, x_slice])
                output_mask[watershed_result == i] = next_label
                output[z_slice, y_slice, x_slice] = np.where(
                    output[z_slice, y_slice, x_slice] == 0,
                    output_mask,
                    output[z_slice, y_slice, x_slice]
                )
                next_label += 1
    
    end_time = time.time()
    print(f"Processing completed in {end_time - start_time:.2f} seconds")
    if invalid_segments > 0:
        print(f"Filtered out {invalid_segments} invalid segments")
    
    return output

def save_surface_points_to_memmap(roi_surface, z_spacing, xy_spacing, temp_dir, label_value):
    """
    Save surface points to a memory-mapped file, scaled by the given spacing.
    """
    surface_points = np.argwhere(roi_surface).astype(float)
    surface_points[:, 0] *= z_spacing
    surface_points[:, 1] *= xy_spacing
    surface_points[:, 2] *= xy_spacing

    # Create a memmap file to store the surface points for this ROI
    memmap_path = os.path.join(temp_dir, f"roi_{label_value}_surface.memmap")
    surface_memmap = np.memmap(memmap_path, dtype='float32', mode='w+', shape=surface_points.shape)
    surface_memmap[:] = surface_points[:]
    del surface_memmap  # Ensure it’s written to disk

    return memmap_path, surface_points.shape


def calculate_min_distance_memmap(roi_memmap_path, roi_shape, other_memmap_path, other_shape, chunk_size=500):
    """
    Calculate minimum surface-to-surface distance between two ROIs stored in memory-mapped files.
    Returns the minimum distance and coordinates of the closest points on each surface.
    """
    min_distance = np.inf
    min_coords = (None, None)  # Initialize to store coordinates of closest points
    roi_surface_memmap = np.memmap(roi_memmap_path, dtype='float32', mode='r', shape=roi_shape)
    other_surface_memmap = np.memmap(other_memmap_path, dtype='float32', mode='r', shape=other_shape)

    # Calculate distance in chunks
    for i in range(0, other_shape[0], chunk_size):
        other_chunk = other_surface_memmap[i:i+chunk_size]
        distances = cdist(roi_surface_memmap, other_chunk)
        min_dist_in_chunk = distances.min()
        
        # If this chunk contains a closer pair, update min_distance and coordinates
        if min_dist_in_chunk < min_distance:
            min_distance = min_dist_in_chunk
            # Find the indices of the closest pair of points
            roi_idx, other_idx = np.unravel_index(distances.argmin(), distances.shape)
            min_coords = (roi_surface_memmap[roi_idx], other_chunk[other_idx])
        
        if min_distance == 0:
            break  # Stop if distance is zero (surfaces touch)

    return min_distance, min_coords  # Return both distance and coordinates


def compute_surface_min_distances_memmap(labeled_stack, z_spacing, xy_spacing, temp_dir=None, select_ids=None, chunk_size=500, n_jobs=-1):
    """
    Calculate minimum surface-to-surface distances between ROIs using memory-mapped arrays for memory efficiency.
    
    Parameters:
    - labeled_stack: 3D numpy array containing labeled ROIs
    - z_spacing: Spacing in microns between z-planes
    - xy_spacing: Spacing in microns for x and y dimensions
    - temp_dir: Path to an external directory for temporary files (default is a new temp directory)
    - chunk_size: Number of points to load per chunk for distance calculation
    - n_jobs: Number of parallel jobs (default is -1, using all processors)
    
    Returns:
    - distances: Array of minimum distances between ROIs
    - coordinates: List of coordinate pairs for each minimum distance
    """
    # Create or use the specified temporary directory
    if temp_dir is None:
        temp_dir = tempfile.mkdtemp()
    else:
        os.makedirs(temp_dir, exist_ok=True)

    # Get unique ROI labels
    labels = np.unique(labeled_stack)
    labels = labels[labels != 0]

    # Save each ROI's surface points to a memmap file
    roi_memmaps = {}
    for label_value in labels:
        roi_mask = (labeled_stack == label_value)
        eroded_roi = binary_erosion(roi_mask)
        roi_surface = roi_mask & ~eroded_roi
        memmap_path, shape = save_surface_points_to_memmap(roi_surface, z_spacing, xy_spacing, temp_dir, label_value)
        roi_memmaps[label_value] = (memmap_path, shape)

    # Compute minimum distances in parallel
    def calculate_min_distance_for_label(label_value):
        min_distance = np.inf
        closest_points = (None, None)
        closest_roi_id = None
        roi_memmap_path, roi_shape = roi_memmaps[label_value]

        for other_label in labels:
            if other_label == label_value:
                continue

            other_memmap_path, other_shape = roi_memmaps[other_label]
            label_min_distance, label_coords = calculate_min_distance_memmap(
                roi_memmap_path, roi_shape, other_memmap_path, other_shape, chunk_size
            )
            if label_min_distance < min_distance:
                min_distance = label_min_distance
                closest_points = label_coords
                closest_roi_id = other_label
            
            if min_distance == 0:
                break  # Stop if distance is zero (surfaces touch)

        return set(label_value, closest_roi_id), min_distance, closest_points  # Return both distance and coordinates

    # Run the parallel computation
    if select_ids is not None:
        select_labels = select_ids
    else:
        select_labels = labels
    results = Parallel(n_jobs=n_jobs)(
        delayed(calculate_min_distance_for_label)(label_value) for label_value in select_labels
    )

    # Split results into distances and coordinates
    roi_id_set, min_surface_distances, closest_coords = zip(*results)

    # Remove all the files inside the temporary directory
    for file in os.listdir(temp_dir):
        os.remove(os.path.join(temp_dir, file))

    #Save min_surface_distances and closest_coords to files
    output_path0 = os.path.join(temp_dir, "closest_roi_ids.npy")
    output_path1 = os.path.join(temp_dir, "min_surface_distances.npy")
    output_path2 = os.path.join(temp_dir, "closest_coords.npy")
    np.save(output_path0, np.array(roi_id_set))
    np.save(output_path1, np.array(min_surface_distances))
    np.save(output_path2, np.array(closest_coords))


    return np.array(min_surface_distances), closest_coords  # Return both distances and coordinates


def compute_roi_volumes(labeled_stack, z_spacing, xy_spacing, labels=None):
    """
    Compute the volume of each ROI in the labeled stack, given the true voxel dimensions.
    
    Parameters:
    - labeled_stack: 3D numpy array containing labeled ROIs
    - z_spacing: Spacing in microns between z-planes
    - xy_spacing: Spacing in microns for x and y dimensions
    
    Returns:
    - roi_volumes: Dictionary with ROI labels as keys and volumes in cubic microns as values
    """
    print("Computing ROI volumes...")
    # Calculate the volume of a single voxel in cubic microns
    voxel_volume = z_spacing * xy_spacing * xy_spacing
    
    # Get unique ROI labels (excluding 0 as it represents the background)
    if labels is None:
        labels = np.unique(labeled_stack)
        labels = labels[labels != 0]
    
    # Calculate volume for each ROI
    roi_volumes = {}
    for label_value in labels:
        # Count the number of voxels for this ROI
        voxel_count = np.sum(labeled_stack == label_value)
        
        # Compute volume in cubic microns
        roi_volume = voxel_count * voxel_volume
        roi_volumes[label_value] = roi_volume

    return roi_volumes


def visualize_with_napari_overlay(original_stack, merged_roi_array, closest_coords=None, z_scale_factor=2, z_spacing=1, xy_spacing=1):
    """
    Visualize the original 3D image stack overlaid with the ROI segmentation in Napari,
    with lines connecting the closest points between ROIs.

    Parameters:
    - original_stack: 3D numpy array of the original microscopy image data.
    - merged_roi_array: 3D numpy array with labeled regions, where each cell has a unique ROI number.
    - closest_coords: List of tuples, where each tuple contains two coordinate arrays (point1, point2) representing 
                      the closest points between ROIs.
    - z_scale_factor: Scaling factor to adjust z-dimension for proper aspect ratio.
    - z_spacing: Spacing in microns between z-planes (used to convert line coordinates to voxel units).
    - xy_spacing: Spacing in microns for x and y dimensions (used to convert line coordinates to voxel units).
    """
    print("Launching Napari visualization with overlay...")

    # Initialize the Napari viewer
    viewer = napari.Viewer()

    # Add the original image stack as the first layer
    viewer.add_image(
        original_stack,
        name="Original Image",
        scale=(z_scale_factor, 1, 1),  # Apply only z scaling factor
        colormap='gray',
        opacity=0.8
    )

    # Add the segmented ROIs as a labels layer
    viewer.add_labels(
        merged_roi_array,
        name="ROIs",
        scale=(z_scale_factor, 1, 1),  # Apply only z scaling factor
        opacity=0.6  # Adjust opacity for overlay effect
    )

    if closest_coords is not None:
        # Prepare lines for closest points, converting each coordinate to voxel units
        lines = []
        for point1, point2 in closest_coords:
            # Convert micron coordinates to voxel coordinates
            start_point = (point1[0] * z_scale_factor, point1[1] / xy_spacing, point1[2] / xy_spacing)
            end_point = (point2[0] * z_scale_factor, point2[1] / xy_spacing, point2[2] / xy_spacing)
            lines.append([start_point, end_point])

        # Add the lines connecting closest points as shapes in Napari
        viewer.add_shapes(
            lines,
            shape_type='line',
            edge_color='red',
            edge_width=1.5,
            name="Closest Points Connections"
        )

    # Switch to 3D mode
    viewer.dims.ndisplay = 3

    # Start the Napari viewer event loop
    napari.run()


def create_parameter_widget(param_name: str, param_config: Dict[str, Any], callback):
    """Create a widget for a single parameter"""
    
    # Define the function with the appropriate type annotation
    if param_config["type"] == "float":
        def parameter_widget(value: float = param_config["value"]):
            callback(value)  # Modified to only pass the value
            return value
    elif param_config["type"] == "int":
        def parameter_widget(value: int = param_config["value"]):
            callback(value)  # Modified to only pass the value
            return value
    else:
        def parameter_widget(value: float = param_config["value"]):
            callback(value)  # Modified to only pass the value
            return value
    
    # Create the widget with magicgui
    widget = magicgui(
        parameter_widget,
        auto_call=True,
        value={
            "label": param_config["label"],
            "min": param_config["min"],
            "max": param_config["max"],
            "step": param_config["step"]
        }
    )
    
    # Store the original parameter name as an attribute
    widget.param_name = param_name
    
    return widget

class DynamicGUIManager:
    def __init__(self, viewer, config, image_stack, file_loc):
        self.viewer = viewer
        self.config = config
        self.image_stack = image_stack
        self.file_loc = file_loc
        self.current_widgets = {}
        self.current_step = {"value": 0}
        self.processing_steps = ["initial_segmentation", "merge_rois", "split_rois"]
        self.parameter_values = {}
        self.active_dock_widgets = set()
        
        # Set up processing directory
        self.inputdir = os.path.dirname(self.file_loc)
        self.basename = os.path.basename(self.file_loc).split('.')[0]
        self.processed_dir = os.path.join(self.inputdir, f"{self.basename}_processed")
        if not os.path.exists(self.processed_dir):
            os.makedirs(self.processed_dir)
        
        # Initialize processing state
        self.initial_segmentation = np.zeros_like(self.image_stack, dtype=np.int32)
        self.slice_settings = {}
        
        # Set up image enhancement
        self.enhanced_stack = adaptive_contrast_enhancement(self.image_stack)
        
        # Initialize viewer layers
        self._initialize_layers()


    def cleanup_step(self, step_number):
        """Clean up the results and layers from a specific step"""
        if step_number == 1:
            if "Intermediate segmentation 1" in self.viewer.layers:
                self.viewer.layers.remove("Intermediate segmentation 1")
            segmented_cells_path = os.path.join(self.processed_dir, "segmented_cells.npy")
            if os.path.exists(segmented_cells_path):
                os.remove(segmented_cells_path)

        elif step_number == 2:
            if "Intermediate segmentation 2" in self.viewer.layers:
                self.viewer.layers.remove("Intermediate segmentation 2")
            merged_roi_array_loc = os.path.join(self.processed_dir, "merged_roi_array_optimized.dat")
            if os.path.exists(merged_roi_array_loc):
                os.remove(merged_roi_array_loc)

        elif step_number == 3:
            if "Segmentation without large volumes" in self.viewer.layers:
                self.viewer.layers.remove("Segmentation without large volumes")
            updated_stack_loc = os.path.join(self.processed_dir, "updated_stack.npy")
            if os.path.exists(updated_stack_loc):
                os.remove(updated_stack_loc)


    def execute_processing_step(self):
        """Execute the next step in the processing pipeline"""
        try:
            if self.current_step["value"] == 0:
                # Step 1: Initial cell segmentation
                print("Running initial cell segmentation...")
                segmented_cells_path = os.path.join(self.processed_dir, "segmented_cells.npy")
                
                # Remove existing results if present
                self.cleanup_step(1)
                
                # Get current parameter values
                current_values = self.get_current_values()
                
                labeled_cells = intensity_based_segmentation(
                    self.image_stack,
                    slice_settings=self.slice_settings,
                    temp_dir=self.processed_dir,
                    downsample_factor=current_values.get("downsample_factor", 2)
                )
                
                np.save(segmented_cells_path, labeled_cells)
                self.viewer.add_labels(
                    labeled_cells,
                    name="Intermediate segmentation 1",
                    scale=(self.z_scale_factor, 1, 1)
                )
                
                self.current_step["value"] += 1
                self.create_step_widgets("merge_rois")

            elif self.current_step["value"] == 1:
                # Step 2: Merge close ROIs
                print("Merging close ROIs...")
                segmented_cells_path = os.path.join(self.processed_dir, "segmented_cells.npy")
                merged_roi_array_loc = os.path.join(self.processed_dir, "merged_roi_array_optimized.dat")
                
                # Remove existing results if present
                self.cleanup_step(2)
                
                # Get current parameter values
                current_values = self.get_current_values()
                
                labeled_cells = np.load(segmented_cells_path)
                merged_roi_array = split_merged_masks(
                    labeled_cells,
                    self.enhanced_stack,
                    min_distance=current_values.get("min_distance", 10),
                    min_intensity_ratio=current_values.get("min_intensity_ratio", 0.3)
                )
                
                # Save the merged array
                merged_roi_array.tofile(merged_roi_array_loc)
                
                self.viewer.add_labels(
                    merged_roi_array,
                    name="Intermediate segmentation 2",
                    scale=(self.z_scale_factor, 1, 1)
                )
                
                self.current_step["value"] += 1
                self.create_step_widgets("split_rois")

            elif self.current_step["value"] == 2:
                # Step 3: Split large ROIs
                print("Splitting large ROIs...")
                merged_roi_array_loc = os.path.join(self.processed_dir, "merged_roi_array_optimized.dat")
                updated_stack_loc = os.path.join(self.processed_dir, "updated_stack.npy")
                
                # Remove existing results if present
                self.cleanup_step(3)
                
                # Get current parameter values
                current_values = self.get_current_values()
                
                merged_roi_array = np.memmap(
                    merged_roi_array_loc,
                    dtype=np.int32,
                    mode='r',
                    shape=self.image_stack.shape
                )
                
                updated_stack = split_large_rois_with_intensity(
                    merged_roi_array,
                    self.config['voxel_dimensions']['z'],
                    self.x_spacing,
                    mean_guess=current_values.get("mean_guess", 5000),
                    std_guess=current_values.get("std_guess", 2500),
                    tempdir=self.processed_dir,
                    max_iters=current_values.get("max_iters", 10)
                )
                
                np.save(updated_stack_loc, updated_stack)
                self.viewer.add_labels(
                    updated_stack,
                    name="Segmentation without large volumes",
                    scale=(self.z_scale_factor, 1, 1)
                )
                
                self.current_step["value"] += 1
                print("Processing complete!")
                
                # Save final configuration
                self.save_updated_config()
            
            else:
                print("All processing steps completed.")
                
        except Exception as e:
            print(f"Error during processing step {self.current_step['value']}: {str(e)}")
            raise

    def save_updated_config(self):
        """Save the current configuration to a YAML file"""
        config_save_path = os.path.join(self.processed_dir, "processing_config.yaml")
        with open(config_save_path, 'w') as file:
            yaml.dump(self.config, file, default_flow_style=False)

    def clear_current_widgets(self):
        """Remove all current widgets"""
        # Get list of dock widgets from the viewer
        dock_widgets = list(self.viewer.window._dock_widgets.values())
        
        # Remove each widget
        for dock_widget in dock_widgets:
            if dock_widget in self.current_widgets:
                try:
                    self.viewer.window.remove_dock_widget(dock_widget)
                except Exception as e:
                    print(f"Warning: Failed to remove dock widget: {str(e)}")
                
        # Clear the tracking dictionary
        self.current_widgets.clear()

    def create_step_widgets(self, step_name: str):
        """Create all widgets for a processing step"""
        try:
            # Remove existing widgets
            self.clear_current_widgets()
            
            # Create new widgets for each parameter in the step
            if step_name not in self.config:
                print(f"Warning: {step_name} not found in config")
                return
                
            step_config = self.config[step_name]
            if "parameters" not in step_config:
                print(f"Warning: no parameters found for {step_name}")
                return
            
            # Reset parameter values for this step
            self.parameter_values = {}
            
            # Create parameter widgets
            for param_name, param_config in step_config["parameters"].items():
                try:
                    # Create callback for this specific parameter
                    callback = lambda value, pn=param_name: self.parameter_changed(step_name, pn, value)
                    
                    # Create widget
                    widget = create_parameter_widget(param_name, param_config, callback)
                    dock_widget = self.viewer.window.add_dock_widget(widget, area="right")
                    self.current_widgets[dock_widget] = widget
                    
                    # Store initial value
                    self.parameter_values[param_name] = param_config["value"]
                except Exception as e:
                    print(f"Error creating widget for {param_name}: {str(e)}")
            
            # Add Update Mask button for initial segmentation
            if step_name == "initial_segmentation":
                try:
                    @magicgui(call_button="Update Mask")
                    def update_mask():
                        slice_idx = self.viewer.dims.current_step[0]
                        params = self.get_current_values()
                        self.apply_mask(
                            slice_idx,
                            params["intensity_threshold"],
                            params["min_volume"],
                            params["downsample_factor"]
                        )
                    
                    dock_widget = self.viewer.window.add_dock_widget(update_mask, area="right")
                    self.current_widgets[dock_widget] = update_mask
                except Exception as e:
                    print(f"Error creating update mask widget: {str(e)}")
                    
        except Exception as e:
            print(f"Error in create_step_widgets: {str(e)}")

    def remove_widget(self, dock_widget):
        """Safely remove a single widget"""
        try:
            if dock_widget in self.current_widgets:
                self.viewer.window.remove_dock_widget(dock_widget)
                del self.current_widgets[dock_widget]
        except Exception as e:
            print(f"Warning: Failed to remove dock widget: {str(e)}")
        
    def _initialize_layers(self):
        """Initialize the basic layers in the viewer"""
        # Get voxel dimensions from config
        voxel_x = self.config.get('voxel_dimensions', {}).get('x', 1)
        voxel_y = self.config.get('voxel_dimensions', {}).get('y', 1)
        voxel_z = self.config.get('voxel_dimensions', {}).get('z', 1)
        self.x_spacing = voxel_x / self.image_stack.shape[1]
        self.z_scale_factor = voxel_z/self.x_spacing
        
        # Add layers
        self.viewer.add_image(
            self.image_stack, 
            name="Original stack", 
            scale=(self.z_scale_factor, 1, 1)
        )
        self.viewer.add_image(
            self.enhanced_stack, 
            name="Enhanced stack", 
            scale=(self.z_scale_factor, 1, 1)
        )
        self.viewer.add_labels(
            self.initial_segmentation, 
            name="Initial segmentation", 
            scale=(self.z_scale_factor, 1, 1)
        )
        
    def parameter_changed(self, step_name: str, param_name: str, value: Any):
        """Callback for when a parameter value changes"""
        if step_name in self.config and "parameters" in self.config[step_name]:
            self.config[step_name]["parameters"][param_name]["value"] = value
            self.parameter_values[param_name] = value
        
    def apply_mask(self, slice_index, intensity_threshold, min_volume, downsample_factor):
        """Generate and apply the segmentation mask for a given slice."""
        try:
            image_slice = self.enhanced_stack[slice_index]
            downsampled = downsample_stack(image_slice[None, ...], factor=downsample_factor)[0]
            enhanced = downsampled
            seed_threshold = intensity_threshold * np.max(enhanced)
            seeds = np.argwhere(enhanced >= seed_threshold)

            mask = np.zeros_like(enhanced, dtype=np.int32)
            for seed in seeds:
                if enhanced[tuple(seed)] >= seed_threshold:
                    mask[tuple(seed)] = 1
            mask = morphology.binary_dilation(mask)
            labeled_mask = label(mask)
            filtered_mask = morphology.remove_small_objects(labeled_mask, min_size=min_volume)

            if downsample_factor > 1:
                filtered_mask = upsample_stack(filtered_mask, image_slice.shape)

            self.initial_segmentation[slice_index] = filtered_mask
            self.viewer.layers["Initial segmentation"].data = self.initial_segmentation
            self.slice_settings[slice_index] = {
                "intensity_threshold": intensity_threshold,
                "min_volume": min_volume,
                "downsample_factor": downsample_factor
            }
            
        except Exception as e:
            print(f"Error in apply_mask: {str(e)}")


    def get_current_values(self) -> Dict[str, Any]:
        """Get current values for all parameters in the current step"""
        return self.parameter_values.copy()
        
    def apply_initial_segmentation(self, values):
        """Apply initial segmentation with current parameters"""
        slice_idx = self.viewer.dims.current_step[0]
        self.apply_mask(
            slice_idx, 
            values["intensity_threshold"],
            values["min_volume"],
            values["downsample_factor"]
        )

    def apply_merge_rois(self, values):
        """Apply ROI merging with current parameters from the GUI"""
        if "Intermediate segmentation 1" not in self.viewer.layers:
            print("Error: Previous segmentation layer not found")
            return
            
        labeled_cells = self.viewer.layers["Intermediate segmentation 1"].data
        
        merged_roi_array = split_merged_masks(
            labeled_cells,
            self.enhanced_stack,
            min_distance=values["min_distance"],
            min_intensity_ratio=values["min_intensity_ratio"]
        )
        
        self.viewer.add_labels(
            merged_roi_array, 
            name="Intermediate segmentation 2", 
            scale=(self.z_scale_factor, 1, 1)
        )

    def apply_split_rois(self, values):
        """Apply ROI splitting with current parameters from the GUI"""
        if "Intermediate segmentation 2" not in self.viewer.layers:
            print("Error: Previous segmentation layer not found")
            return
            
        merged_roi_array = self.viewer.layers["Intermediate segmentation 2"].data
        
        updated_stack = split_large_rois_with_intensity(
            merged_roi_array,
            self.config['voxel_dimensions']['z'],
            self.x_spacing,
            mean_guess=values["mean_guess"],
            std_guess=values["std_guess"],
            tempdir=self.processed_dir,
            max_iters=values["max_iters"]
        )
        
        self.viewer.add_labels(
            updated_stack, 
            name="Segmentation without large volumes", 
            scale=(self.z_scale_factor, 1, 1)
        )

def interactive_segmentation_with_config():
    """
    Launch interactive segmentation with dynamic GUI based on YAML configuration
    """
    # First prompt for config file
    Tk().withdraw()
    config_path = filedialog.askopenfilename(
        title="Select config YAML file (optional)", 
        filetypes=[("YAML files", "*.yaml *.yml"), ("All files", "*.*")]
    )
    
    # Load or create config
    if config_path:
        with open(config_path, 'r') as file:
            config = yaml.safe_load(file)
    else:
        print("No config file selected. Please select a config file.")
        return
    
    # Prompt for input file
    file_loc = filedialog.askopenfilename(
        title="Select a .tif file", 
        filetypes=[("TIFF files", "*.tif")]
    )
    if not file_loc:
        print("No file selected. Exiting.")
        return

    # Create processed directory
    inputdir = os.path.dirname(file_loc)
    basename = os.path.basename(file_loc).split('.')[0]
    processed_dir = os.path.join(inputdir, f"{basename}_processed")
    if not os.path.exists(processed_dir):
        os.makedirs(processed_dir)

    # Load the .tif file
    image_stack = tiff.imread(file_loc)
    print(f"Loaded stack with shape {image_stack.shape}")

    # Initialize viewer and GUI manager
    viewer = napari.Viewer()
    gui_manager = DynamicGUIManager(viewer, config, image_stack, file_loc)
    
    @magicgui(call_button="Continue Processing")
    def continue_processing():
        """Execute the next step in the processing pipeline"""
        gui_manager.execute_processing_step()
        update_navigation_buttons()

    @magicgui(call_button="Previous Step")
    def go_to_previous_step():
        """Go back one step in the processing pipeline"""
        if gui_manager.current_step["value"] > 0:
            gui_manager.current_step["value"] -= 1
            step_name = gui_manager.processing_steps[gui_manager.current_step["value"]]
            gui_manager.create_step_widgets(step_name)
            gui_manager.cleanup_step(gui_manager.current_step["value"] + 1)
            update_navigation_buttons()

    def update_navigation_buttons():
        """Update the state of navigation buttons"""
        previous_step_button.enabled = gui_manager.current_step["value"] > 0
        continue_processing_button.enabled = gui_manager.current_step["value"] < len(gui_manager.processing_steps)

    # Add navigation buttons
    continue_processing_button = continue_processing
    previous_step_button = go_to_previous_step
    viewer.window.add_dock_widget(continue_processing_button, area="right")
    viewer.window.add_dock_widget(previous_step_button, area="right")
    update_navigation_buttons()

    # Create initial GUI widgets
    gui_manager.create_step_widgets("initial_segmentation")

    napari.run()

# Load and process data

In [None]:
interactive_segmentation_with_config()

Loaded stack with shape (90, 1536, 1536)




Running initial cell segmentation...
Initializing 3D segmentation...
Preprocessing image stack...
Enhancing contrast...
Downsampling...
Generating volume slices...
Created 512 subvolumes for processing
Starting parallel processing with 24 workers...


  filtered_labels = morphology.remove_small_objects(labels, min_size=settings['min_volume'])
  filtered_labels = morphology.remove_small_objects(labels, min_size=settings['min_volume'])
  filtered_labels = morphology.remove_small_objects(labels, min_size=settings['min_volume'])
  filtered_labels = morphology.remove_small_objects(labels, min_size=settings['min_volume'])
  filtered_labels = morphology.remove_small_objects(labels, min_size=settings['min_volume'])
  filtered_labels = morphology.remove_small_objects(labels, min_size=settings['min_volume'])
  filtered_labels = morphology.remove_small_objects(labels, min_size=settings['min_volume'])
Processing subvolumes: 100%|██████████| 512/512 [03:49<00:00,  2.23it/s]


Stitching subvolumes...
Starting volume stitching...
Making labels unique across blocks...


Relabeling blocks: 100%|██████████| 512/512 [00:00<00:00, 2461.95it/s]


First pass: Finding overlapping regions...


Processing blocks: 100%|██████████| 512/512 [00:00<00:00, 2122.78it/s]


Relabeling to ensure consecutive labels...
Upsampling results...
Saving results...
Results saved to: /home/kirill/Desktop/A_1_processed/segmented_cells.npy


# Debugging

In [8]:
file_loc = '/home/kirill/Desktop/A_1.czi.tif'
default_settings = {
        "intensity_threshold": 0.9,
        "min_volume": 500,
    }
slice_settings = {44 : {"intensity_threshold" : 0.2, "min_volume" : 50}, 
                  64 : {"intensity_threshold" : 0.3, "min_volume" : 50}, 
                  77 : {"intensity_threshold" : 0.8, "min_volume" : 50}}

inputdir = os.path.dirname(file_loc)
#Get the filename basename
basename = os.path.basename(file_loc).split('.')[0]
#check if basename_processed directory exists
processed_dir = os.path.join(inputdir, f"{basename}_processed")
image_stack = tiff.imread(file_loc)

if not os.path.exists(processed_dir):
    os.makedirs(processed_dir)
merged_roi_array =  intensity_based_segmentation(image_stack, slice_settings=slice_settings, temp_dir=processed_dir, downsample_factor=2)
merged_roi_array_loc = os.path.join(processed_dir, "merged_roi_array_optimized.dat")
if os.path.exists(merged_roi_array_loc):
    merged_roi_array = np.memmap(merged_roi_array_loc, dtype=np.int32, mode='r+', shape=image_stack.shape)
xy_dim = 471.4 
xy_spacing = xy_dim / merged_roi_array.shape[1]

#Load enhanced stack
enhanced_stack_path = os.path.join(processed_dir, "enhanced_stack.tif")
enhanced_stack = tiff.imread(enhanced_stack_path)

visualize_with_napari_overlay(enhanced_stack, merged_roi_array, closest_coords=None,  z_scale_factor=1/xy_spacing, z_spacing=1, xy_spacing=xy_spacing)

Fitting depth-adaptive parameters based on user-defined settings...
Performing intensity-based region growing in parallel...
Interp_intensity_threshold: [0.2        0.2        0.2        0.2        0.2        0.2
 0.2        0.2        0.2        0.2        0.2        0.2
 0.2        0.2        0.2        0.2        0.2        0.2
 0.2        0.2        0.2        0.2        0.2        0.2
 0.2        0.2        0.2        0.2        0.2        0.2
 0.2        0.2        0.2        0.2        0.2        0.2
 0.2        0.2        0.2        0.2        0.2        0.2
 0.2        0.2        0.2        0.205      0.21       0.215
 0.22       0.225      0.23       0.235      0.24       0.245
 0.25       0.255      0.26       0.265      0.27       0.275
 0.28       0.285      0.29       0.295      0.3        0.33846154
 0.37692308 0.41538462 0.45384615 0.49230769 0.53076923 0.56923077
 0.60769231 0.64615385 0.68461538 0.72307692 0.76153846 0.8
 0.8        0.8        0.8        0.8        0.

Processing slices: 100%|██████████| 45/45 [03:48<00:00,  5.07s/it]  


Labeling connected components in 3D...


  tiff.imsave(enhanced_stack_path, enhanced_stack)


Segmentation complete. Results saved at /home/kirill/Desktop/A_1_processed/segmented_cells.npy
Launching Napari visualization with overlay...


In [10]:
xy_dim = 471.4 
xy_spacing = xy_dim / merged_roi_array.shape[1]
min_surface_distances, closest_points_dict = compute_surface_min_distances_memmap(
    merged_roi_array, z_spacing=1, xy_spacing=xy_spacing, temp_dir='/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/tmp', chunk_size=500, n_jobs=44
)


In [23]:
#Using true dimensions in microns, calculate scaling factor for visualisation
img = tiff.imread('/home/kirill/Desktop/A_1.czi.tif')
# merged_roi_array = np.memmap("data/merged_roi_array_optimized.dat", dtype=np.int32, mode='r+', shape=img.shape)
# closest_points_dict = np.load("/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/tmp/closest_coords.npy", allow_pickle=True)
labeled_cells = np.load("/home/kirill/Desktop/A_1_processed/segmented_cells.npy")

# xy_dim = 471.4 
# z_spacing = 1
# xy_spacing = xy_dim / merged_roi_array.shape[1]
# enhanced_stack = adaptive_contrast_enhancement(img)
# visualize_with_napari_overlay(enhanced_stack, merged_roi_array, closest_points_dict,  z_scale_factor=1/xy_spacing, z_spacing=1, xy_spacing=xy_spacing)
# visualize_with_napari_overlay(img, merged_roi_array, closest_points_dict, z_scale_factor=1/xy_spacing)

split_merged_masks(labeled_cells, img, min_distance=10,
                              min_intensity_ratio=0.3, n_processes=None)


Starting parallel processing with 23 processes
Found 451 cells to process


Processing cells: 100%|██████████| 451/451 [00:01<00:00, 391.98it/s]


Combining results...
Processing completed in 172.21 seconds


array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 

In [8]:
#Plot a histogram of the volumes
image_stack = tiff.imread('data/Added later-D31_no cyto_1.tif')
processed_dir = "data/Added later-D31_no cyto_1_processed"
merged_roi_array_loc = os.path.join(processed_dir, "merged_roi_array_optimized.dat")
merged_roi_array = np.memmap(merged_roi_array_loc, dtype=np.int32, mode='r+', shape=image_stack.shape)
xy_spacing = 471.4 / merged_roi_array.shape[1]
volumes = compute_roi_volumes(merged_roi_array, z_spacing=1, xy_spacing=xy_spacing)


In [12]:
volumes_list = list(volumes.values())
fig = go.Figure(data=[go.Histogram(x=volumes_list)])
fig.update_layout(title_text='Histogram of ROI Volumes')
fig.show()

In [None]:
import numpy as np
from scipy.ndimage import distance_transform_edt, label
from skimage.measure import regionprops

def refine_rois(labeled_stack, z_spacing, xy_spacing, temp_dir, min_volume=4000, max_volume=10000, distance_cutoff=30):
    """
    Refines ROIs in a 3D labeled stack:
    - Removes small ROIs that are far from larger ROIs.
    - Merges small ROIs with the closest larger ROI if within the distance cutoff.

    Parameters:
        labeled_stack (numpy.ndarray): 3D labeled image stack.
        min_volume (int): Minimum volume (in voxels) for a "normal" ROI.
        max_volume (int): Maximum volume (in voxels) to exclude exceptionally large ROIs.
        distance_cutoff (int): Maximum distance (in pixels) for merging small ROIs with normal ROIs.

    Returns:
        numpy.ndarray: Refined 3D labeled stack.
    """
    print("Refining ROIs...")

    #Create a memmap to hold the final labeled stack
    final_labels = np.memmap(os.path.join(temp_dir, "final_labels.dat"), dtype=np.int32, mode='w+', shape=labeled_stack.shape)

    #Copy the labeled stack to the final labels
    final_labels[:] = labeled_stack[:]

    # Step 1: Calculate all the volumes
    all_volumes = compute_roi_volumes(final_labels, z_spacing=z_spacing, xy_spacing=xy_spacing)

    # Separate ROIs into categories based on volume
    small_rois = []
    normal_rois = []
    for roi_label, roi_volume in all_volumes.items():
        if roi_volume < min_volume:
            small_rois.append(roi_label)
        else:
            normal_rois.append(roi_label)
    
    to_retest = [k for k in all_volumes.keys()]
    to_remove = []

    while True:
        print(f"Identified {len(small_rois)} small ROIs and {len(normal_rois)} normal ROIs.")

        closest_roi_ids, min_surface_distances, closest_points_dict = compute_surface_min_distances_memmap(final_labels, 
                                                                                                       z_spacing, 
                                                                                                       xy_spacing, 
                                                                                                       temp_dir=temp_dir, 
                                                                                                       select_ids=normal_rois,
                                                                                                       chunk_size=500, 
                                                                                                       n_jobs=-1)
        removed_small_rois = []
        #Sort each three lists by the min_surface_distances (from smallest to largest)
        closest_roi_ids = [x for _, x in sorted(zip(min_surface_distances, closest_roi_ids))]
        min_surface_distances = sorted(min_surface_distances)
        closest_points_dict = [x for _, x in sorted(zip(min_surface_distances, closest_points_dict))]

        # Step 2: Identify normal ROIs whose shortest distance is to a small ROI
        for distances_pos, roi_pair in tqdm(enumerate(closest_roi_ids)):
            roi_1, roi_2 = roi_pair
            if roi_1 in removed_small_rois or roi_2 in removed_small_rois:
                continue
            small_roi, normal_roi = None, None
            if roi_1 in normal_rois and roi_2 in small_rois:
                small_roi = roi_2
                normal_roi = roi_1
            elif roi_2 in normal_rois and roi_1 in small_rois:
                small_roi = roi_1
                normal_roi = roi_2
            if small_roi is not None:
                if min_surface_distances[distances_pos] < distance_cutoff:
                    # Modify labels in the final labels array to merge the small ROI with the normal ROI
                    final_labels[final_labels == small_roi] = normal_roi
                    removed_small_rois.append(small_roi)
                    #Remove small ROI from all_volumes
                    all_volumes.pop(small_roi)
                    #Remove small ROI from small_rois
                    small_rois.remove(small_roi)

                    #To make sure the surface is continuous, fill any background values with normal roi valuesa along the line of shorthest distance
                    coords = closest_points_dict[distances_pos]
                    line = np.linspace(coords[0], coords[1], num=100)
                    for point in line:
                        point = point.astype(int)
                        if final_labels[point[0], point[1], point[2]] == 0:
                            final_labels[point[0], point[1], point[2]] = normal_roi
                    #Recalculate the volume of the normal ROI
                    normal_roi_volume = compute_roi_volumes(final_labels, z_spacing=z_spacing, xy_spacing=xy_spacing, labels=[normal_roi])[normal_roi]
                    all_volumes[normal_roi] = normal_roi_volume
        if len(removed_small_rois) == 0:
            break
    #Remove all the remaining small ROIs that are not close to any normal ROIs (merge with background)
    for roi in small_rois:
        if roi not in removed_small_rois:
            final_labels[final_labels == roi] = 0

    print("ROI refinement complete.")
    return final_labels


In [16]:
#Load updated_stack.npy
updated_stack = np.load('/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/Takeshi_data/Added later-D31_no cyto_1_processed/updated_stack.npy')

xy_spacing = 471.4 / updated_stack.shape[1]
refine_rois(updated_stack, 
            z_spacing=1, 
            xy_spacing=xy_spacing, 
            temp_dir='/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/Takeshi_data/Added later-D31_no cyto_1_processed', 
            min_volume=4000, 
            max_volume=10000, 
            distance_cutoff=30)


Refining ROIs...
Computing ROI volumes...
Identified 4646 small ROIs and 112 normal ROIs.


KeyboardInterrupt: 

In [34]:
import numpy as np
import os
import multiprocessing as mp
from scipy.spatial import cKDTree
from tqdm import tqdm

import numpy as np
import multiprocessing as mp

def extract_surface_points_for_label(args):
    """
    Extracts surface points for a single ROI using matrix operations.
    
    Parameters:
        args (tuple): Contains the labeled stack, label, z_spacing, and xy_spacing.
    
    Returns:
        tuple: A tuple of (label, scaled surface points).
    """
    labeled_stack, label, z_spacing, xy_spacing = args
    binary_mask = (labeled_stack == label)

    # Compute shifts for neighbors in a 3D grid
    shifts = np.array([
        [-1,  0,  0], [1,  0,  0],  # z-axis neighbors
        [ 0, -1,  0], [0,  1,  0],  # y-axis neighbors
        [ 0,  0, -1], [0,  0,  1],  # x-axis neighbors
    ])
    
    # Find surface voxels
    padded_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=False)
    neighbors = np.zeros_like(binary_mask, dtype=bool)
    
    for shift in shifts:
        shifted = padded_mask[
            1 + shift[0]:1 + shift[0] + binary_mask.shape[0],
            1 + shift[1]:1 + shift[1] + binary_mask.shape[1],
            1 + shift[2]:1 + shift[2] + binary_mask.shape[2],
        ]
        neighbors |= ~shifted  # Mark as surface if adjacent to background
    
    surface_voxels = np.argwhere(binary_mask & neighbors)
    
    # Scale voxel coordinates to physical distances
    scaled_points = surface_voxels * [z_spacing, xy_spacing, xy_spacing]
    return label, scaled_points

def compute_surface_points(labeled_stack, z_spacing, xy_spacing, n_processes=None):
    """
    Computes the 3D surface points of each labeled ROI in the stack using parallel processing.
    
    Parameters:
        labeled_stack (np.ndarray): 3D array with labeled ROIs.
        z_spacing (float): Spacing in microns between z-planes.
        xy_spacing (float): Spacing in microns for x and y dimensions.
        n_processes (int, optional): Number of processes to use. Defaults to os.cpu_count().
    
    Returns:
        dict: A dictionary where keys are ROI IDs and values are arrays of surface points.
    """
    print("Computing surface points for each ROI...")
    unique_labels = np.unique(labeled_stack)
    unique_labels = unique_labels[unique_labels != 0]  # Exclude background (0)
    
    # Prepare arguments for parallel processing
    args = [(labeled_stack, label, z_spacing, xy_spacing) for label in unique_labels]
    
    # Use multiprocessing to extract surface points for each label
    with mp.Pool(n_processes or mp.cpu_count()) as pool:
        results = tqdm(pool.map(extract_surface_points_for_label, args))
    
    # Combine results into a dictionary
    surface_points = {label: points for label, points in results}
    return surface_points


def compute_pairwise_distance(args):
    """
    Computes the shortest distance between two sets of surface points.

    Parameters:
        args (tuple): Contains surface points of ROI1, ROI2, and ROI IDs.

    Returns:
        tuple: A tuple of (roi1, roi2, {"distance": shortest_distance, "points": (point1, point2)}).
    """
    roi1, roi2, points1, points2 = args

    tree1 = cKDTree(points1)
    tree2 = cKDTree(points2)

    # Compute distances in both directions and find the minimum
    distances1, indices1 = tree1.query(points2, k=1)
    distances2, indices2 = tree2.query(points1, k=1)

    min_distance1 = np.min(distances1)
    min_distance2 = np.min(distances2)

    if min_distance1 < min_distance2:
        idx = np.argmin(distances1)
        closest_points = (points1[indices1[idx]], points2[idx])
        shortest_distance = min_distance1
    else:
        idx = np.argmin(distances2)
        closest_points = (points1[idx], points2[indices2[idx]])
        shortest_distance = min_distance2

    return roi1, roi2, {"distance": shortest_distance, "points": closest_points}

def compute_surface_min_distances(labeled_stack, z_spacing, xy_spacing, temp_dir):
    """
    Computes the shortest distances between surfaces of each pair of ROIs.

    Parameters:
        labeled_stack (np.ndarray): 3D array with labeled ROIs.
        z_spacing (float): Spacing in microns between z-planes.
        xy_spacing (float): Spacing in microns for x and y dimensions.
        temp_dir (str): Directory path for saving temporary files.

    Returns:
        np.ndarray: Matrix of shape (n_rois, n_rois) containing dictionaries with "distance" and "points".
    """
    os.makedirs(temp_dir, exist_ok=True)

    surface_points = compute_surface_points(labeled_stack, z_spacing, xy_spacing)
    unique_labels = list(surface_points.keys())
    n_rois = len(unique_labels)

    # Create memmap for results
    result_path = os.path.join(temp_dir, "distance_matrix.npy")
    distance_matrix = np.memmap(result_path, dtype=object, mode="w+", shape=(n_rois, n_rois))

    # Prepare arguments for multiprocessing
    args = []
    for i, roi1 in enumerate(unique_labels):
        for j, roi2 in enumerate(unique_labels):
            if i < j:  # Avoid redundant computations
                args.append((roi1, roi2, surface_points[roi1], surface_points[roi2]))

    # Parallel computation of pairwise distances
    with mp.Pool(mp.cpu_count()) as pool:
        results = list(tqdm(pool.imap(compute_pairwise_distance, args), total=len(args)))

    # Fill the result matrix
    for roi1, roi2, result in results:
        idx1 = unique_labels.index(roi1)
        idx2 = unique_labels.index(roi2)
        distance_matrix[idx1, idx2] = result
        distance_matrix[idx2, idx1] = result  # Symmetric matrix

    return distance_matrix


In [35]:
outpt2 = compute_surface_min_distances(updated_stack, 1, xy_spacing, '/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/Takeshi_data/Added later-D31_no cyto_1_processed')

Computing surface points for each ROI...


100%|██████████| 4758/4758 [00:00<00:00, 4992.03it/s]
100%|██████████| 11316903/11316903 [27:02<00:00, 6974.62it/s] 


In [44]:
output_dir = '/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/Takeshi_data/Added later-D31_no cyto_1_processed'

#Load the distance matrix memmap
distance_matrix = np.memmap(os.path.join(output_dir, "distance_matrix.npy"), dtype=object, mode='r+', shape=(updated_stack.shape[1], updated_stack.shape[1]))
distance_matrix[1,1]

#Modify distance matrix to add ROI IDs in the dictionary
for i in range(distance_matrix.shape[0]):
    for j in range(distance_matrix.shape[1]):
        if distance_matrix[i, j] is not None:
            roi_set = set([i, j])
            distance_matrix[i, j] = (roi_set, distance_matrix[i, j]["distance"], distance_matrix[i, j]["points"])

In [46]:
output_dir = '/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/Takeshi_data/Added later-D31_no cyto_1_processed'

#Load the distance matrix memmap
distance_matrix = np.memmap(os.path.join(output_dir, "distance_matrix.npy"), dtype=object, mode='r+', shape=(updated_stack.shape[1], updated_stack.shape[1]))
distance_matrix[1,2]

({1, 2},
 90.82324025018849,
 (array([  3.        , 168.18177083, 230.78958333]),
  array([ 45.        ,  88.08059896, 222.50325521])))

In [None]:
def convert_microns_to_voxels(coords, xy_spacing, z_spacing):
    """
    Convert coordinates in microns to voxel units.
    
    Parameters:
    - coords: Tuple of two 3D coordinate arrays (point1, point2) in microns.
    - xy_spacing: Spacing in microns for x and y dimensions.
    - z_spacing: Spacing in microns between z-planes.
    
    Returns:
    - Tuple of two 3D coordinate arrays (point1, point2) in voxel units.
    """
    point1, point2 = coords
    point1_voxels = point1 / [z_spacing, xy_spacing, xy_spacing]
    point2_voxels = point2 / [z_spacing, xy_spacing, xy_spacing]
    return point1_voxels, point2_voxels


def just_closest_rois(distance_matrix):
    """
    While the distance matrix has all the pairwise distance, this function will return only the closest ROI for each ROI.
    """
    closest_rois = []
    for i in range(distance_matrix.shape[0]):
        min_distance = np.inf
        closest_roi_id = None
        for j in range(distance_matrix.shape[1]):
            if distance_matrix[i, j] is not None:
                if distance_matrix[i, j][1] < min_distance:
                    min_distance = distance_matrix[i, j][1]
                    closest_roi_id = j
        closest_rois.append(distance_matrix[i, closest_roi_id])
    return closest_rois