# Imports

In [1]:
import tifffile as tiff
import numpy as np
from scipy import ndimage
import tempfile
from tqdm import tqdm
import os
import plotly.graph_objs as go
import napari
from joblib import Parallel, delayed

from skimage import morphology, exposure, transform
from multiprocessing import Pool, cpu_count

from scipy.ndimage import binary_erosion
from scipy.spatial.distance import cdist

from scipy.ndimage import distance_transform_edt
from skimage.feature import peak_local_max
from skimage.segmentation import watershed

from scipy import ndimage as ndi

from magicgui import magicgui
from skimage.measure import label
from tkinter import filedialog, Tk

# Helper functions

In [7]:
def downsample_stack(image_stack, factor=2):
    """Downsample a 3D image stack by a given factor."""
    return image_stack[:, ::factor, ::factor]

def upsample_stack(labeled_stack, original_shape):
    """
    Upsamples the labeled stack to the original shape.
    """
    if labeled_stack.ndim == 3:
        return np.array([transform.resize(labeled_stack[z], original_shape[1:], 
                                      order=0, preserve_range=True, anti_aliasing=False) 
                     for z in range(labeled_stack.shape[0])], dtype=np.int32)
    elif labeled_stack.ndim == 2:
        return transform.resize(labeled_stack, original_shape, order=0, preserve_range=True, anti_aliasing=False)
    else:
        raise ValueError("Input stack must be either 2D or 3D.")


