# Initialize

In [3]:
import os
import glob
import pickle
import time

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import cv2
from pathlib import Path
import re 
# Scikit image
from skimage import color, segmentation, filters, measure, morphology
from skimage.io import imread, imsave
from skimage.segmentation import slic
from skimage.measure import regionprops, regionprops_table, find_contours
from skimage.transform import resize

from scipy import ndimage as ndi
from scipy.ndimage import center_of_mass, gaussian_filter, generic_filter,binary_closing
from scipy.stats import mode
from scipy.ndimage import find_objects, label, center_of_mass


# Cellpose
from cellpose import core, utils, io, models, metrics, plot
from cellpose.io import logger_setup

# Personal 
import importlib
import preprocessing as pp
import lesion_mask as lm

# Image loading modules
from aicsimageio import AICSImage
import stackview

# pyclesperanto
import pyclesperanto_prototype as cle
from pyclesperanto_prototype import imshow

In [2]:
# Cellpose use gpu
mpl.rcParams['figure.dpi'] = 300

use_GPU = core.use_gpu()
print('>>> GPU activated? %d'%use_GPU)

# call logger_setup to have output of cellpose written
logger_setup();

>>> GPU activated? 0
creating new log file
2025-01-24 19:35:17,952 [INFO] WRITING LOG OUTPUT TO /home/ucloud/.cellpose/run.log
2025-01-24 19:35:17,954 [INFO] 
cellpose version: 	3.0.11 
platform:       	linux 
python version: 	3.12.3 
torch version:  	2.5.0+cu124


## Helper Functions

In [3]:
# Function to apply tissue mask
def apply_tissue_mask(image, tissue_mask):
    """
    Apply a tissue mask to the input image to isolate the tissue regions.

    Args:
        image (ndarray): The input image with multiple channels.
        tissue_mask (ndarray): The binary mask representing tissue areas.

    Returns:
        ndarray: The masked image.
    """
    return image * tissue_mask[..., None]  # Broadcast tissue mask to all channels

def get_cellpose_centroids_with_table(masks):
    """
    Compute centroids from Cellpose masks using regionprops_table.

    Args:
        masks (ndarray): Labeled mask from Cellpose.

    Returns:
        list: A list of centroids [(y1, x1), (y2, x2), ...].
    """
    props = regionprops_table(masks, properties=["centroid"])
    # Extract centroid-0 and centroid-1 arrays
    centroids_y = props["centroid-0"]
    centroids_x = props["centroid-1"]
    # Combine into a list of tuples
    centroids = list(zip(centroids_y, centroids_x))
    return centroids

# Workflow - Segment lesions and quantify cells for each image

In [None]:
# # Reload modules (useful during development)
# #importlib.reload(lm)
# #importlib.reload(pp)

# # Initialize Cellpose model
# model = models.Cellpose(gpu=True, model_type='cyto3')

# # Define the base path for image files
# base_path = r"/work/imaging_data_f2805_HEV/2023_07_14_F2805_w.*"

# # Use glob to find all matching files
# image_paths = glob.glob(base_path)

# # Extract week numbers from file paths
# def extract_week_number(path):
#     match = re.search(r'w\.(\d+)', path)
#     return int(match.group(1)) if match else None

# # Sort paths by week number
# image_paths = sorted(image_paths, key=extract_week_number)

# # Group paths by week
# week_groups = {}
# for path in image_paths:
#     week = extract_week_number(path)
#     if week is not None:
#         week_groups.setdefault(week, []).append(path)

# # Channel names for easy reference
# channel_names = ['T_cell', 'B_cell', 'HEV', 'DAPI']

# # Set output folder
# batch_size = 4
# output_folder = "full_batch_results"
# os.makedirs(output_folder, exist_ok=True)

# # Process images grouped by week
# for week, week_paths in week_groups.items():
#     print(f"Processing week {week} with {len(week_paths)} images.")

#     # Batch processing for the current week
#     for batch_num, i in enumerate(range(0, len(week_paths), batch_size)):
#         batch_paths = week_paths[i:i + batch_size]
#         results_dict = {}  # Initialize a dictionary to store batch results

