In [126]:
import anndata
import matplotlib.colors as clr
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import seaborn as sns
from scipy import sparse
from scipy.stats import spearmanr, wilcoxon
from statsmodels.stats.multitest import multipletests

import warnings
warnings.filterwarnings("ignore")
sc.settings.verbosity = 0

In [127]:
# Color
color_cts = clr.LinearSegmentedColormap.from_list("magma", ["#000003", "#3B0F6F", "#8C2980", "#F66E5B", "#FD9F6C", "#FBFCBF"], N=256)

In [128]:
# Utility functions
def npz_to_adata(npz_path: str, adata_tumor):
    
    X = sparse.load_npz(npz_path).tocsr()

    n_cells = adata_tumor.n_obs
    n_genes = adata_tumor.n_vars
    
    if X.shape != (n_cells, n_genes):
        raise ValueError(f"Shape mismatch for {npz_path}: {X.shape} vs tumor {(n_cells, n_genes)}")

    obs = adata_tumor.obs.copy()
    var = adata_tumor.var.copy()

    adata = anndata.AnnData(X = X, obs = obs, var = var)
    return adata


def rowwise_pearson_sparse(Xn: sp.csr_matrix, Xs: sp.csr_matrix, mask: np.ndarray):
    
    n_cells = Xn.shape[0]
    corr = np.full(n_cells, np.nan, dtype = np.float32)

    idx = np.where(mask)[0]
    if idx.size == 0:
        return corr

    for i in idx:
        
        xn = Xn.getrow(i)
        xs = Xs.getrow(i)

        # convert to dense 1D arrays
        a = xn.toarray().ravel()
        b = xs.toarray().ravel()

        # handle constant vectors
        a_mean = a.mean()
        b_mean = b.mean()
        a0 = a - a_mean
        b0 = b - b_mean
        denom = np.sqrt((a0 @ a0) * (b0 @ b0))
        if denom == 0:
            corr[i] = np.nan
        else:
            corr[i] = (a0 @ b0) / denom

    return corr


def rowwise_spearman_sparse(Xn: sp.csr_matrix, Xs: sp.csr_matrix, mask: np.ndarray):
    
    n_cells = Xn.shape[0]
    corr = np.full(n_cells, np.nan, dtype=np.float32)

    idx = np.where(mask)[0]
    if idx.size == 0:
        return corr

    for i in idx:
        a = Xn.getrow(i).toarray().ravel()
        b = Xs.getrow(i).toarray().ravel()

        # spearmanr returns nan if constant
        r, _ = spearmanr(a, b)
        corr[i] = r

    return corr


def rowwise_jaccard_csr(B1: sp.csr_matrix, B2: sp.csr_matrix, mask: np.ndarray):
    
    jac = np.full(B1.shape[0], np.nan, dtype=np.float32)

    idx = np.where(mask)[0]
    if idx.size == 0:
        return jac

    for i in idx:
        a = B1.indices[B1.indptr[i]:B1.indptr[i+1]]
        b = B2.indices[B2.indptr[i]:B2.indptr[i+1]]

        # sizes
        if a.size == 0 and b.size == 0:
            jac[i] = np.nan
            continue

        # intersection size
        inter = np.intersect1d(a, b, assume_unique=True).size
        union = a.size + b.size - inter
        jac[i] = inter / union if union > 0 else np.nan

    return jac

In [130]:
# ==================== Main operations ==================== #

settings = {"Xenium_5K_BC": {"cell_type_label": True},
            "Xenium_5K_OC": {"cell_type_label": True},
            "Xenium_5K_CC": {"cell_type_label": True},
            "Xenium_5K_LC": {"cell_type_label": False},
            "Xenium_5K_Prostate": {"cell_type_label": False},
            "Xenium_5K_Skin": {"cell_type_label": False}}

min_total_counts = 5
min_nnz_genes = 20
os.makedirs("SG_vs_nuclei", exist_ok=True)