def intensity_based_segmentation(image_stack, slice_settings=None, temp_dir=None, downsample_factor=2):
    """
    3D cell segmentation using intensity-based region growing with depth-adaptive parameters.
    Parameters are interpolated based on user-defined settings per slice.
    """
    if slice_settings is None:
        slice_settings = {}
    if temp_dir is None:
        temp_dir = tempfile.mkdtemp()
    else:
        os.makedirs(temp_dir, exist_ok=True)

    # Default settings
    default_settings = {
        "intensity_threshold": 0.9,
        "min_volume": 500,
    }

    num_slices = image_stack.shape[0]

    # Interpolation logic for depth-adaptive parameters
    if len(slice_settings) >= 2:
        print("Fitting depth-adaptive parameters based on user-defined settings...")
        z_indices = np.array(sorted(slice_settings.keys()))
        intensity_thresholds = np.array([slice_settings[z].get("intensity_threshold", default_settings["intensity_threshold"]) for z in z_indices])
        min_volumes = np.array([slice_settings[z].get("min_volume", default_settings["min_volume"]) for z in z_indices])

        # Linear interpolation across Z-dimension
        interp_intensity_threshold = np.interp(range(num_slices), z_indices, intensity_thresholds)
        interp_min_volume = np.interp(range(num_slices), z_indices, min_volumes)
    else:
        # Use default values for all slices
        print("Using default parameters for segmentation.")
        interp_intensity_threshold = np.full(num_slices, default_settings["intensity_threshold"])
        interp_min_volume = np.full(num_slices, default_settings["min_volume"])

    # Step 1: Downsample and enhance contrast
    downsampled_stack = downsample_stack(image_stack, factor=downsample_factor)
    enhanced_stack = adaptive_contrast_enhancement(downsampled_stack)

    # Step 2: Perform parallel region growing on downsampled slices
    print("Performing intensity-based region growing in parallel...")
    max_intensity = np.max(enhanced_stack)
    num_workers = min(cpu_count(), num_slices // downsample_factor)
    
    slice_data = [
        (
            z,
            enhanced_stack[z],
            np.argwhere(enhanced_stack[z] >= interp_intensity_threshold[z] * max_intensity),
            max_intensity,
            interp_intensity_threshold[z],
        )
        for z in range(0, enhanced_stack.shape[0], downsample_factor)
    ]

    with Pool(num_workers) as pool:
        results = list(tqdm(pool.imap(process_slice, slice_data), total=len(slice_data), desc="Processing slices"))

    # Combine labeled slices into a downsampled 3D stack
    labeled_stack = np.zeros_like(enhanced_stack, dtype=np.int32)
    current_label = 1
    for slice_index, slice_label, last_label in results:
        labeled_stack[slice_index] = slice_label + (current_label - 1) * (slice_label > 0)
        current_label += last_label

    # Upsample to the original resolution if downsampling was applied
    if downsample_factor > 1:
        labeled_stack = upsample_stack(labeled_stack, image_stack.shape)

    # Step 3: 3D connected component labeling and volume filtering
    print("Labeling connected components in 3D...")
    labeled_stack, num_features = ndimage.label(labeled_stack > 0)

    final_labeled_stack = np.zeros_like(labeled_stack, dtype=np.int32)

    # Define the number of sub-volumes along the Z-axis
    num_sub_volumes = 5  # Adjust this based on the dataset size
    z_slices = labeled_stack.shape[0]
    sub_volume_ranges = np.array_split(range(z_slices), num_sub_volumes)

    # Process each sub-volume with its interpolated min_volume
    for z_range in sub_volume_ranges:
        start_z, end_z = z_range[0], z_range[-1]
        sub_volume = labeled_stack[start_z:end_z + 1]

        # Compute mean or median `min_volume` for this sub-volume
        local_min_volume = int(interp_min_volume[start_z:end_z + 1].mean())

        # Apply small object removal in 3D
        filtered_sub_volume = morphology.remove_small_objects(
            sub_volume, 
            min_size=local_min_volume
        )

        # Store the filtered sub-volume back in the final stack
        final_labeled_stack[start_z:end_z + 1] = filtered_sub_volume

    # Relabel the entire stack to ensure consistent labels across sub-volumes
    final_labeled_stack, num_features = ndimage.label(final_labeled_stack > 0)

    # Save labeled stack to a file
    output_path = os.path.join(temp_dir, "segmented_cells.npy")
    np.save(output_path, final_labeled_stack)

    print(f"Segmentation complete. Results saved at {output_path}")
    return final_labeled_stack

    
def merge_close_rois_optimized(roi_array, temp_dir, dilation_radius=1, chunk_size=10, num_threads=2):
    """
    Optimized function to merge close ROIs by dilating and re-labeling connected components,
    with a focus on quality and memory efficiency.
    
    Parameters:
    - roi_array: 3D numpy array with labeled regions, where each cell has a unique ROI number
    - dilation_radius: int, radius for dilation to bridge close components
    - chunk_size: int, size of chunks to process independently in the z-axis
    - num_threads: int, number of threads for parallel processing
    
    Returns:
    - Path to the memmapped merged ROI array
    """
    print("Starting quality-preserving merging of close ROIs...")

    # Prepare memmap for output array
    output_path = os.path.join(temp_dir, "merged_roi_array_optimized.dat")
    merged_roi_array_memmap = np.memmap(output_path, dtype=np.int32, mode='w+', shape=roi_array.shape)
    
    # Initialize an empty boolean array for dilated regions
    dilated_roi_array = np.memmap(os.path.join(temp_dir, "dilated_roi_array.dat"), dtype=bool, mode='w+', shape=roi_array.shape)
    dilated_roi_array[:] = False  # Set all to False initially

    # Get unique labels, excluding background (label 0)
    unique_labels = np.unique(roi_array)
    unique_labels = unique_labels[unique_labels != 0]
    
    def dilate_roi(label):
        """Dilates the given ROI label and updates the global dilated array."""
        # Mask for current label
        mask = (roi_array == label)
        
        # Dilate mask and add to the dilated ROI array (in-place)
        dilated = morphology.binary_dilation(mask, morphology.ball(dilation_radius))
        dilated_roi_array[:] |= dilated  # Union with existing dilated regions

    # Process each label in parallel
    print("Dilating each ROI label in parallel...")
    Parallel(n_jobs=num_threads)(delayed(dilate_roi)(label) for label in tqdm(unique_labels))

    # After dilation, label connected components in the dilated ROI array
    print("Labeling connected components to create merged ROIs...")
    labeled_dilated_array, num_features = ndimage.label(dilated_roi_array)

    # Store labeled regions in the final output memmap array
    merged_roi_array_memmap[:] = labeled_dilated_array
    
    # Clean up intermediate arrays
    del dilated_roi_array

    print(f"Merging complete. Number of merged ROIs: {num_features}")
    print("Output saved to:", output_path)
    
    return merged_roi_array_memmap


def adaptive_contrast_enhancement(image_stack, low_percentile=25, high_percentile=99.9):
    """
    Enhances contrast for each slice by stretching intensities based on slice-specific percentiles.
    """
    enhanced_stack = np.zeros_like(image_stack, dtype=np.float32)
    for z in range(image_stack.shape[0]):
        slice_img = image_stack[z]
        p_low, p_high = np.percentile(slice_img, (low_percentile, high_percentile))
        enhanced_stack[z] = exposure.rescale_intensity(slice_img, in_range=(p_low, p_high), out_range=(0, 1))
    return enhanced_stack


# def downsample_stack(image_stack, factor=2):
#     """
#     Downsamples each slice by a given factor to reduce processing load.
#     """
#     print(f"Downsampling image stack by factor of {factor}...")
#     return np.array([transform.resize(image_stack[z], 
#                                       (image_stack.shape[1] // factor, image_stack.shape[2] // factor),
#                                       anti_aliasing=True) for z in range(image_stack.shape[0])], dtype=np.float32)


# def upsample_stack(labeled_stack, original_shape):
#     """
#     Upsamples the labeled stack to the original shape.
#     """
#     print("Upsampling labeled stack to original resolution...")
#     return np.array([transform.resize(labeled_stack[z], original_shape[1:], 
#                                       order=0, preserve_range=True, anti_aliasing=False) 
#                      for z in range(labeled_stack.shape[0])], dtype=np.int32)


def process_slice(slice_data):
    """
    Process a single slice for region growing and labeling based on intensity threshold.
    """
    slice_index, slice_intensity, seeds, max_intensity, intensity_threshold = slice_data
    slice_label = np.zeros_like(slice_intensity, dtype=np.int32)
    current_label = 1
    
    for seed in seeds:
        region = morphology.flood(slice_intensity, tuple(seed), tolerance=(max_intensity * (1 - intensity_threshold)))
        slice_label[region] = current_label
        current_label += 1
    
    return slice_index, slice_label, current_label

# def process_slice(args):
#     """Process a single 2D slice for region-growing segmentation."""
#     slice_index, slice_image, seeds, max_intensity, intensity_threshold = args
#     mask = np.zeros_like(slice_image, dtype=np.int32)
#     if len(seeds) > 0:
#         for seed in seeds:
#             if slice_image[tuple(seed)] >= intensity_threshold * max_intensity:
#                 mask[tuple(seed)] = 1
#         mask = morphology.binary_dilation(mask)
#     labeled_slice = label(mask)
#     return slice_index, labeled_slice, labeled_slice.max()

# def intensity_based_segmentation(image_stack, intensity_threshold=0.9, min_volume=500, temp_dir=None, downsample_factor=2):
#     """
#     3D cell segmentation using intensity-based region growing with parallel processing.
#     """
#     if temp_dir is None:
#         temp_dir = tempfile.mkdtemp()
#     else:
#         os.makedirs(temp_dir, exist_ok=True)

#     # Step 1: Downsample and enhance contrast
#     downsampled_stack = downsample_stack(image_stack, factor=downsample_factor)
#     enhanced_stack = adaptive_contrast_enhancement(downsampled_stack)

#     # Step 2: Identify high-intensity seed points selectively
#     max_intensity = np.max(enhanced_stack)
#     seed_threshold = intensity_threshold * max_intensity
#     seeds = [np.argwhere(enhanced_stack[z] >= seed_threshold) for z in range(0, enhanced_stack.shape[0], downsample_factor)]
    
#     # Step 3: Perform parallel region growing on downsampled slices
#     print("Performing intensity-based region growing in parallel...")
#     num_workers = min(cpu_count(), len(seeds))
    
#     slice_data = [
#         (z, enhanced_stack[z], seeds[z // downsample_factor], max_intensity, intensity_threshold)
#         for z in range(0, enhanced_stack.shape[0], downsample_factor)
#     ]

#     with Pool(num_workers) as pool:
#         results = list(tqdm(pool.imap(process_slice, slice_data), total=len(slice_data), desc="Processing slices"))

#     # Combine labeled slices into a downsampled 3D stack
#     labeled_stack = np.zeros_like(enhanced_stack, dtype=np.int32)
#     current_label = 1
#     for slice_index, slice_label, last_label in results:
#         labeled_stack[slice_index] = slice_label + (current_label - 1) * (slice_label > 0)
#         current_label += last_label

#     # Upsample to the original resolution if downsampling was applied
#     if downsample_factor > 1:
#         labeled_stack = upsample_stack(labeled_stack, image_stack.shape)

#     # Step 4: 3D connected component labeling and volume filtering
#     print("Labeling connected components in 3D...")
#     labeled_stack, num_features = ndimage.label(labeled_stack > 0)
#     final_labeled_stack = morphology.remove_small_objects(labeled_stack, min_size=min_volume)

#     # Save labeled stack to a file
#     output_path = os.path.join(temp_dir, "segmented_cells.npy")
#     np.save(output_path, final_labeled_stack)

#     print(f"Segmentation complete. Results saved at {output_path}")
#     return final_labeled_stack


def save_surface_points_to_memmap(roi_surface, z_spacing, xy_spacing, temp_dir, label_value):
    """
    Save surface points to a memory-mapped file, scaled by the given spacing.
    """
    surface_points = np.argwhere(roi_surface).astype(float)
    surface_points[:, 0] *= z_spacing
    surface_points[:, 1] *= xy_spacing
    surface_points[:, 2] *= xy_spacing

    # Create a memmap file to store the surface points for this ROI
    memmap_path = os.path.join(temp_dir, f"roi_{label_value}_surface.memmap")
    surface_memmap = np.memmap(memmap_path, dtype='float32', mode='w+', shape=surface_points.shape)
    surface_memmap[:] = surface_points[:]
    del surface_memmap  # Ensure it’s written to disk

    return memmap_path, surface_points.shape


def calculate_min_distance_memmap(roi_memmap_path, roi_shape, other_memmap_path, other_shape, chunk_size=500):
    """
    Calculate minimum surface-to-surface distance between two ROIs stored in memory-mapped files.
    Returns the minimum distance and coordinates of the closest points on each surface.
    """
    min_distance = np.inf
    min_coords = (None, None)  # Initialize to store coordinates of closest points
    roi_surface_memmap = np.memmap(roi_memmap_path, dtype='float32', mode='r', shape=roi_shape)
    other_surface_memmap = np.memmap(other_memmap_path, dtype='float32', mode='r', shape=other_shape)

    # Calculate distance in chunks
    for i in range(0, other_shape[0], chunk_size):
        other_chunk = other_surface_memmap[i:i+chunk_size]
        distances = cdist(roi_surface_memmap, other_chunk)
        min_dist_in_chunk = distances.min()
        
        # If this chunk contains a closer pair, update min_distance and coordinates
        if min_dist_in_chunk < min_distance:
            min_distance = min_dist_in_chunk
            # Find the indices of the closest pair of points
            roi_idx, other_idx = np.unravel_index(distances.argmin(), distances.shape)
            min_coords = (roi_surface_memmap[roi_idx], other_chunk[other_idx])
        
        if min_distance == 0:
            break  # Stop if distance is zero (surfaces touch)

    return min_distance, min_coords  # Return both distance and coordinates


def compute_surface_min_distances_memmap(labeled_stack, z_spacing, xy_spacing, temp_dir=None, select_ids=None, chunk_size=500, n_jobs=-1):
    """
    Calculate minimum surface-to-surface distances between ROIs using memory-mapped arrays for memory efficiency.
    
    Parameters:
    - labeled_stack: 3D numpy array containing labeled ROIs
    - z_spacing: Spacing in microns between z-planes
    - xy_spacing: Spacing in microns for x and y dimensions
    - temp_dir: Path to an external directory for temporary files (default is a new temp directory)
    - chunk_size: Number of points to load per chunk for distance calculation
    - n_jobs: Number of parallel jobs (default is -1, using all processors)
    
    Returns:
    - distances: Array of minimum distances between ROIs
    - coordinates: List of coordinate pairs for each minimum distance
    """
    # Create or use the specified temporary directory
    if temp_dir is None:
        temp_dir = tempfile.mkdtemp()
    else:
        os.makedirs(temp_dir, exist_ok=True)

    # Get unique ROI labels
    labels = np.unique(labeled_stack)
    labels = labels[labels != 0]

    # Save each ROI's surface points to a memmap file
    roi_memmaps = {}
    for label_value in labels:
        roi_mask = (labeled_stack == label_value)
        eroded_roi = binary_erosion(roi_mask)
        roi_surface = roi_mask & ~eroded_roi
        memmap_path, shape = save_surface_points_to_memmap(roi_surface, z_spacing, xy_spacing, temp_dir, label_value)
        roi_memmaps[label_value] = (memmap_path, shape)

    # Compute minimum distances in parallel
    def calculate_min_distance_for_label(label_value):
        min_distance = np.inf
        closest_points = (None, None)
        closest_roi_id = None
        roi_memmap_path, roi_shape = roi_memmaps[label_value]

        for other_label in labels:
            if other_label == label_value:
                continue

            other_memmap_path, other_shape = roi_memmaps[other_label]
            label_min_distance, label_coords = calculate_min_distance_memmap(
                roi_memmap_path, roi_shape, other_memmap_path, other_shape, chunk_size
            )
            if label_min_distance < min_distance:
                min_distance = label_min_distance
                closest_points = label_coords
                closest_roi_id = other_label
            
            if min_distance == 0:
                break  # Stop if distance is zero (surfaces touch)

        return set(label_value, closest_roi_id), min_distance, closest_points  # Return both distance and coordinates

    # Run the parallel computation
    if select_ids is not None:
        select_labels = select_ids
    else:
        select_labels = labels
    results = Parallel(n_jobs=n_jobs)(
        delayed(calculate_min_distance_for_label)(label_value) for label_value in select_labels
    )

    # Split results into distances and coordinates
    roi_id_set, min_surface_distances, closest_coords = zip(*results)

    # Remove all the files inside the temporary directory
    for file in os.listdir(temp_dir):
        os.remove(os.path.join(temp_dir, file))

    #Save min_surface_distances and closest_coords to files
    output_path0 = os.path.join(temp_dir, "closest_roi_ids.npy")
    output_path1 = os.path.join(temp_dir, "min_surface_distances.npy")
    output_path2 = os.path.join(temp_dir, "closest_coords.npy")
    np.save(output_path0, np.array(roi_id_set))
    np.save(output_path1, np.array(min_surface_distances))
    np.save(output_path2, np.array(closest_coords))


    return np.array(min_surface_distances), closest_coords  # Return both distances and coordinates


def compute_roi_volumes(labeled_stack, z_spacing, xy_spacing, labels=None):
    """
    Compute the volume of each ROI in the labeled stack, given the true voxel dimensions.
    
    Parameters:
    - labeled_stack: 3D numpy array containing labeled ROIs
    - z_spacing: Spacing in microns between z-planes
    - xy_spacing: Spacing in microns for x and y dimensions
    
    Returns:
    - roi_volumes: Dictionary with ROI labels as keys and volumes in cubic microns as values
    """
    print("Computing ROI volumes...")
    # Calculate the volume of a single voxel in cubic microns
    voxel_volume = z_spacing * xy_spacing * xy_spacing
    
    # Get unique ROI labels (excluding 0 as it represents the background)
    if labels is None:
        labels = np.unique(labeled_stack)
        labels = labels[labels != 0]
    
    # Calculate volume for each ROI
    roi_volumes = {}
    for label_value in labels:
        # Count the number of voxels for this ROI
        voxel_count = np.sum(labeled_stack == label_value)
        
        # Compute volume in cubic microns
        roi_volume = voxel_count * voxel_volume
        roi_volumes[label_value] = roi_volume

    return roi_volumes


def visualize_with_napari_overlay(original_stack, merged_roi_array, closest_coords=None, z_scale_factor=2, z_spacing=1, xy_spacing=1):
    """
    Visualize the original 3D image stack overlaid with the ROI segmentation in Napari,
    with lines connecting the closest points between ROIs.

    Parameters:
    - original_stack: 3D numpy array of the original microscopy image data.
    - merged_roi_array: 3D numpy array with labeled regions, where each cell has a unique ROI number.
    - closest_coords: List of tuples, where each tuple contains two coordinate arrays (point1, point2) representing 
                      the closest points between ROIs.
    - z_scale_factor: Scaling factor to adjust z-dimension for proper aspect ratio.
    - z_spacing: Spacing in microns between z-planes (used to convert line coordinates to voxel units).
    - xy_spacing: Spacing in microns for x and y dimensions (used to convert line coordinates to voxel units).
    """
    print("Launching Napari visualization with overlay...")

    # Initialize the Napari viewer
    viewer = napari.Viewer()

    # Add the original image stack as the first layer
    viewer.add_image(
        original_stack,
        name="Original Image",
        scale=(z_scale_factor, 1, 1),  # Apply only z scaling factor
        colormap='gray',
        opacity=0.8
    )

    # Add the segmented ROIs as a labels layer
    viewer.add_labels(
        merged_roi_array,
        name="ROIs",
        scale=(z_scale_factor, 1, 1),  # Apply only z scaling factor
        opacity=0.6  # Adjust opacity for overlay effect
    )

    if closest_coords is not None:
        # Prepare lines for closest points, converting each coordinate to voxel units
        lines = []
        for point1, point2 in closest_coords:
            # Convert micron coordinates to voxel coordinates
            start_point = (point1[0] * z_scale_factor, point1[1] / xy_spacing, point1[2] / xy_spacing)
            end_point = (point2[0] * z_scale_factor, point2[1] / xy_spacing, point2[2] / xy_spacing)
            lines.append([start_point, end_point])

        # Add the lines connecting closest points as shapes in Napari
        viewer.add_shapes(
            lines,
            shape_type='line',
            edge_color='red',
            edge_width=1.5,
            name="Closest Points Connections"
        )

    # Switch to 3D mode
    viewer.dims.ndisplay = 3

    # Start the Napari viewer event loop
    napari.run()


def split_large_rois_with_intensity(labeled_stack, z_spacing, xy_spacing, mean_guess, std_guess, tempdir, max_iters=10):
    """
    Identifies large ROIs based on volume and recursively splits them into smaller ROIs.

    Parameters:
    - labeled_stack: 3D numpy array containing labeled ROIs.
    - z_spacing: Spacing in microns between z-planes.
    - xy_spacing: Spacing in microns for x and y dimensions.
    - mean_guess: the best guess for the mean volume of single cells
    - std_guess: the best guess for the standard deviation of single cells
    - max_iters: Maximum number of iterations for splitting.

    Returns:
    - updated_stack: Updated 3D numpy array with split ROIs.
    """
    volumes = compute_roi_volumes(labeled_stack, z_spacing=z_spacing, xy_spacing=xy_spacing)
    # mean_volume = np.mean(list(volumes.values()))
    # median_volume = np.median(list(volumes.values()))
    # std_volume = np.std(list(volumes.values()))
    updated_stack = np.copy(labeled_stack)
    print("Identifying large ROIs...")
    
    # Find large ROIs
    roi_labels = np.unique(labeled_stack)
    large_rois = [label for label, volume in volumes.items() if volume > mean_guess + 2 * std_guess]

    for roi_label in large_rois:
        if roi_label == 0:
            continue
        print(f"Processing large ROI {roi_label}...")
        roi_coords = np.array(np.where(updated_stack == roi_label)).T  # Extract coordinates
        updated_stack = recursive_split(
            updated_stack, roi_coords, z_spacing, xy_spacing,
            mean_guess, std_guess, max_iters=max_iters
        )

    #save updated_stack as memmap
    output_path = os.path.join(tempdir, "updated_stack.npy")
    np.save(output_path, updated_stack)

    print("Finished processing all large ROIs.")

    return updated_stack


def recursive_split(labeled_stack, roi_coords, z_spacing, xy_spacing, mean_guess, std_guess, max_iters=10):
    """
    Recursively split a large ROI into smaller ROIs using watershed segmentation.

    Parameters:
    - labeled_stack: 3D numpy array containing labeled ROIs.
    - roi_coords: Coordinates (z, y, x) of the ROI mask.
    - z_spacing: Spacing in microns between z-planes.
    - xy_spacing: Spacing in microns for x and y dimensions.
    - mean_guess: the best guess for the mean volume of single cells
    - std_guess: the best guess for the standard deviation of single cells
    - max_iters: Maximum number of iterations for recursive splitting.

    Returns:
    - updated_stack: Updated 3D numpy array with split ROIs.
    """
    print("Starting recursive split...")
    updated_stack = np.copy(labeled_stack)

    # Compute bounding box for ROI
    z_coords, y_coords, x_coords = roi_coords[:, 0], roi_coords[:, 1], roi_coords[:, 2]
    z_min, z_max = z_coords.min(), z_coords.max()
    y_min, y_max = y_coords.min(), y_coords.max()
    x_min, x_max = x_coords.min(), x_coords.max()

    print(f"Bounding box: z({z_min}-{z_max}), y({y_min}-{y_max}), x({x_min}-{x_max})")

    # Extract bounding box region
    roi_mask_bbox = np.zeros((z_max - z_min + 1, y_max - y_min + 1, x_max - x_min + 1), dtype=bool)
    roi_mask_bbox[z_coords - z_min, y_coords - y_min, x_coords - x_min] = True

    # Compute distance transform
    print("Computing distance transform...")
    distance_map = distance_transform_edt(roi_mask_bbox, sampling=(z_spacing, xy_spacing, xy_spacing))

    # Calculate the expected number of splits based on the volume ratio
    ones_roi_mask_bbox = np.zeros_like(roi_mask_bbox, dtype=np.int32)
    ones_roi_mask_bbox[roi_mask_bbox==True] = 1
    print(f"Bounding box dimensions: {ones_roi_mask_bbox.shape}")
    roi_volume = compute_roi_volumes(ones_roi_mask_bbox, z_spacing=z_spacing, xy_spacing=xy_spacing)[1]
    expected_splits = max(1, int(np.round(roi_volume / mean_guess)))
    print(f"Expected splits: {expected_splits}")

    coords = peak_local_max(distance_map, footprint=np.ones((10, 25, 25)), labels=roi_mask_bbox)
    mask_var = np.zeros(distance_map.shape, dtype=bool)
    mask_var[tuple(coords.T)] = True
    markers, _ = ndi.label(mask_var)

    # Apply watershed
    print("Applying watershed...")
    segmented_bbox = watershed(-distance_map, markers, mask=roi_mask_bbox)
    if segmented_bbox.max() == 1:
        print("No splits detected; returning...")
        return updated_stack

    print(f"Actual splits: {segmented_bbox.max()}")

    # Iterate over segments
    max_label = labeled_stack.max()
    for seg_label in np.unique(segmented_bbox):
        if seg_label == 0:
            continue  # Skip background

        # Extract segment mask
        seg_mask_bbox = (segmented_bbox == seg_label)
        seg_volume = np.sum(seg_mask_bbox) * z_spacing * xy_spacing * xy_spacing  # Compute volume

        if seg_volume > mean_guess + 2 * std_guess:
            print(f"Segment {seg_label} is too large; recursing...")
            seg_coords = np.array(np.where(seg_mask_bbox)).T + np.array([z_min, y_min, x_min])
            updated_stack = recursive_split(
                updated_stack, seg_coords, z_spacing, xy_spacing,
                mean_guess, std_guess, max_iters=max_iters
            )
        else:
            print(f"Segment {seg_label} is fine; assigning new label...")
            seg_coords = np.array(np.where(seg_mask_bbox)).T + np.array([z_min, y_min, x_min])
            updated_stack[seg_coords[:, 0], seg_coords[:, 1], seg_coords[:, 2]] = max_label + 1
            max_label += 1

    print("Finished recursive split.")
    return updated_stack


# Interactive Segmentation Function
def interactive_segmentation_with_file_prompt():
    """
    Launch a file prompt to load a .tif file, initialize Napari, and enable interactive segmentation.
    """
    # Prompt user to load a .tif file
    Tk().withdraw()  # Hide the main tkinter window
    file_loc = filedialog.askopenfilename(title="Select a .tif file", filetypes=[("TIFF files", "*.tif")])
    if not file_loc:
        print("No file selected. Exiting.")
        return

    # Load the .tif file
    image_stack = tiff.imread(file_loc)
    print(f"Loaded stack with shape {image_stack.shape}")

    #Create a mask with the same shape as the image stack
    initial_segmentation = np.zeros_like(image_stack, dtype=np.int32)

    # Prompt user for voxel dimensions
    # voxel_x = float(input("Image x dimension in microns: "))
    # voxel_y = float(input("Image y dimension in microns: "))
    # voxel_z = float(input("Distance between z-slices in microns: "))

    voxel_x = 471.4
    voxel_y = 471.4
    voxel_z = 1

    # Create a settings dictionary to store parameters for each slice
    slice_settings = {}

    def apply_mask(slice_index, intensity_threshold, min_volume, downsample_factor):
        """
        Generate and apply the segmentation mask for a given slice.
        """
        image_slice = enhanced_stack[slice_index]
        downsampled = downsample_stack(image_slice[None, ...], factor=downsample_factor)[0]
        enhanced = downsampled
        seed_threshold = intensity_threshold * np.max(enhanced)
        seeds = np.argwhere(enhanced >= seed_threshold)

        mask = np.zeros_like(enhanced, dtype=np.int32)
        for seed in seeds:
            if enhanced[tuple(seed)] >= seed_threshold:
                mask[tuple(seed)] = 1
        mask = morphology.binary_dilation(mask)
        labeled_mask = label(mask)
        filtered_mask = morphology.remove_small_objects(labeled_mask, min_size=min_volume)

        # Upsample the mask to original resolution
        if downsample_factor > 1:
            filtered_mask = upsample_stack(filtered_mask, image_slice.shape)

        #replace the correct z index in the mask
        initial_segmentation[slice_index] = filtered_mask

        # Update the mask layer in Napari
        viewer.layers["Initial segmentation"].data = initial_segmentation

        # Save the current slice settings
        slice_settings[slice_index] = {
            "intensity_threshold": intensity_threshold,
            "min_volume": min_volume,
            "downsample_factor": downsample_factor,
        }

    @magicgui(
        intensity_threshold={"label": "Intensity Threshold", "min": 0.0, "max": 1.0, "step": 0.01},
        min_volume={"label": "Minimum Volume (px)", "min": 1, "step": 10},
        downsample_factor={"label": "Downsample Factor", "min": 1, "max": 4, "step": 1},
        call_button="Update Mask"
    )
    def update_mask(intensity_threshold: float = 0.9, min_volume: int = 500, downsample_factor: int = 2):
        """
        Updates the segmentation mask for the currently displayed slice.
        """
        slice_idx = viewer.dims.current_step[0]  # Get the current slice index from Napari
        apply_mask(slice_idx, intensity_threshold, min_volume, downsample_factor)

    @magicgui(call_button="Run Full Segmentation")
    def run_full_segmentation():
        """
        Runs the full 3D segmentation pipeline using depth-adaptive parameters.
        """
        print("Running segmentation...")
        #Get directory of the data file
        inputdir = os.path.dirname(file_loc)

        #Get the filename basename
        basename = os.path.basename(file_loc).split('.')[0]
        #check if basename_processed directory exists
        processed_dir = os.path.join(inputdir, f"{basename}_processed")

        if not os.path.exists(processed_dir):
            os.makedirs(processed_dir)

        #Check if the segmented_cells.npy file exists
        if os.path.exists(os.path.join(processed_dir, "segmented_cells.npy")):
            labeled_cells = np.load(os.path.join(processed_dir, "segmented_cells.npy"))
        else:
            labeled_cells =  intensity_based_segmentation(image_stack, slice_settings=slice_settings, temp_dir=processed_dir, downsample_factor=2)
        #Check if merged_roi_array_optimized.dat exists
        merged_roi_array_loc = os.path.join(processed_dir, "merged_roi_array_optimized.dat")
        if os.path.exists(merged_roi_array_loc):
            merged_roi_array = np.memmap(merged_roi_array_loc, dtype=np.int32, mode='r+', shape=image_stack.shape)
        else:
            merged_roi_array = merge_close_rois_optimized(labeled_cells, temp_dir = processed_dir, dilation_radius=1, chunk_size=20, num_threads=46)
        #check if updated_stack.npy exists
        updated_stack_loc = os.path.join(processed_dir, "updated_stack.npy")
        if os.path.exists(updated_stack_loc):
            updated_stack = np.load(updated_stack_loc)
        else:
            updated_stack = split_large_rois_with_intensity(merged_roi_array, voxel_z, x_spacing, mean_guess=5000, std_guess=2500, tempdir = processed_dir, max_iters=10)
            
        
        viewer.add_labels(labeled_cells, name="Intermediate segmentation 1", scale=(z_scale_factor, 1, 1))
        viewer.add_labels(merged_roi_array, name="Intermediate segmentation 2", scale=(z_scale_factor, 1, 1))
        viewer.add_labels(updated_stack, name="Segmentation without large volumes", scale=(z_scale_factor, 1, 1))

    # Initialize the Napari viewer
    viewer = napari.Viewer()
    x_spacing = voxel_x / image_stack.shape[1]
    z_scale_factor = voxel_z/x_spacing

    viewer.add_image(image_stack, name="Original stack", scale=(z_scale_factor, 1, 1))

    enhanced_stack = adaptive_contrast_enhancement(image_stack)
    viewer.add_image(enhanced_stack, name="Enhanced stack", scale=(z_scale_factor, 1, 1))
    viewer.add_labels(initial_segmentation, name="Initial segmentation", scale=(z_scale_factor, 1, 1))

    # Add GUI components to the viewer
    viewer.window.add_dock_widget(update_mask, area="right")
    viewer.window.add_dock_widget(run_full_segmentation, area="right")

    napari.run()


# Load and process data

In [None]:
labeled_cells =  intensity_based_segmentation(image_stack, slice_settings=slice_settings, temp_dir=processed_dir, downsample_factor=2)
merged_roi_array = np.memmap("data/merged_roi_array_optimized.dat", dtype=np.int32, mode='r+', shape=img.shape)
xy_dim = 471.4 
xy_spacing = xy_dim / merged_roi_array.shape[1]

visualize_with_napari_overlay(img, merged_roi_array, closest_coords=None,  z_scale_factor=1/xy_spacing, z_spacing=1, xy_spacing=xy_spacing)

Launching Napari visualization with overlay...


In [10]:
xy_dim = 471.4 
xy_spacing = xy_dim / merged_roi_array.shape[1]
min_surface_distances, closest_points_dict = compute_surface_min_distances_memmap(
    merged_roi_array, z_spacing=1, xy_spacing=xy_spacing, temp_dir='/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/tmp', chunk_size=500, n_jobs=44
)


In [7]:
#Using true dimensions in microns, calculate scaling factor for visualisation
img = tiff.imread('data/Added later-D31_no cyto_1.tif')
merged_roi_array = np.memmap("data/merged_roi_array_optimized.dat", dtype=np.int32, mode='r+', shape=img.shape)
closest_points_dict = np.load("/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/tmp/closest_coords.npy", allow_pickle=True)

xy_dim = 471.4 
z_spacing = 1
xy_spacing = xy_dim / merged_roi_array.shape[1]
enhanced_stack = adaptive_contrast_enhancement(img)
visualize_with_napari_overlay(enhanced_stack, merged_roi_array, closest_points_dict,  z_scale_factor=1/xy_spacing, z_spacing=1, xy_spacing=xy_spacing)
# visualize_with_napari_overlay(img, merged_roi_array, closest_points_dict, z_scale_factor=1/xy_spacing)


Enhancing contrast per slice:   2%|▏         | 3/195 [00:00<00:07, 24.28it/s]

Applying adaptive contrast enhancement...


Enhancing contrast per slice: 100%|██████████| 195/195 [00:07<00:00, 27.06it/s]


Launching Napari visualization with overlay...


In [8]:
interactive_segmentation_with_file_prompt()

Loaded stack with shape (90, 1536, 1536)


In [8]:
#Plot a histogram of the volumes
image_stack = tiff.imread('data/Added later-D31_no cyto_1.tif')
processed_dir = "data/Added later-D31_no cyto_1_processed"
merged_roi_array_loc = os.path.join(processed_dir, "merged_roi_array_optimized.dat")
merged_roi_array = np.memmap(merged_roi_array_loc, dtype=np.int32, mode='r+', shape=image_stack.shape)
xy_spacing = 471.4 / merged_roi_array.shape[1]
volumes = compute_roi_volumes(merged_roi_array, z_spacing=1, xy_spacing=xy_spacing)


In [12]:
volumes_list = list(volumes.values())
fig = go.Figure(data=[go.Histogram(x=volumes_list)])
fig.update_layout(title_text='Histogram of ROI Volumes')
fig.show()

In [None]:
import numpy as np
from scipy.ndimage import distance_transform_edt, label
from skimage.measure import regionprops

def refine_rois(labeled_stack, z_spacing, xy_spacing, temp_dir, min_volume=4000, max_volume=10000, distance_cutoff=30):
    """
    Refines ROIs in a 3D labeled stack:
    - Removes small ROIs that are far from larger ROIs.
    - Merges small ROIs with the closest larger ROI if within the distance cutoff.

    Parameters:
        labeled_stack (numpy.ndarray): 3D labeled image stack.
        min_volume (int): Minimum volume (in voxels) for a "normal" ROI.
        max_volume (int): Maximum volume (in voxels) to exclude exceptionally large ROIs.
        distance_cutoff (int): Maximum distance (in pixels) for merging small ROIs with normal ROIs.

    Returns:
        numpy.ndarray: Refined 3D labeled stack.
    """
    print("Refining ROIs...")

    #Create a memmap to hold the final labeled stack
    final_labels = np.memmap(os.path.join(temp_dir, "final_labels.dat"), dtype=np.int32, mode='w+', shape=labeled_stack.shape)

    #Copy the labeled stack to the final labels
    final_labels[:] = labeled_stack[:]

    # Step 1: Calculate all the volumes
    all_volumes = compute_roi_volumes(final_labels, z_spacing=z_spacing, xy_spacing=xy_spacing)

    # Separate ROIs into categories based on volume
    small_rois = []
    normal_rois = []
    for roi_label, roi_volume in all_volumes.items():
        if roi_volume < min_volume:
            small_rois.append(roi_label)
        else:
            normal_rois.append(roi_label)
    
    to_retest = [k for k in all_volumes.keys()]
    to_remove = []

    while True:
        print(f"Identified {len(small_rois)} small ROIs and {len(normal_rois)} normal ROIs.")

        closest_roi_ids, min_surface_distances, closest_points_dict = compute_surface_min_distances_memmap(final_labels, 
                                                                                                       z_spacing, 
                                                                                                       xy_spacing, 
                                                                                                       temp_dir=temp_dir, 
                                                                                                       select_ids=normal_rois,
                                                                                                       chunk_size=500, 
                                                                                                       n_jobs=-1)
        removed_small_rois = []
        #Sort each three lists by the min_surface_distances (from smallest to largest)
        closest_roi_ids = [x for _, x in sorted(zip(min_surface_distances, closest_roi_ids))]
        min_surface_distances = sorted(min_surface_distances)
        closest_points_dict = [x for _, x in sorted(zip(min_surface_distances, closest_points_dict))]

        # Step 2: Identify normal ROIs whose shortest distance is to a small ROI
        for distances_pos, roi_pair in tqdm(enumerate(closest_roi_ids)):
            roi_1, roi_2 = roi_pair
            if roi_1 in removed_small_rois or roi_2 in removed_small_rois:
                continue
            small_roi, normal_roi = None, None
            if roi_1 in normal_rois and roi_2 in small_rois:
                small_roi = roi_2
                normal_roi = roi_1
            elif roi_2 in normal_rois and roi_1 in small_rois:
                small_roi = roi_1
                normal_roi = roi_2
            if small_roi is not None:
                if min_surface_distances[distances_pos] < distance_cutoff:
                    # Modify labels in the final labels array to merge the small ROI with the normal ROI
                    final_labels[final_labels == small_roi] = normal_roi
                    removed_small_rois.append(small_roi)
                    #Remove small ROI from all_volumes
                    all_volumes.pop(small_roi)
                    #Remove small ROI from small_rois
                    small_rois.remove(small_roi)

                    #To make sure the surface is continuous, fill any background values with normal roi valuesa along the line of shorthest distance
                    coords = closest_points_dict[distances_pos]
                    line = np.linspace(coords[0], coords[1], num=100)
                    for point in line:
                        point = point.astype(int)
                        if final_labels[point[0], point[1], point[2]] == 0:
                            final_labels[point[0], point[1], point[2]] = normal_roi
                    #Recalculate the volume of the normal ROI
                    normal_roi_volume = compute_roi_volumes(final_labels, z_spacing=z_spacing, xy_spacing=xy_spacing, labels=[normal_roi])[normal_roi]
                    all_volumes[normal_roi] = normal_roi_volume
        if len(removed_small_rois) == 0:
            break
    #Remove all the remaining small ROIs that are not close to any normal ROIs (merge with background)
    for roi in small_rois:
        if roi not in removed_small_rois:
            final_labels[final_labels == roi] = 0

    print("ROI refinement complete.")
    return final_labels


In [16]:
#Load updated_stack.npy
updated_stack = np.load('/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/Takeshi_data/Added later-D31_no cyto_1_processed/updated_stack.npy')

xy_spacing = 471.4 / updated_stack.shape[1]
refine_rois(updated_stack, 
            z_spacing=1, 
            xy_spacing=xy_spacing, 
            temp_dir='/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/Takeshi_data/Added later-D31_no cyto_1_processed', 
            min_volume=4000, 
            max_volume=10000, 
            distance_cutoff=30)


Refining ROIs...
Computing ROI volumes...
Identified 4646 small ROIs and 112 normal ROIs.


KeyboardInterrupt: 

In [34]:
import numpy as np
import os
import multiprocessing as mp
from scipy.spatial import cKDTree
from tqdm import tqdm

import numpy as np
import multiprocessing as mp

def extract_surface_points_for_label(args):
    """
    Extracts surface points for a single ROI using matrix operations.
    
    Parameters:
        args (tuple): Contains the labeled stack, label, z_spacing, and xy_spacing.
    
    Returns:
        tuple: A tuple of (label, scaled surface points).
    """
    labeled_stack, label, z_spacing, xy_spacing = args
    binary_mask = (labeled_stack == label)

    # Compute shifts for neighbors in a 3D grid
    shifts = np.array([
        [-1,  0,  0], [1,  0,  0],  # z-axis neighbors
        [ 0, -1,  0], [0,  1,  0],  # y-axis neighbors
        [ 0,  0, -1], [0,  0,  1],  # x-axis neighbors
    ])
    
    # Find surface voxels
    padded_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=False)
    neighbors = np.zeros_like(binary_mask, dtype=bool)
    
    for shift in shifts:
        shifted = padded_mask[
            1 + shift[0]:1 + shift[0] + binary_mask.shape[0],
            1 + shift[1]:1 + shift[1] + binary_mask.shape[1],
            1 + shift[2]:1 + shift[2] + binary_mask.shape[2],
        ]
        neighbors |= ~shifted  # Mark as surface if adjacent to background
    
    surface_voxels = np.argwhere(binary_mask & neighbors)
    
    # Scale voxel coordinates to physical distances
    scaled_points = surface_voxels * [z_spacing, xy_spacing, xy_spacing]
    return label, scaled_points

