In [None]:
import cv2
import numpy as np
import os
import pandas as pd
import pickle
import pyvips
import seaborn as sns
import tifffile
from matplotlib import pyplot as plt
from scipy import ndimage

### Settings

In [None]:
# File paths
dataset = "MERSCOPE_WT_1"
data_path = f"../data/{dataset}/"
output_path = f"../output/{dataset}/"

# Transformation parameters
pixel_size = 0.10799861
x_shift = int(-266.1734)
y_shift = int(180.2510)

# All DAPI images
files = os.listdir(data_path + "raw_data/DAPI_images/")
files = [i for i in files if i.startswith("mosaic")]
files.sort()
files

### Derive the MIP of DAPI images (run once)

In [None]:
# paths = [data_path + "raw_data/DAPI_images/" + f"mosaic_DAPI_z{i}.tif" for i in range(7)]
# imgs = [pyvips.Image.new_from_file(p, access="sequential") for p in paths]
# mip = imgs[0]
# for im in imgs[1:]:
#     mip = pyvips.Image.maxpair(mip, im)
# mip.tiffsave(
#     data_path + "raw_data/DAPI_images/mosaic_DAPI_MIP.tif",
#     bigtiff=True,
#     compression="lzw",   # or "deflate"
#     tile=True,
#     tile_width=1024,
#     tile_height=1024,
#     pyramid=False
# )
# print("Saved:", "mosaic_DAPI_MIP.tif")

### Benchmark 1: area of detected objects on the MIP

In [None]:
# ==================== Helper functions ==================== #


# Thresholding followed by dilation with circular kernel
def adaptive_thresholding_with_dilation(img, block_size=49, c=-1, radius=10):
    th = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, c)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1))
    th_dilated = cv2.dilate(th, kernel)
    return th_dilated


def adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=50, radius=10):
    th = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, c)
    contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    filtered_contours = [c for c in contours if cv2.contourArea(c) >= min_area]
    filtered_img = np.zeros_like(th)
    cv2.drawContours(filtered_img, filtered_contours, -1, 255, -1)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1))
    th_dilated = cv2.dilate(filtered_img, kernel)
    return th_dilated


def otsu_with_dilation(img, radius=10):
    _, th = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1))
    th_dilated = cv2.dilate(th, kernel)
    return th_dilated


# Analyze detected objects
def analyze_objects(contours, method_name):
    objects = []
    for i, contour in enumerate(contours):
        area_pixels = cv2.contourArea(contour)
        if area_pixels > 0:
            area_um2 = area_pixels * (pixel_size ** 2)
            radius_um = np.sqrt(area_um2 / np.pi)
            diameter_um = radius_um * 2
            objects.append({"method": method_name,
                            "object_id": i + 1,
                            "area_pixels": area_pixels,
                            "area_um2": area_um2,
                            "radius_um": radius_um,
                            "diameter_um": diameter_um})
    return objects


def print_method_results(objects, method_name):
    
    if len(objects) > 0:
        
        print(f"Number of detected objects: {len(objects)}")
        print(f"\nDetailed statistics for each object:")
        
        areas_px = [obj["area_pixels"] for obj in objects]
        areas_um2 = [obj["area_um2"] for obj in objects]
        radii_um = [obj["radius_um"] for obj in objects]
        diameters_um = [obj["diameter_um"] for obj in objects]
        
        print(f"\nSummary statistics:")
        print(f"  Area (pixels²):  Mean={np.mean(areas_px):6.1f}, Median={np.median(areas_px):6.1f}, Min={np.min(areas_px):6.1f}, Max={np.max(areas_px):6.1f}")
        print(f"  Area (μm²):      Mean={np.mean(areas_um2):6.2f}, Median={np.median(areas_um2):6.2f}, Min={np.min(areas_um2):6.2f}, Max={np.max(areas_um2):6.2f}")
        print(f"  Radius (μm):     Mean={np.mean(radii_um):6.2f}, Median={np.median(radii_um):6.2f}, Min={np.min(radii_um):6.2f}, Max={np.max(radii_um):6.2f}")
        print(f"  Diameter (μm):   Mean={np.mean(diameters_um):6.2f}, Median={np.median(diameters_um):6.2f}, Min={np.min(diameters_um):6.2f}, Max={np.max(diameters_um):6.2f}")
    
    else:
        
        print(f"No objects detected with {method_name}!")


# Extract statistics for plotting
def get_method_statistics(method_results, method_name, stat_type="area_pixels"):
    if method_name not in method_results:
        print(f"Method '{method_name}' not found in results")
        return []
    objects = method_results[method_name]["objects"]
    return [obj[stat_type] for obj in objects]