#         print(f"Processing batch {batch_num + 1} for week {week} with {len(batch_paths)} images.")

#         for path in batch_paths:
#             print(f"Processing image: {path}")

#             # Load and preprocess the image
#             scenes = pp.load_images(path)
#             filename = os.path.splitext(os.path.basename(path))[0]

#             # Process each scene except the last (assumed metadata or unused)
#             for scene in range(len(scenes) - 1):
#                 scene_start = time.time()

#                 outfile = f"{filename}_Scene_{scene}"
#                 image = scenes[scene]

#                 # Compute lesion and tissue masks
#                 lesion_start = time.time()
#                 lesion_mask, tissue_mask = lm.get_lesion_mask(
#                     image,
#                     kernel_size=20,
#                     sigma=20,
#                     std_thresh=0.2,
#                     scale_percent=10,
#                     re_gauss_sigma=20,
#                     re_gauss_threshold=0.5
#                 )
#                 lesion_end = time.time() - lesion_start
#                 print("Total lesion time:", lesion_end)

#                 # Apply tissue mask to the image
#                 tissue_image = apply_tissue_mask(image, tissue_mask)

#                 # Initialize centroid lists for T and B cells
#                 T_centroids = []
#                 B_centroids = []

#                 # Run Cellpose on T and B cell channels
#                 cell_start = time.time()
#                 for ch, centroids_list in zip(range(2), [T_centroids, B_centroids]):
#                     cell_channel_start = time.time()

#                     # Threshold and process the specific channel
#                     image_thresholded = np.where(tissue_image[:, :, ch] < 500, 0, tissue_image[:, :, ch])
#                     masks, flows, styles, diams = model.eval(
#                         image_thresholded, diameter=30, flow_threshold=None, channels=[0, 0])
#                     cell_channel_end = time.time() - cell_channel_start
#                     print(f"Cellpose time for channel {channel_names[ch]}: {cell_channel_end:.2f} seconds")

#                     # Compute centroids and store them
#                     centroids = get_cellpose_centroids_with_table(masks)
#                     centroids_list.extend(centroids)

#                 cell_end = time.time() - cell_start
#                 print("Total cellpose time:", cell_end)

#                 # Store the results in the dictionary
#                 results_dict[outfile] = {
#                     'DAPI_image': image[:, :, 3],
#                     'lesion_mask': lesion_mask,
#                     'tissue_mask': tissue_mask,
#                     'T_centroids': np.array(T_centroids),
#                     'B_centroids': np.array(B_centroids),
#                 }

#                 scene_end = time.time() - scene_start
#                 print(f"Scene processing time: {scene_end:.2f} seconds")

#         # Save batch results to a pickle file
#         batch_file = Path(output_folder) / f"lesion_cell_results_week_{week}_batch_{batch_num + 1}.pkl"
#         with open(batch_file, 'wb') as f:
#             pickle.dump(results_dict, f, protocol=4)

#         print(f"Week {week}, Batch {batch_num + 1} results saved to {batch_file}")

#         # Discard batch results to free memory
#         del results_dict

        #112 minutes


## Helper functions for plotting the images

In [4]:
# Function to mark cells within the lesion area
def mark_within_lesion(props_df, lesion_mask, cell_type):
    if not props_df.empty:
        # Ensure the centroids are integers for indexing
        props_df['centroid-0'] = props_df['centroid-0'].astype(int)
        props_df['centroid-1'] = props_df['centroid-1'].astype(int)

        # Check if each cell's centroid is within the lesion
        props_df['within_lesion'] = props_df.apply(
            lambda row: lesion_mask[row['centroid-0'], row['centroid-1']]
            if 0 <= row['centroid-0'] < lesion_mask.shape[0] and
               0 <= row['centroid-1'] < lesion_mask.shape[1]
            else False,
            axis=1
        )
    else:
        # Add the column to an empty DataFrame
        props_df['within_lesion'] = []

    props_df['cell_type'] = cell_type
    return props_df