def compute_surface_points(labeled_stack, z_spacing, xy_spacing, n_processes=None):
    """
    Computes the 3D surface points of each labeled ROI in the stack using parallel processing.
    
    Parameters:
        labeled_stack (np.ndarray): 3D array with labeled ROIs.
        z_spacing (float): Spacing in microns between z-planes.
        xy_spacing (float): Spacing in microns for x and y dimensions.
        n_processes (int, optional): Number of processes to use. Defaults to os.cpu_count().
    
    Returns:
        dict: A dictionary where keys are ROI IDs and values are arrays of surface points.
    """
    print("Computing surface points for each ROI...")
    unique_labels = np.unique(labeled_stack)
    unique_labels = unique_labels[unique_labels != 0]  # Exclude background (0)
    
    # Prepare arguments for parallel processing
    args = [(labeled_stack, label, z_spacing, xy_spacing) for label in unique_labels]
    
    # Use multiprocessing to extract surface points for each label
    with mp.Pool(n_processes or mp.cpu_count()) as pool:
        results = tqdm(pool.map(extract_surface_points_for_label, args))
    
    # Combine results into a dictionary
    surface_points = {label: points for label, points in results}
    return surface_points


def compute_pairwise_distance(args):
    """
    Computes the shortest distance between two sets of surface points.

    Parameters:
        args (tuple): Contains surface points of ROI1, ROI2, and ROI IDs.

    Returns:
        tuple: A tuple of (roi1, roi2, {"distance": shortest_distance, "points": (point1, point2)}).
    """
    roi1, roi2, points1, points2 = args

    tree1 = cKDTree(points1)
    tree2 = cKDTree(points2)

    # Compute distances in both directions and find the minimum
    distances1, indices1 = tree1.query(points2, k=1)
    distances2, indices2 = tree2.query(points1, k=1)

    min_distance1 = np.min(distances1)
    min_distance2 = np.min(distances2)

    if min_distance1 < min_distance2:
        idx = np.argmin(distances1)
        closest_points = (points1[indices1[idx]], points2[idx])
        shortest_distance = min_distance1
    else:
        idx = np.argmin(distances2)
        closest_points = (points1[idx], points2[indices2[idx]])
        shortest_distance = min_distance2

    return roi1, roi2, {"distance": shortest_distance, "points": closest_points}

