## Import packages

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import czifile
from czifile import CziFile
from skimage import exposure, filters
from scipy.stats import pearsonr
import os
import glob
import matplotlib.colors

## Set up working directory

In [None]:
"""
Define the path to the working directory and the output folder where the 3D volumes will be saved.
Define the name of the files to be processed.
"""
data_path = 'path_to_working_directory'
output_path = 'path_to_output_directory'
file_name = glob.glob(f'{data_path}*.*', recursive = True) # List of file names to process

In [None]:
"""
Ensure the output directory exists and that files are read correctly.
"""
total_files = len(file_name)
print(f"Wolking forlder contains {total_files} files")

## <a id='toc1_3_'></a>[Define functions for image processing](#toc0_)

In [None]:
class Colocalisation:
    """
    Class for performing colocalisation analysis on confocal images (.czi files).
    Includes methods for loading channels, contrast enhancement, colocalisation masks and calculating colocalisation metrics.
    """
    def __init__(self, clahe_clip_limit=2.0, clahe_tile_grid_size=(4, 4)): # Defines CLAHE parameters for contrast enhancement
        self.clahe = cv2.createCLAHE(
            clipLimit=clahe_clip_limit,
            tileGridSize=clahe_tile_grid_size
        )

    # ------------------
    # Load CZI channels
    # ------------------
    def load_czi_channels(self, czi_path):
        """
        Load confocal images (.czi files), extract and return channels.
        """
        with czifile.CziFile(czi_path) as czi: # Images should be loaded as 2D maximum intensity projections
            data = czi.asarray() # Reads the image data from the .czi file
        data_squeezed = np.squeeze(data) # Squeeze any a single dimension (removes dimensions of metadata)
        if data_squeezed.ndim < 3: # Ensure there are at least 3 dimensions (channels x Y x X)
            raise ValueError("Unexpected CZI format: expected at least 3 dimensions") # Raise an error if 3 dimensions are not present
        channel1 = data_squeezed[0] # Extract first channel
        channel2 = data_squeezed[1] # Extract second channel
        return channel1, channel2 # Return the two channels

    # ---------------------------------
    # Contrast enhancement with CLAHE
    # ---------------------------------
    def enhance_contrast(self, channel):
        """
        Normalize and apply CLAHE for contrast enhancement.
        """
        mip_image = cv2.normalize(channel, None, 0, 255, # Normalise image to 8-bit range
                                  cv2.NORM_MINMAX).astype(np.uint8)

        mip_image_rgb = cv2.cvtColor(mip_image, cv2.COLOR_GRAY2RGB) 
        gray_image = cv2.cvtColor(mip_image_rgb, cv2.COLOR_RGB2GRAY) # Convert to grayscale 
        clahe_image = self.clahe.apply(gray_image) # Apply CLAHE contrast enhancement
        return clahe_image # Return the processed image

    # ---------------------
    # Colocalization mask
    # ---------------------
    def calculate_colocalization_mask(self, ch1, ch2):
        """
        Performs contrast enhacement and computes mask based on Otsu thresholding.
        """
        ch1_enh = self.enhance_contrast(ch1) # Performs CLAHE contrast enhancement on channel 1
        ch2_enh = self.enhance_contrast(ch2) # Performs CLAHE contrast enhancement on channel 2
        _, mask1 = cv2.threshold(ch1_enh, 0, 255, # Performs Otsu thresholding to create binary masks on channel 1
                                 cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        _, mask2 = cv2.threshold(ch2_enh, 0, 255, # Performs Otsu thresholding to create binary masks on channel 2
                                 cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        mask1 = mask1.astype(bool) # Convert masks to boolean arrays
        mask2 = mask2.astype(bool) # Convert masks to boolean arrays
        coloc_mask = mask1 & mask2 # Create colocalisation mask by logical AND operation
        return coloc_mask, ch1_enh, ch2_enh # Return colocalisation mask and enhanced channels

    # ---------------------
    # Pearson correlation
    # ---------------------
    def calculate_pearson_correlation(self, ch1, ch2, mask=None):
        """
        Calculate Pearson correlation between the two channels.
        """
        if mask is not None: # If available, apply mask to channels
            ch1_masked = ch1[mask] # Apply mask to channel 1
            ch2_masked = ch2[mask] # Apply mask to channel 2
        else: # If no mask, flatten the channels
            ch1_masked = ch1.flatten() # Flatten channel 1
            ch2_masked = ch2.flatten() # Flatten channel 2
        non_zero = (ch1_masked > 0) | (ch2_masked > 0) # Create a mask considering only non-zero pixels in either channel
        ch1_masked = ch1_masked[non_zero] # Apply non-zero mask to channel 1
        ch2_masked = ch2_masked[non_zero] # Apply non-zero mask to channel 2
        if len(ch1_masked) > 0: # Ensure a correct mask with pixels has been created
            pearson_coef, p_value = pearsonr(ch1_masked, ch2_masked) # Calculate Pearson correlation coefficient and p-value
            return pearson_coef, p_value # Return the Pearson coefficient and p-value
        else: # If no valid pixels in the mask
            return np.nan, np.nan # Return NaN for both coefficient and p-value

    # ----------------------
    # Manders coefficients
    # ----------------------
    def calculate_manders_coefficients(self, ch1, ch2, coloc_mask):
        """
        Calculate Manders Overlap Coefficients M1 and M2.
        Manders Overlap Coefficient is the degree of colocalisation between the channels normalised between [0,1]
        """
        m1 = np.sum(ch1[coloc_mask]) / np.sum(ch1) if np.sum(ch1) > 0 else np.nan # Calculate M1 coefficient for channel 1
        m2 = np.sum(ch2[coloc_mask]) / np.sum(ch2) if np.sum(ch2) > 0 else np.nan # Calculate M2 coefficient for channel 2
        return m1, m2 # Return Manders coefficients

    # ------------------------------------------------------------
    # ICQ (Intensity Correlation Quotient)
    # ------------------------------------------------------------
    def calculate_icq(self, ch1, ch2):
        """
        Calculate ICQ colocalisation score.
        ICQ measures the degree of correlation between the intensities of the two channels.
        """
        ch1_mean = ch1.mean() # Calculate mean intensity of channel 1
        ch2_mean = ch2.mean() # Calculate mean intensity of channel 2
        prod = (ch1 - ch1_mean) * (ch2 - ch2_mean) # Calculate product of deviations from mean
        n_pos = np.sum(prod > 0) # Count number of positive products
        icq = (n_pos / prod.size) - 0.5 # Calculate ICQ score
        return icq # Return ICQ score

    # -------------------
    # Figure generation
    # -------------------
    def create_analysis_figure(self, ch1, ch2, coloc_mask, output_path_pdf):
        """
        Function to generate colocalisation analysis plots.
        Saves figure as PDF.
        """
        merge = np.zeros((ch1.shape[0], ch1.shape[1], 3), dtype=np.float32) # Create empty RGB image for merge
        merge[:, :, 1] = ch1 / 255 # Assign channel 1 to green
        merge[:, :, 0] = ch2 / 255 # Assign channel 2 to red
        merge[:, :, 2] = ch2 / 255 # Assign channel 2 to blue
        coloc_overlay = merge.copy() # Create copy for colocalisation overlay
        coloc_overlay[coloc_mask] = [1, 1, 0]  # Yellow overlay

        pearson_coef, _ = self.calculate_pearson_correlation(ch1, ch2) # Calculate Pearson correlation coefficient
        manders_m1, manders_m2 = self.calculate_manders_coefficients(ch1, ch2, coloc_mask) # Calculate Manders coefficients
        icq = self.calculate_icq(ch1, ch2) # Calculate ICQ score
        total_signal = np.sum((ch1 > 0) | (ch2 > 0)) # Calculate total pixel signal
        coloc_percentage = (np.sum(coloc_mask) / total_signal * 100) if total_signal > 0 else 0 # Calculate colocalisation percentage

        # Create figure layout
        fig = plt.figure(figsize=(16, 10)) # Create canvas size
        gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3) # Define grid layout

        # Plot channel images and overlays
        axes = [
            ('Channel 1', ch1, gs[0, 0], 'gray'), # Plot channel 1
            ('Channel 2', ch2, gs[0, 1], 'gray'), # Plot channel 2
            ('Merge', merge, gs[0, 2], None), # Plot merged channels
            ('Colocalization', coloc_overlay, gs[1, 0], None) # Plot colocalisation overlay
        ]

        for title, img, grid, cmap in axes: 
            ax = fig.add_subplot(grid)
            ax.imshow(img, cmap=cmap)
            ax.set_title(title)
            ax.axis('off')

        # Pixel intensity in scatter plot
        ax_corr = fig.add_subplot(gs[1, 1]) # Create scatter plot for pixel intensity correlation
        n_samples = min(10000, ch1.size) # Limit number of points for scatter plot
        indices = np.random.choice(ch1.size, n_samples, replace=False) # Randomly sample pixel indices
        ch1_flat = ch1.flatten()[indices] / 255 # Flatten and normalise channel 1
        ch2_flat = ch2.flatten()[indices] / 255 # Flatten and normalise channel 2
        coloc_flat = coloc_mask.flatten()[indices] # Flatten colocalisation mask
        ax_corr.scatter(ch1_flat[~coloc_flat], ch2_flat[~coloc_flat], # Scatter non-colocalised pixels
                        c='gray', s=1, alpha=0.3, label='Non-coloc')
        ax_corr.scatter(ch1_flat[coloc_flat], ch2_flat[coloc_flat],   # Scatter colocalised pixels
                        c='red', s=2, alpha=0.6, label='Coloc')
        ax_corr.plot([0, 1], [0, 1], 'k--', alpha=0.5)
        ax_corr.set_xlabel('Channel 1 Intensity') # Label x-axis
        ax_corr.set_ylabel('Channel 2 Intensity') # Label y-axis
        ax_corr.set_title('Pixel Intensity Correlation') # Title for scatter plot
        ax_corr.legend(fontsize=8) # Legend for scatter plot

        # Text box with statistics. Display calculated statistics
        ax_stats = fig.add_subplot(gs[1, 2]) 
        ax_stats.axis('off')
        stats_text = (f"Pearson R: {pearson_coef:.2f}\n" 
                      f"Manders M1: {manders_m1:.2f}\n"
                      f"Manders M2: {manders_m2:.2f}\n"
                      f"ICQ: {icq:.2f}\n"
                      f"Coloc %: {coloc_percentage:.1f}%")
        ax_stats.text(0.05, 0.5, stats_text, fontsize=12, verticalalignment='center')

        plt.show() # Show in notebook

        # Save to PDF
        fig.savefig(output_path_pdf, dpi=300, format='pdf')
        plt.close(fig)


In [None]:
coloc = Colocalisation()

for file in file_name:
    ch1, ch2 = coloc.load_czi_channels(file) # Load each confocal image (.cazi files)
    coloc_mask, ch1_enh, ch2_enh = coloc.calculate_colocalization_mask(ch1, ch2) # Calculate colocalisation mask
    sampleID = "name_of_output_file"
    output_title = f"{output_path}/{sampleID}_colocalisation_result.pdf" # Set figure title and save path
    coloc.create_analysis_figure(ch1_enh, ch2_enh, coloc_mask, output_title) # Generate and save figure as PDF