def analyze_lesion_area(lesion_mask, t_props_df, b_props_df):
    # Mark T cells and B cells with whether they're inside the lesion
    t_props_df = mark_within_lesion(t_props_df, lesion_mask, 'T_cell')
    b_props_df = mark_within_lesion(b_props_df, lesion_mask, 'B_cell')

    # Combine the DataFrames
    combined_df = pd.concat([t_props_df, b_props_df], ignore_index=True)

    # Count total and within-lesion cells for each type
    counts = {
        'T_cell': {
            'total': len(t_props_df),
            'within_lesion': t_props_df['within_lesion'].sum(),
        },
        'B_cell': {
            'total': len(b_props_df),
            'within_lesion': b_props_df['within_lesion'].sum(),
        }
    }

    return combined_df, counts

def plot_lesion_with_centroids(image_name, dapi_image, lesion_mask, centroids_df, counts, output_folder="processed_images", centroid_colors=('#2CA02C', '#E377C2'), contour_color='blue'):
    """
    Plot the lesion area with T and B cell centroids, and include a header with the sample name.
    The plot is saved as an image file in the specified output folder.'#4CBB17', 'magenta'

    Args:
        dapi_image (numpy.ndarray): DAPI image for background visualization.
        lesion_mask (numpy.ndarray): Binary lesion mask.
        centroids_df (pandas.DataFrame): Combined centroids DataFrame with 'cell_type' column.
        counts (dict): Counts of total and within-lesion centroids for T cells and B cells.
        output_folder (str): Folder where the plot image will be saved.
        sample_name (str): The name of the sample to be displayed as a header.
        centroid_colors (tuple): Colors for T cells and B cells (default: ('#2CA02C', '#E377C2')).#15B01A
        contour_color (str): Color for lesion contours (default: '#1F77B4').
    """
    fig, ax = plt.subplots(figsize=(15, 15))
    ax.imshow(dapi_image, cmap='gray', vmin=500, vmax=2500)

    # Add the sample name header
    ax.text(0.5, 1.05, f"{image_name}", ha='center', va='bottom', fontsize=16, color='black', transform=ax.transAxes)

    # Plot centroids
    for cell_type, color, marker in zip(['T_cell', 'B_cell'], centroid_colors, ['o', 'o']):
        subset = centroids_df[centroids_df['cell_type'] == cell_type]
        for _, row in subset.iterrows():
            y, x = row['centroid-0'], row['centroid-1']
            if row['within_lesion']:
                # Marker for cells inside the lesion
                ax.plot(x, y, marker, color=color, markersize=0.3, alpha=0.85, label=f"{cell_type} (inside lesion)")
            else:
                # Marker for cells outside the lesion
                ax.plot(x, y, marker, color=color, markersize=0.3, alpha = 0.85, label=f"{cell_type} (outside lesion)")  # Extract and plot lesion contours
    
    # Plot contours only once (removes the issue of multiple contours in different colors)
    contours = find_contours(lesion_mask, level=0.5)
    for contour in contours:
        ax.plot(contour[:, 1], contour[:, 0], contour_color, linewidth=1.5)  # Set consistent linewidth
    
    
        # Set dynamic title
    t_total, t_within = counts['T_cell']['total'], counts['T_cell']['within_lesion']
    b_total, b_within = counts['B_cell']['total'], counts['B_cell']['within_lesion']
    ax.set_title(
        f"T Cells - Total: {t_total}, Within Lesion: {t_within} | "
        f"B Cells - Total: {b_total}, Within Lesion: {b_within}",
        fontsize=14
    )


    # Add legend
    legend_elements = [
        plt.Line2D([0], [0], color=centroid_colors[0], marker='o', linestyle='None', markersize=10, label='T Cells'),
        plt.Line2D([0], [0], color=centroid_colors[1], marker='o', linestyle='None', markersize=10, label='B Cells')
    ]
    ax.legend(handles=legend_elements, loc='upper right', fontsize=12, frameon=True)

    ax.axis('off')  # Hide axes for better visualization

    # Save the plot as an image
    os.makedirs(output_folder, exist_ok=True)
    save_path = os.path.join(output_folder, f"{image_name}_plot.png")
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()  # Close the plot to free memory

    print(f"Saved plot: {save_path}")