def compute_surface_min_distances(labeled_stack, z_spacing, xy_spacing, temp_dir):
    """
    Computes the shortest distances between surfaces of each pair of ROIs.

    Parameters:
        labeled_stack (np.ndarray): 3D array with labeled ROIs.
        z_spacing (float): Spacing in microns between z-planes.
        xy_spacing (float): Spacing in microns for x and y dimensions.
        temp_dir (str): Directory path for saving temporary files.

    Returns:
        np.ndarray: Matrix of shape (n_rois, n_rois) containing dictionaries with "distance" and "points".
    """
    os.makedirs(temp_dir, exist_ok=True)

    surface_points = compute_surface_points(labeled_stack, z_spacing, xy_spacing)
    unique_labels = list(surface_points.keys())
    n_rois = len(unique_labels)

    # Create memmap for results
    result_path = os.path.join(temp_dir, "distance_matrix.npy")
    distance_matrix = np.memmap(result_path, dtype=object, mode="w+", shape=(n_rois, n_rois))

    # Prepare arguments for multiprocessing
    args = []
    for i, roi1 in enumerate(unique_labels):
        for j, roi2 in enumerate(unique_labels):
            if i < j:  # Avoid redundant computations
                args.append((roi1, roi2, surface_points[roi1], surface_points[roi2]))

    # Parallel computation of pairwise distances
    with mp.Pool(mp.cpu_count()) as pool:
        results = list(tqdm(pool.imap(compute_pairwise_distance, args), total=len(args)))

    # Fill the result matrix
    for roi1, roi2, result in results:
        idx1 = unique_labels.index(roi1)
        idx2 = unique_labels.index(roi2)
        distance_matrix[idx1, idx2] = result
        distance_matrix[idx2, idx1] = result  # Symmetric matrix

    return distance_matrix


