In [None]:
import cv2
import numpy as np
import os
import pandas as pd
import scanpy as sc
import tifffile
from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle

In [None]:
# Adaptive thresholding with size filter and dilation
def adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=100, radius=10):
    th = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, c)
    contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    filtered_contours = [c for c in contours if cv2.contourArea(c) >= min_area]
    filtered_img = np.zeros_like(th)
    cv2.drawContours(filtered_img, filtered_contours, -1, 255, -1)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1))
    th_dilated = cv2.dilate(filtered_img, kernel)
    return th_dilated

In [None]:
# Settings
settings = {"MERSCOPE_WT_1": {"pixel_size": 0.10799861, "x_shift": int(-266.1734), "y_shift": int(180.2510), "cutoff": 6250, "theta": 10 * np.pi / 180, "coordinate_for_rotation": ["global_y", "global_x"], "coordinate_for_cutoff": "global_y", "cutoff_direction": "smaller"},
            "MERSCOPE_AD_1": {"pixel_size": 0.10799905, "x_shift": int(-126.9911), "y_shift": int(-20.3805), "cutoff": -4200, "theta": 170 * np.pi / 180, "coordinate_for_rotation": ["global_x", "global_y"], "coordinate_for_cutoff": "global_x", "cutoff_direction": "larger"}}

In [None]:
# ==================== Main operations on transcripts ==================== #

in_soma_label = "overlaps_nucleus_5_dilation"

for dataset in settings.keys():
    
    print("=" * 25)
    print(f"Processing {dataset}...")
    print("=" * 25)
    
    # File paths
    data_path = f"../data/{dataset}/"
    output_path = f"../output/{dataset}/"

    # Transformation parameters
    pixel_size = settings[dataset]["pixel_size"]
    x_shift = settings[dataset]["x_shift"]
    y_shift = settings[dataset]["y_shift"]
    
    # Load DAPI images
    files = os.listdir(data_path + "raw_data/DAPI_images/")
    files = [i for i in files if i.startswith("mosaic_DAPI_z")]
    files.sort()
    
    # Read transcripts
    transcripts = pd.read_csv(data_path + "raw_data/transcripts.csv")
    transcripts = transcripts[["cell_id", "gene", "global_x", "global_y", "global_z"]].copy()
    
    # Compute DAPI pixel coordinates
    transcripts["row"] = (transcripts["global_y"] / pixel_size).astype(int) + y_shift
    transcripts["col"] = (transcripts["global_x"] / pixel_size).astype(int) + x_shift

    # Add default overlap column
    transcripts["overlaps_nucleus_5_dilation"] = 0
    transcripts["overlaps_nucleus_10_dilation"] = 0
    
    # Update labels in place
    global_ratio = []

    for i, radius in enumerate([5, 10]):
        
        print(f"Processing dilation radius: {radius}")

        for j, fname in enumerate(files):
            
            # Load DAPI image
            img = tifffile.imread(data_path + f"raw_data/DAPI_images/{fname}")
            img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
            
            # Threshold and dilate
            th = adaptive_thresholding_with_size_filter_and_dilation(img, block_size=49, c=-1, min_area=100, radius=radius)

            # Save resized visualization (optional)
            th_small = cv2.resize(th, (3500, 5000), interpolation=cv2.INTER_AREA)
            cv2.imwrite(f"intermediate_data/images/z_{j}_small.png", th_small)
            
            # Select transcripts in this z-layer
            trans_z_mask = transcripts["global_z"] == j
            trans_z = transcripts[trans_z_mask].copy()
            row_vals = trans_z["row"].astype(int).values
            col_vals = trans_z["col"].astype(int).values

            # Avoid out-of-bounds indexing
            height, width = th.shape
            valid = (row_vals >= 0) & (row_vals < height) & (col_vals >= 0) & (col_vals < width)
            row_valid = row_vals[valid]
            col_valid = col_vals[valid]
            
            # Assign in-nucleus labels
            overlaps = np.zeros(len(trans_z), dtype=int)
            overlaps[valid] = (th[row_valid, col_valid] != 0).astype(int)

            # Update main DataFrame in-place
            transcripts.loc[trans_z.index, f"overlaps_nucleus_{radius}_dilation"] = overlaps

            # Track global ratio
            global_ratio.append(overlaps.mean())
            print(f"Iteration {j+1}: {np.sum(row_vals != row_valid)} row mismatches, {np.sum(col_vals != col_valid)} column mismatches, {overlaps.mean():.2%} in-nucleus")
    
    print(f"Average in-nucleus ratio (5 dilation): {transcripts['overlaps_nucleus_5_dilation'].mean()}")
    print(f"Average in-nucleus ratio (10 dilation): {transcripts['overlaps_nucleus_10_dilation'].mean()}")
    
    # Final labeled transcripts
    transcripts = transcripts[["cell_id", "overlaps_nucleus_5_dilation", "overlaps_nucleus_10_dilation", "gene", "global_x", "global_y", "global_z"]].copy()
    transcripts["global_z"] *= 1.5
    transcripts = transcripts.rename(columns = {"gene": "target"})
    
    # Cut hemisphere
    cutoff = settings[dataset]["cutoff"]
    theta = settings[dataset]["theta"]
    coordinate_for_rotation = settings[dataset]["coordinate_for_rotation"]
    coordinate_for_cutoff = settings[dataset]["coordinate_for_cutoff"]
    cutoff_direction = settings[dataset]["cutoff_direction"]
    
    rotation_matrix = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]])
    coords = transcripts[coordinate_for_rotation].to_numpy()
    transformed_coords = coords @ rotation_matrix.T
    transcripts[f"{coordinate_for_rotation[0]}_new"] = transformed_coords[:, 0]
    transcripts[f"{coordinate_for_rotation[1]}_new"] = transformed_coords[:, 1]
    if cutoff_direction == "smaller":
        transcripts_cut = transcripts[transcripts[f"{coordinate_for_cutoff}_new"] <= cutoff].copy()
    elif cutoff_direction == "larger":
        transcripts_cut = transcripts[transcripts[f"{coordinate_for_cutoff}_new"] >= cutoff].copy()
    transcripts_cut = transcripts_cut[["cell_id", in_soma_label, "target", "global_x", "global_y", "global_z"]].copy()
    transcripts_cut.rename(columns={in_soma_label: "overlaps_nucleus"}, inplace=True)

    transcripts_cut.to_parquet(data_path + "processed_data/transcripts.parquet")
    print(f"Processed {dataset}.")