## Plotting the images

In [None]:
# Set the output folder for saving processed images
output_image_folder = "processed_images"
os.makedirs(output_image_folder, exist_ok=True)

# Path to the folder containing batch results
batch_results_folder = "full_batch_results"

# Loop through the batch result files in the folder
for week_file in os.listdir(batch_results_folder):
    if week_file.endswith(".pkl"):
        print(f"Loading {week_file}...")
        week_file_path = os.path.join(batch_results_folder, week_file)
        
        # Load the results dictionary
        with open(week_file_path, 'rb') as f:
            results_dict = pickle.load(f)

        # Loop through the results and save processed images
        for key, result in results_dict.items():
            print(f"Processing: {key}")
            image_name = key
            dapi_image = result.get('DAPI_image')
            lesion_mask = result.get('lesion_mask')
            T_cells = result.get('T_centroids')
            B_cells = result.get('B_centroids')

            # Validate input data
            if dapi_image is None or lesion_mask is None or T_cells is None or B_cells is None:
                print(f"Missing data for {key}. Skipping.")
                continue

            # Ensure lesion_mask matches the DAPI image shape
            if lesion_mask.shape != dapi_image.shape:
                print("Lesion mask does not match in shape. Resizing...")
                lesion_mask = resize(lesion_mask, dapi_image.shape, anti_aliasing=False).astype(bool)
            
            # Convert T_cells and B_cells centroids to DataFrames
            if len(T_cells) > 0:
                t_props_df = pd.DataFrame(T_cells, columns=["centroid-0", "centroid-1"])
            else:
                print(f"No T cells detected for {key}.")
                t_props_df = pd.DataFrame(columns=["centroid-0", "centroid-1"])

            if len(B_cells) > 0:
                b_props_df = pd.DataFrame(B_cells, columns=["centroid-0", "centroid-1"])
            else:
                print(f"No B cells detected for {key}.")
                b_props_df = pd.DataFrame(columns=["centroid-0", "centroid-1"])


            # Convert T_cells and B_cells centroids to DataFrames
            #t_props_df = pd.DataFrame(T_cells, columns=["centroid-0", "centroid-1"])
            #b_props_df = pd.DataFrame(B_cells, columns=["centroid-0", "centroid-1"])

            # Analyze lesion area
            combined_df, counts = analyze_lesion_area(lesion_mask, t_props_df, b_props_df)

            plot_lesion_with_centroids(
                image_name,
                dapi_image,
                lesion_mask,
                combined_df,
                counts,
                output_folder="processed_images"
            )


Loading lesion_cell_results_week_4_batch_1.pkl...
Processing: 2023_07_14_F2805_w.4 B3+B4_iBALT_Scene_0
Saved plot: processed_images/2023_07_14_F2805_w.4 B3+B4_iBALT_Scene_0_plot.png
Processing: 2023_07_14_F2805_w.4 B3+B4_iBALT_Scene_1
Saved plot: processed_images/2023_07_14_F2805_w.4 B3+B4_iBALT_Scene_1_plot.png
Processing: 2023_07_14_F2805_w.4 A1+A2_iBALT_Scene_0
Saved plot: processed_images/2023_07_14_F2805_w.4 A1+A2_iBALT_Scene_0_plot.png
Processing: 2023_07_14_F2805_w.4 A1+A2_iBALT_Scene_1
Saved plot: processed_images/2023_07_14_F2805_w.4 A1+A2_iBALT_Scene_1_plot.png
Processing: 2023_07_14_F2805_w.4 B1+B2_iBALT_Scene_0
Saved plot: processed_images/2023_07_14_F2805_w.4 B1+B2_iBALT_Scene_0_plot.png
Processing: 2023_07_14_F2805_w.4 B1+B2_iBALT_Scene_1


KeyboardInterrupt: 