In [35]:
outpt2 = compute_surface_min_distances(updated_stack, 1, xy_spacing, '/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/Takeshi_data/Added later-D31_no cyto_1_processed')

Computing surface points for each ROI...


100%|██████████| 4758/4758 [00:00<00:00, 4992.03it/s]
100%|██████████| 11316903/11316903 [27:02<00:00, 6974.62it/s] 


In [44]:
output_dir = '/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/Takeshi_data/Added later-D31_no cyto_1_processed'

#Load the distance matrix memmap
distance_matrix = np.memmap(os.path.join(output_dir, "distance_matrix.npy"), dtype=object, mode='r+', shape=(updated_stack.shape[1], updated_stack.shape[1]))
distance_matrix[1,1]

#Modify distance matrix to add ROI IDs in the dictionary
for i in range(distance_matrix.shape[0]):
    for j in range(distance_matrix.shape[1]):
        if distance_matrix[i, j] is not None:
            roi_set = set([i, j])
            distance_matrix[i, j] = (roi_set, distance_matrix[i, j]["distance"], distance_matrix[i, j]["points"])

In [46]:
output_dir = '/mnt/a5e90321-8b33-423f-ad87-3e20a7c42f90/Takeshi_data/Added later-D31_no cyto_1_processed'

#Load the distance matrix memmap
distance_matrix = np.memmap(os.path.join(output_dir, "distance_matrix.npy"), dtype=object, mode='r+', shape=(updated_stack.shape[1], updated_stack.shape[1]))
distance_matrix[1,2]