for data in settings.keys():
    
    print(f"========== Processing {data}... ==========")
    
    # paths
    data_dir = f"../../data/{data}/"
    utils_dir = "../../data/_utils/"
    output_dir = "SG_vs_nuclei/"
    
    # read data
    adata = sc.read_h5ad(data_dir + "intermediate_data/adata.h5ad")
    adata_tumor = adata[adata.obs["cell_type_merged"] == "Malignant cell"].copy()
    granule_adata = sc.read_h5ad(data_dir + "processed_data/granule_adata.h5ad")
    
    # load nucleus and SG expression matrices
    adata_nuc = npz_to_adata(data_dir + "processed_data/nuclear_expression_matrix.npz", adata_tumor)
    adata_sg = npz_to_adata(data_dir + "processed_data/SG_expression_matrix.npz", adata_tumor)
    
    # SG presence and library size
    sg_libsize = adata_sg.X.sum(axis = 1).A1
    adata_tumor.obs["SG_present"] = sg_libsize > 0
    adata_tumor.obs["SG_libsize"] = sg_libsize
    
    # normalize nucleus and SG data
    sc.pp.normalize_total(adata_nuc, target_sum=1e4)
    sc.pp.log1p(adata_nuc)
    
    sc.pp.normalize_total(adata_sg, target_sum=1e4)
    sc.pp.log1p(adata_sg)

    # cell-wise Pearson & Spearman correlation
    Xn = adata_nuc.X.tocsr()
    Xs = adata_sg.X.tocsr()

    sg_total_counts = np.asarray(Xs.sum(axis=1)).ravel()
    sg_nnz_genes = np.diff(Xs.indptr)

    adata_sg.obs["SG_total_counts"] = sg_total_counts
    adata_sg.obs["SG_nnz_genes"] = sg_nnz_genes
    adata_sg.obs["SG_present"] = sg_total_counts > 0

    keep = (sg_total_counts >= min_total_counts) & (sg_nnz_genes >= min_nnz_genes)
    keep_idx = np.where(keep)[0]

    print(f"Keep {keep_idx.size}/{Xs.shape[0]} cells ({keep_idx.size/Xs.shape[0]*100:.2f}%) with SG_total_counts>={min_total_counts} and SG_nnz_genes>={min_nnz_genes}")

    pearson_corr = rowwise_pearson_sparse(Xn, Xs, keep)
    adata_sg.obs["corr_nuc_vs_sg_pearson"] = pearson_corr
    print(f"Pearson corr (selected cells): {np.nanmedian(pearson_corr[keep_idx])}, n = {np.sum(keep)}")
    
    spearman_corr = rowwise_spearman_sparse(Xn, Xs, keep)
    adata_sg.obs["corr_nuc_vs_sg_spearman"] = spearman_corr
    print(f"Spearman corr (selected cells): {np.nanmedian(spearman_corr[keep_idx])}, n = {np.sum(keep)}")
    
    # cell-wise Jaccard similarity
    Bn = Xn.copy()
    Bn.data = np.ones_like(Bn.data)
    Bn.eliminate_zeros()

    Bs = Xs.copy()
    Bs.data = np.ones_like(Bs.data)
    Bs.eliminate_zeros()

    jacc = rowwise_jaccard_csr(Bn, Bs, keep)
    adata_sg.obs["jaccard_nuc_vs_sg"] = jacc
    print(f"Median Jaccard (selected cells): {np.nanmedian(jacc[keep_idx])}, n = {np.sum(keep)}")
    
    # correlation distributions and vs SG abundance
    for metric, label in zip(["corr_nuc_vs_sg_pearson", "corr_nuc_vs_sg_spearman", "jaccard_nuc_vs_sg"], ["Pearson correlation", "Spearman correlation", "Jaccard similarity"]):
        
        df_plot = adata_sg.obs.loc[keep, [metric, "SG_total_counts"]].copy()
        df_plot = df_plot.dropna(subset=[metric, "SG_total_counts"])  # same rows dropped for both
        vals = df_plot[metric].to_numpy()
        libs = df_plot["SG_total_counts"].to_numpy()
        
        plt.figure(figsize=(6,4))
        sns.histplot(vals, bins=50, kde=False, edgecolor="gray")
        plt.xlabel(" ")
        plt.ylabel("Number of SG+ tumor cells")
        plt.title(f"Distribution of {label}")
        plt.savefig(output_dir + f"{data}_{metric}_distribution.png", dpi=300, bbox_inches="tight")
        plt.close()
        
        plt.figure(figsize=(6,4))
        plt.scatter(libs, vals, s=4, edgecolor="gray", linewidths=0.25)
        plt.xlabel("SG total counts per cell")
        plt.ylabel(label)
        plt.title("Correlation vs SG abundance")
        plt.savefig(output_dir + f"{data}_{metric}_vs_SG_abundance.png", dpi=300, bbox_inches="tight")
        plt.close()
    
    # cell-wise contrasts
    Delta = (Xs[keep_idx, :] - Xn[keep_idx, :]).tocsr()
    
    gene_names = adata_nuc.var_names.to_numpy()
    delta_mean = np.asarray(Delta.mean(axis=0)).ravel()
    df_delta = pd.DataFrame({"gene": gene_names, "delta_mean": delta_mean})
    df_delta_sorted = df_delta.sort_values("delta_mean", ascending=False)
    
    print("Top SG-enriched genes:")
    print(df_delta_sorted.head(5))
    print("\nTop nuclear-retained genes:")
    print(df_delta_sorted.tail(5))
    
    # volcano plot
    Delta_csc = Delta.tocsc()
    min_nonzero_cells = 50
    
    pvals = np.ones(Delta_csc.shape[1], dtype=float)
    nnz_by_gene = np.diff(Delta_csc.indptr)

    for j in range(Delta_csc.shape[1]):
        
        # skip genes with too few nonzero cells
        if nnz_by_gene[j] < min_nonzero_cells:
            continue
        
        # Wilcoxon test vs 0 (paired difference)
        col = Delta_csc.getcol(j)
        d = col.toarray().ravel()
        if np.allclose(d, 0):
            pvals[j] = 1.0
        else:
            try:
                _, p = wilcoxon(d, alternative="two-sided", zero_method="wilcox")
                pvals[j] = p if np.isfinite(p) else 1.0
            except ValueError:
                pvals[j] = 1.0

    res = pd.DataFrame({"gene": gene_names, "delta_mean": delta_mean, "pval": pvals, "nnz_cells": nnz_by_gene})
    res["neglog10p"] = -np.log10(np.maximum(res["pval"], 1e-300))
    
    top_k = 5
    top_pos = res.sort_values("delta_mean", ascending=False).head(top_k)
    top_neg = res.sort_values("delta_mean", ascending=True).head(top_k)

    plt.figure(figsize=(6,5))
    plt.scatter(res["delta_mean"], res["neglog10p"], s=4, edgecolors="gray", linewidths=0.25)
    plt.axvline(0, linestyle="--", color="red", linewidth=0.8)
    for _, r in top_pos.iterrows():
        plt.text(r["delta_mean"], r["neglog10p"], r["gene"], fontsize=4)
    plt.xlabel("Mean SG–nucleus difference")
    plt.ylabel("-log10 p-value")
    plt.title(" ")
    plt.savefig(output_dir + f"{data}_volcano_plot.png", dpi=300, bbox_inches="tight")
    plt.close()
    
    # heatmap of top differential genes
    top_genes = pd.concat([res.sort_values("delta_mean", ascending=False).head(10), res.sort_values("delta_mean", ascending=True).head(10)])["gene"].tolist()
    gene_to_idx = {g:i for i,g in enumerate(gene_names)}
    cols = [gene_to_idx[g] for g in top_genes if g in gene_to_idx]

    M = Delta[:, :][:, cols].toarray()

    k = len(cols)//2
    score = M[:, :k].mean(axis=1) - M[:, k:].mean(axis=1)
    row_order = np.argsort(score)[::-1]
    M = M[row_order, :]

    col_order = np.argsort(M.mean(axis=0))[::-1]
    M = M[:, col_order]
    genes = [genes[j] for j in col_order]

    vmax = np.quantile(np.abs(M), 0.98)
    vmin = -vmax

    plt.figure(figsize=(12, 7))
    ax = sns.heatmap(M, cmap="vlag", vmin=vmin, vmax=vmax, xticklabels=genes, yticklabels=False, cbar_kws={"label": "Mean SG–nucleus difference"}, rasterized=True)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90, ha="center", fontsize=7)
    ax.set_title(" ")
    plt.savefig(output_dir + f"{data}_heatmap.png", dpi=300, bbox_inches="tight")
    plt.close()

Keep 31907/102180 cells (31.23%) with SG_total_counts>=5 and SG_nnz_genes>=20
Pearson corr (selected cells): 0.09244156628847122, n = 31907
Spearman corr (selected cells): 0.08664211630821228, n = 31907
Median Jaccard (selected cells): 0.0357142873108387, n = 31907
Top SG-enriched genes:
        gene  delta_mean
473     CA12    0.775163
2177   IL6ST    0.447454
2808    MYH9    0.421381
888     CLTC    0.385632
3569  PTP4A1    0.377083

Top nuclear-retained genes:
           gene  delta_mean
4141     SNHG15   -1.740062
1972    HNRNPH1   -1.751142
3130    PABPC1L   -1.790734
4186    SOX2-OT   -2.050587
2975  NOTCH2NLA   -2.438242
Keep 48919/160250 cells (30.53%) with SG_total_counts>=5 and SG_nnz_genes>=20
Pearson corr (selected cells): 0.11742059886455536, n = 48919
Spearman corr (selected cells): 0.1050107479095459, n = 48919
Median Jaccard (selected cells): 0.0451977401971817, n = 48919
Top SG-enriched genes:
        gene  delta_mean
2808    MYH9    0.040996
3144   PALLD    0.002623
2