def get_all_methods_statistics(method_results, stat_type="area_pixels"):
    stats_dict = {}
    for method_name in method_results.keys():
        stats_dict[method_name] = get_method_statistics(method_results, method_name, stat_type)
    return stats_dict


# Downsample image
def downsample_image(img, scale_factor=5000):
    height, width = img.shape[:2]
    if height < width:
        scale_factor = 5000 / height
        new_height = 5000
        new_width = int(width * scale_factor)
    else:
        scale_factor = 5000 / width
        new_width = 5000
        new_height = int(height * scale_factor)
    return cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)

In [None]:
# Initialize for comparison
fname = "mosaic_DAPI_MIP.tif"

# Load DAPI image, downsample, and save representative patch
img_path = os.path.join(data_path, "raw_data/DAPI_images", fname)
img = tifffile.imread(img_path)
img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

img_downsampled = downsample_image(img, scale_factor=5000)
cv2.imwrite(output_path + f"downsampled_img.png", img_downsampled)

target_coords = [(3750, 1500), (3750, 1650), (3750, 1800)]
w, h = 150, 150
for i, (x0, y0) in enumerate(target_coords):
    patch = img_downsampled[y0:y0+h, x0:x0+w]
    cv2.imwrite(output_path + f"downsampled_patch_{i}.png", patch)

print(f"Analyzing DAPI image: {fname}")
print("="*80)

# Define thresholding methods
thresholding_methods = {
    "Adaptive Thresholding": {
        "function": lambda img: cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 49, -1),
        "description": "Gaussian-weighted adaptive thresholding with block size 49"
    },
    "Adaptive + Dilation": {
        "function": lambda img: adaptive_thresholding_with_dilation(img, block_size=49, c=-1, radius=10),
        "description": "Adaptive thresholding followed by dilation with circular kernel (radius=10px)"
    },
    "Adaptive + Small Size Filter + Small Dilation": {
        "function": lambda img: adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=100, radius=5),
        "description": "Adaptive thresholding followed by dilation with circular kernel (radius=5px) and size filter (min_area=100px)"
    },
    "Adaptive + Small Size Filter + Moderate Dilation": {
        "function": lambda img: adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=100, radius=10),
        "description": "Adaptive thresholding followed by dilation with circular kernel (radius=10px) and size filter (min_area=100px)"
    },
    "Adaptive + Small Size Filter + Large Dilation": {
        "function": lambda img: adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=100, radius=15),
        "description": "Adaptive thresholding followed by dilation with circular kernel (radius=15px) and size filter (min_area=100px)"
    },
    # "Adaptive + Moderate Size Filter + Small Dilation": {
    #     "function": lambda img: adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=250, radius=5),
    #     "description": "Adaptive thresholding followed by dilation with circular kernel (radius=5px) and size filter (min_area=250px)"
    # },
    # "Adaptive + Moderate Size Filter + Moderate Dilation": {
    #     "function": lambda img: adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=250, radius=10),
    #     "description": "Adaptive thresholding followed by dilation with circular kernel (radius=10px) and size filter (min_area=250px)"
    # },
    # "Adaptive + Moderate Size Filter + Large Dilation": {
    #     "function": lambda img: adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=250, radius=15),
    #     "description": "Adaptive thresholding followed by dilation with circular kernel (radius=15px) and size filter (min_area=250px)"
    # },
    # "Adaptive + Large Size Filter + Small Dilation": {
    #     "function": lambda img: adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=500, radius=5),
    #     "description": "Adaptive thresholding followed by dilation with circular kernel (radius=5px) and size filter (min_area=500px)"
    # },
    # "Adaptive + Large Size Filter + Moderate Dilation": {
    #     "function": lambda img: adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=500, radius=10),
    #     "description": "Adaptive thresholding followed by dilation with circular kernel (radius=10px) and size filter (min_area=500px)"
    # },
    # "Adaptive + Large Size Filter + Large Dilation": {
    #     "function": lambda img: adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=500, radius=15),
    #     "description": "Adaptive thresholding followed by dilation with circular kernel (radius=15px) and size filter (min_area=500px)"
    # },
    # "Otsu Thresholding": {
    #     "function": lambda img: cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1],
    #     "description": "Automatic global thresholding using Otsu's method"
    # },
    # "Otsu + Dilation": {
    #     "function": lambda img: otsu_with_dilation(img, radius=10),
    #     "description": "Otsu thresholding followed by dilation with circular kernel (radius=10px)"
    # },
}

print(f"Total methods available: {list(thresholding_methods.keys())}")

# Main analysis loop
all_results = {}