({1, 2},
 90.82324025018849,
 (array([  3.        , 168.18177083, 230.78958333]),
  array([ 45.        ,  88.08059896, 222.50325521])))

In [None]:
def convert_microns_to_voxels(coords, xy_spacing, z_spacing):
    """
    Convert coordinates in microns to voxel units.
    
    Parameters:
    - coords: Tuple of two 3D coordinate arrays (point1, point2) in microns.
    - xy_spacing: Spacing in microns for x and y dimensions.
    - z_spacing: Spacing in microns between z-planes.
    
    Returns:
    - Tuple of two 3D coordinate arrays (point1, point2) in voxel units.
    """
    point1, point2 = coords
    point1_voxels = point1 / [z_spacing, xy_spacing, xy_spacing]
    point2_voxels = point2 / [z_spacing, xy_spacing, xy_spacing]
    return point1_voxels, point2_voxels


def just_closest_rois(distance_matrix):
    """
    While the distance matrix has all the pairwise distance, this function will return only the closest ROI for each ROI.
    """
    closest_rois = []
    for i in range(distance_matrix.shape[0]):
        min_distance = np.inf
        closest_roi_id = None
        for j in range(distance_matrix.shape[1]):
            if distance_matrix[i, j] is not None:
                if distance_matrix[i, j][1] < min_distance:
                    min_distance = distance_matrix[i, j][1]
                    closest_roi_id = j
        closest_rois.append(distance_matrix[i, closest_roi_id])
    return closest_rois