In [None]:
# # Cut a small region from the hemisphere
# dataset = "MERSCOPE_WT_1"
# data_path = f"../data/{dataset}/"
# output_path = f"../output/{dataset}/"

# # x_min, x_max = 4500, 5500
# # y_min, y_max = 3000, 4000

# x_min, x_max = 3100, 4100
# y_min, y_max = 3100, 4100

# adata = sc.read_h5ad(data_path + "processed_data/adata.h5ad")

# sc.set_figure_params(figsize = (6, 9))
# ax = sc.pl.scatter(adata, x="global_y_new", y="global_x_new", color="brain_area", size=10, title = " ", show=False)
# rect = Rectangle((y_min, x_min), y_max - y_min, x_max - x_min, linewidth=2, edgecolor="red", facecolor="none")
# ax.add_patch(rect)
# ax.grid(False)
# ax.set_xticks([])
# ax.set_yticks([])
# ax.set_xlabel("")
# ax.set_ylabel("")
# for spine in ax.spines.values():
#     spine.set_visible(False)
# # plt.savefig(output_path +"small_region_annotation.png", dpi = 300, bbox_inches = "tight")
# plt.savefig(output_path +"small_region_2_annotation.png", dpi = 300, bbox_inches = "tight")
# plt.close()

# adata_cut = adata[(adata.obs["global_y_new"] >= y_min) & (adata.obs["global_y_new"] <= y_max) & (adata.obs["global_x_new"] >= x_min) & (adata.obs["global_x_new"] <= x_max)].copy()
# domains = "cell_type"
# num_celltype = len(adata_cut.obs[domains].unique())
# plot_color = ["#F56867","#FEB915","#C798EE","#59BE86","#7495D3","#6D1A9C","#15821E","#3A84E6","#997273","#787878","#DB4C6C","#9E7A7A","#554236","#AF5F3C","#93796C","#F9BD3F","#DAB370","#877F6C","#268785"]
# adata_cut.uns[domains+"_colors"] = list(plot_color[:num_celltype])

# sc.set_figure_params(figsize = (5, 5))
# ax = sc.pl.scatter(adata_cut, x="global_y_new", y="global_x_new", color="cell_type", size=15, title = " ", show=False)
# ax.grid(False)
# ax.set_xticks([])
# ax.set_yticks([])
# ax.set_xlabel("")
# ax.set_ylabel("")
# for spine in ax.spines.values():
#     spine.set_visible(False)
# # plt.savefig(output_path +"small_region.png", dpi = 300, bbox_inches = "tight")
# plt.savefig(output_path +"small_region_2.png", dpi = 300, bbox_inches = "tight")
# plt.close()

In [None]:
# # Retain the transcripts in the small region
# theta = 10 * np.pi / 180
# cutoff = 6250

# transcripts = pd.read_parquet(data_path + "processed_data/transcripts.parquet")

# rotation_matrix = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]])
# coords = transcripts[["global_y", "global_x"]].to_numpy()
# transformed_coords = coords @ rotation_matrix.T
# transcripts["global_y_new"] = transformed_coords[:, 0]
# transcripts["global_x_new"] = transformed_coords[:, 1]
# transcripts["global_y_new"] = cutoff - transcripts["global_y_new"]

# transcripts_cut = transcripts[(transcripts["global_y_new"] >= y_min) & (transcripts["global_y_new"] <= y_max) & (transcripts["global_x_new"] >= x_min) & (transcripts["global_x_new"] <= x_max)].copy()
# # transcripts_cut.to_parquet(data_path + "processed_data/transcripts_small_region.parquet")
# transcripts_cut.to_parquet(data_path + "processed_data/transcripts_small_region_2.parquet")