for i, (method_name, method_info) in enumerate(thresholding_methods.items()):
    print(f"METHOD {i+1}: {method_name.upper()}")
    print("-" * 40)
    print(f"Description: {method_info['description']}")
    
    # Apply thresholding method
    thresholded_img = method_info["function"](img)
    
    # Find contours
    contours, _ = cv2.findContours(thresholded_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Save downsampled thresholded image
    thresholded_img_downsampled = downsample_image(thresholded_img, scale_factor=5000)
    safe_filename = method_name.lower().replace(" ", "_").replace("(", "").replace(")", "").replace(",", "").replace("+", "and")
    cv2.imwrite(output_path + f"thresholded_img_{safe_filename}_downsampled.png", thresholded_img_downsampled)
    
    for i, (x0, y0) in enumerate(target_coords):
        patch = thresholded_img_downsampled[y0:y0+h, x0:x0+w]
        cv2.imwrite(output_path + f"thresholded_patch_{safe_filename}_downsampled_{i}.png", patch)
    
    # Analyze objects
    objects = analyze_objects(contours, method_name)
    
    # Print results
    print_method_results(objects, method_name)
    
    # Store results
    all_results[method_name] = {
        "objects": objects,
        "thresholded_image": thresholded_img,
        "num_objects": len(objects)
    }
    
    print("\n" + "="*80)

# Method Comparison Summary
print("COMPARISON SUMMARY")
print("-" * 40)
for method_name, results in all_results.items():
    print(f"{method_name}: {results['num_objects']} objects detected")

print("\n" + "="*80)

# Store results for further analysis
method_results = all_results
with open(output_path + "method_results.pickle", "wb") as handle:
    pickle.dump(method_results, handle)

In [None]:
# Save object areas
areas_um2 = get_all_methods_statistics(method_results, "area_um2")
pd.DataFrame({"area": areas_um2["Adaptive + Small Size Filter + Small Dilation"]}).to_csv("../validation/areas_MERSCOPE_small_size_small_dilation.csv", index=False)
pd.DataFrame({"area": areas_um2["Adaptive + Small Size Filter + Moderate Dilation"]}).to_csv("../validation/areas_MERSCOPE_small_size_moderate_dilation.csv", index=False)

### Benchmark 2: in-soma ratio of transcripts

In [None]:
# Read transcripts
transcripts = pd.read_csv(data_path + "raw_data/transcripts.csv")
transcripts = transcripts[["cell_id", "gene", "global_x", "global_y", "global_z"]].copy()
transcripts.head()

In [None]:
# Define target genes
all_genes = pd.read_csv(data_path + "processed_data/genes.csv")
all_genes = all_genes["genes"].tolist()

with open("../data/utils/overlap_genes.pickle", "rb") as handle:
    overlap_genes = pickle.load(handle)
granule_markers = overlap_genes["overlap_genes_select"]

transcripts = transcripts[transcripts["gene"].isin(all_genes)].copy()

In [None]:
# ==================== Thresholding Methods Benchmarking ====================#

# Define thresholding methods
thresholding_methods = {
    "Adaptive Thresholding": {
        "function": lambda img: cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 49, -1),
        "description": "Gaussian-weighted adaptive thresholding with block size 49"
    },
    "Adaptive + Small Size Filter + Small Dilation": {
        "function": lambda img: adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=100, radius=5),
        "description": "Adaptive thresholding followed by dilation with circular kernel (radius=5px) and size filter (min_area=100px)"
    },
    "Adaptive + Small Size Filter + Moderate Dilation": {
        "function": lambda img: adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=100, radius=10),
        "description": "Adaptive thresholding followed by dilation with circular kernel (radius=10px) and size filter (min_area=100px)"
    },
}

# Initialize transcript dataframe with binary columns for each method
transcripts_with_labels = transcripts.copy()
for method_name in thresholding_methods.keys():
    safe_col_name = method_name.replace(" ", "_").replace("(", "").replace(")", "").replace(",", "").replace("+", "and")
    transcripts_with_labels[f"in_soma_{safe_col_name}"] = 0

print("Starting thresholding methods benchmarking...\n")

# Process each z-layer (exclude MIP file)
z_layer_files = [f for f in files if f.startswith("mosaic_DAPI_z")]
z_layer_files.sort()

