In [10]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.filters import gaussian, threshold_otsu, threshold_multiotsu, sobel
from skimage.morphology import remove_small_objects, disk, binary_closing
from scipy.ndimage import zoom, binary_dilation, binary_erosion, distance_transform_edt
from skimage.measure import label, regionprops
from skimage import io, exposure, color
from skimage import measure, morphology
from skimage import exposure
from czifile import imread
import cv2
import re
from matplotlib.ticker import MaxNLocator
from cellpose import models, plot 
#model = models.Cellpose(gpu=False, model_type='cyto3')

In [11]:
MIN_INCLUSION_SIZE = 5
MAX_INCLUSION_SIZE = 2000

In [12]:
def display_image(image, path, type):
    """Display the image."""
    plt.imshow(image)
    plt.axis('off')
    plt.title(f"{path} {type}")
    plt.show()

def extract_image_paths(folder):
    """Extract all image file paths from the specified folder."""
    return [os.path.join(folder, f) for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]

def read_image(image_path):
    """Read the LSM image from the specified path."""
    return imread(image_path)

def count(mask): 
    """Count the number of unique labels in the mask."""
    return len(np.unique(label(mask))) - 1  # Exclude background label (0)

def extract_channels(image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """Extract green and red channels from the squeezed image (shape: [Z, C, H, W]).""" 
    return image[0], image[1]

def preprocess_green_channel(green_channel):
    """
    Preprocess the green fluorescence channel for better segmentation and inclusion detection.
    - Applies Gaussian blur to reduce noise.
    - Enhances contrast using sigmoid adjustment.
    - Normalizes intensities to [0, 1] for consistent processing.
    """
    confocal_img = gaussian(green_channel, sigma=2)
    confocal_img = exposure.adjust_sigmoid(green_channel, cutoff=0.25)
    confocal_img = normalize_image(confocal_img)
    return confocal_img

def normalize_image(image):
    """
    Normalize the image to the range [0, 1].
    This is useful for consistent processing across different images.
    """
    return (image - np.min(image)) / (np.max(image) - np.min(image))

def calculate_surface_area(labeled_image: np.ndarray) -> float:
    """Calculate the total surface area for labeled regions."""
    props = regionprops(labeled_image)
    return sum(prop.area for prop in props)

def high_circularity_mask(labeled_image: np.ndarray, threshold: float = 0.7) -> np.ndarray:
    """
    Return a binary mask with only the regions that have a circularity index higher than the threshold.
    
    Parameters:
        labeled_image (np.ndarray): Labeled image where each object has a unique label.
        threshold (float): Minimum circularity index to keep the object.
        
    Returns:
        np.ndarray: Binary mask with only high-circularity objects.
    """
    mask = np.zeros_like(labeled_image, dtype=np.uint8)
    props = regionprops(labeled_image)

    for prop in props:
        perimeter = prop.perimeter
        area = prop.area
        if perimeter > 0:
            circularity = (4 * np.pi * area) / (perimeter ** 2)
            if circularity > threshold:
                # Add the object to the mask
                mask[labeled_image == prop.label] = 1

    return mask


def segment_cells(green_channel):
    """
    Segment whole cells in the green channel using Cellpose.
    - Normalizes image intensity.
    - Suppresses bright spots (e.g., inclusions) to better detect cell boundaries.
    - Applies Gaussian blur for smoother segmentation input.
    - Gradually increases segmentation diameter until at least one cell is detected.
    """
    green_channel = normalize_image(green_channel)
    percentile_99 = np.percentile(green_channel, 99)
    
    # Suppress very bright pixels (inclusions)
    green_channel_remove_inclusions = np.where(green_channel < percentile_99, green_channel, percentile_99)
    green_channel_remove_inclusions = gaussian(green_channel_remove_inclusions, sigma=5)

    # Normalize again after processing
    green_channel_remove_inclusions = normalize_image(green_channel_remove_inclusions)

    # Try different diameters until cells are detected
    diameter = 150
    while diameter < 500:
        masks, flows, styles, diams = model.eval(green_channel_remove_inclusions, diameter=diameter, channels=[0, 0])
        labeled_cells = label(masks)
        if np.max(labeled_cells) > 0:
            return labeled_cells
        diameter += 25

    # No cells found

def extract_inclusions(green_channel, mask, display_graph=False):
    """
    Extract potential inclusions inside a cell.
    - Blurs and masks the cell region.
    - Computes intensity statistics for thresholding.
    - Applies different threshold strategies depending on intensity distribution.
    - Removes objects that are too small or too large to be inclusions.
    - Optionally shows histogram for debugging.
    """
    applied_mask_blurred = gaussian(green_channel, sigma=1) * mask
    applied_mask_eliminate_background = applied_mask_blurred[applied_mask_blurred > 0]


    # Normalize the signal within the masked region
    applied_mask_eliminate_background = normalize_image(applied_mask_eliminate_background)


    # Compute descriptive statistics for intensity distribution
    q3 = np.percentile(applied_mask_eliminate_background, 75)
    hist, bin_edges = np.histogram(applied_mask_eliminate_background, bins='fd')
    applied_mask = normalize_image(green_channel) * mask


    # Decide on thresholding strategy based on upper quartile
    if q3 < 0.4 and len(bin_edges) > 20:
        threshold = max(threshold_otsu(applied_mask), 0.5)
    elif q3 >= 0.9:
        threshold = 1  # very high, to exclude everything
    else:
        threshold = 0.999  # conservatively high


    # Apply threshold and size-based filters
    #threshold = max(threshold_otsu(applied_mask), 0.38)
    inclusions = applied_mask > threshold
    inclusions = remove_small_objects(inclusions, min_size=MIN_INCLUSION_SIZE)
    #inclusions = inclusions ^ remove_small_objects(inclusions, min_size=MAX_INCLUSION_SIZE)


    # Optional histogram display
    if display_graph:
        print("Threshold: ", threshold)
        print("Bin count", len(bin_edges))
        plt.hist(applied_mask_eliminate_background, bins='fd')
        plt.axvline(q3, color='purple', linestyle='dashed', linewidth=2, label=f'Q3: {q3:.2f}')
        plt.legend()
        plt.title("Intensity histogram")
        plt.show()


    return inclusions

def generate_inclusion_image(green_channel, labeled_cells):
    """
    Generate a binary image with all inclusions from all cells.
    - Loops through each segmented cell.
    - Extracts inclusions from each cell region.
    - Combines all into one final binary image.
    """
    inclusion_image = np.zeros_like(green_channel)

    for i, cell in enumerate(regionprops(labeled_cells)):
        if cell.area < 100:
            continue
        mask = labeled_cells == cell.label
        inclusions = extract_inclusions(green_channel, mask)
        inclusion_image += inclusions  # adds binary inclusion mask

    return inclusion_image

In [13]:
def analysis(red: np.ndarray, green:np.ndarray, path:str) -> pd.DataFrame:
    data = []
    df_cell_summary = pd.DataFrame()

    print("Starting analysis...")

    # Preprocess the green channel
    green = preprocess_green_channel(green)

    
    display_image(green, path, "Green Channel")

    return green

    #labeled_cells = segment_cells(green)

    #display_image(labeled_cells, path, "Labeled Cells")

    #inclusion_image = generate_inclusion_image(green, labeled_cells)
    #display_image(inclusion_image, path, "Inclusion Image")

    # Extract inclusions from the green channel
    #inclusions = extract_inclusions(green, inclusion_image, display_graph=True)
    #display_image(inclusions, path, "Inclusions")

    #for i, cell in enumerate(regionprops(labeled_cells)):
    #    if cell.area < 100:  # Skip tiny regions likely to be noise
    #        continue
#
    #            # Create a mask for the current cell
    #    mask = labeled_cells == cell.label
#
    #    original_cell = green * mask
#
    #    #display_image(original_cell, path, f"Cell Mask {i+1}")

    





In [14]:
def main(image_folder):
    images_to_analyze = extract_image_paths(image_folder)
    output_dir = os.getcwd()

    for path in images_to_analyze:
        image = read_image(path)
        image_squeezed = np.squeeze(image) 
    
        red, green = extract_channels(image_squeezed)
    
        # Create output folder for normalized images
        normalized_output_dir = os.path.join(os.getcwd(), "green_channel_images")
        os.makedirs(normalized_output_dir, exist_ok=True)
    
        # Build the output file name
        base_name = os.path.splitext(os.path.basename(path))[0]
        output_path = os.path.join(normalized_output_dir, f"{base_name}_green.png")
    
        # Normalize and save image
        #if green.dtype != np.uint8:
        #    green_uint8 = (green * 255).clip(0, 255).astype(np.uint8)
        #else:
        #    green_uint8 = green
    
        io.imsave(output_path, green)

if __name__ == "__main__":
    image_folder = '52925_images'
    main(image_folder)

    