# Process each thresholding method
for method_idx, (method_name, method_info) in enumerate(thresholding_methods.items()):
    print(f"Processing method {method_idx+1}/{len(thresholding_methods)}: {method_name}")
    
    # Get safe column name for this method
    safe_col_name = method_name.replace(" ", "_").replace("(", "").replace(")", "").replace(",", "").replace("+", "and")
    in_soma_col = f"in_soma_{safe_col_name}"
    
    for j, fname in enumerate(z_layer_files):
        # Load DAPI image
        img_path = os.path.join(data_path, "raw_data/DAPI_images", fname)
        img = tifffile.imread(img_path)
        img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
        
        # Apply thresholding method
        th = method_info["function"](img)
        
        # Map transcripts to this z-layer and check overlap
        trans_z_mask = transcripts["global_z"] == j
        trans_z_indices = transcripts.index[trans_z_mask].values
        
        if len(trans_z_indices) > 0:
            # Calculate row and col from global_x and global_y
            # Convert global coordinates (in microns) to pixel coordinates
            global_x_vals = transcripts.loc[trans_z_indices, "global_x"].values
            global_y_vals = transcripts.loc[trans_z_indices, "global_y"].values
            
            # Convert to pixel coordinates (divide by pixel_size and apply shifts)
            col_vals = (global_x_vals / pixel_size).astype(int) + x_shift
            row_vals = (global_y_vals / pixel_size).astype(int) + y_shift
            
            # Visualize mapped transcripts for first DAPI image only
            if method_idx == 0 and j == 0:
                # Filter transcripts for Slc17a7 gene
                gene_mask = transcripts.loc[trans_z_indices, "gene"] == "Slc17a7"
                gene_local_indices = np.where(gene_mask)[0]  # Indices within trans_z_indices
                
                if len(gene_local_indices) > 0:
                    # Get coordinates for Slc17a7 transcripts using already calculated values
                    gene_col = col_vals[gene_local_indices]
                    gene_row = row_vals[gene_local_indices]
                    
                    # Create color version of DAPI image
                    img_color = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
                    
                    # Draw red dots for Slc17a7 transcripts
                    height, width = img_color.shape[:2]
                    for r, c in zip(gene_row, gene_col):
                        if 0 <= r < height and 0 <= c < width:
                            cv2.circle(img_color, (c, r), 5, (0, 0, 255), -1)  # Red dot, radius 5
                    
                    # Downsample the annotated image
                    img_downsampled = downsample_image(img_color, scale_factor=5000)
                    
                    # Save the downsampled annotated image
                    output_filename = f"mapped_transcripts_Slc17a7_z{j}_{safe_col_name}.png"
                    output_filepath = os.path.join(output_path, output_filename)
                    cv2.imwrite(output_filepath, img_downsampled)
                    print(f"  Saved visualization to: {output_filename}")
            
            # Avoid out-of-bounds indexing
            height, width = th.shape
            valid = (row_vals >= 0) & (row_vals < height) & (col_vals >= 0) & (col_vals < width)
            row_valid = row_vals[valid]
            col_valid = col_vals[valid]
            valid_indices = trans_z_indices[valid]
            
            # Check overlap for valid transcripts
            if len(row_valid) > 0:
                overlaps = (th[row_valid, col_valid] != 0).astype(int)
                # Update in-soma labels in the transcript dataframe
                transcripts_with_labels.loc[valid_indices, in_soma_col] = overlaps

print("Completed processing all methods.\n")

# Calculate per-gene in-soma ratios for each method
gene_ratios = []

for method_name in thresholding_methods.keys():
    safe_col_name = method_name.replace(" ", "_").replace("(", "").replace(")", "").replace(",", "").replace("+", "and")
    in_soma_col = f"in_soma_{safe_col_name}"
    
    # Group by gene and calculate in-soma ratio
    gene_stats = transcripts_with_labels.groupby("gene").agg({
        in_soma_col: ["sum", "count"]
    }).reset_index()
    gene_stats.columns = ["gene", "in_soma_count", "total_count"]
    gene_stats["in_soma_ratio"] = gene_stats["in_soma_count"] / gene_stats["total_count"]
    gene_stats["method"] = method_name
    
    gene_ratios.append(gene_stats)

# Combine all gene ratios
gene_ratios_df = pd.concat(gene_ratios, ignore_index=True)

# Add gene category labels
gene_ratios_df["gene_category"] = "Others"
gene_ratios_df.loc[gene_ratios_df["gene"].isin(granule_markers), "gene_category"] = "Granule Markers"

# Save transcript file with in-soma labels
transcripts_output_path = os.path.join(data_path, "processed_data", "transcripts_with_in_soma_labels.parquet")
transcripts_with_labels.to_parquet(transcripts_output_path)
print(f"Transcripts with in-soma labels saved to: {transcripts_output_path}")

# Save gene-level in-soma ratios
gene_ratios_output_path = os.path.join(output_path, "gene_in_soma_ratios.csv")
gene_ratios_df.to_csv(gene_ratios_output_path, index=False)
print(f"Gene-level in-soma ratios saved to: {gene_ratios_output_path}")