In [6]:
import matplotlib.colors as clr
import matplotlib.pyplot as plt
import gseapy as gp
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import seaborn as sns
from matplotlib.patches import Patch
from scipy import sparse
from scipy.stats import wilcoxon
from statsmodels.stats.multitest import multipletests

import warnings
warnings.filterwarnings("ignore")
sc.settings.verbosity = 0

In [7]:
# Color
color_cts = clr.LinearSegmentedColormap.from_list("magma", ["#000003", "#3B0F6F", "#8C2980", "#F66E5B", "#FD9F6C", "#FBFCBF"], N=256)

In [8]:
# ==================== ssGSEA functions ==================== #

# Read GMT file into dict: {pathway: [genes]}
def read_gmt(gmt_path: str) -> dict:
    gene_sets = {}
    with open(gmt_path, "r") as f:
        for line in f:
            if not line.strip():
                continue
            parts = line.rstrip("\n").split("\t")
            if len(parts) < 3:
                continue
            gs_name = parts[0]
            genes = [g for g in parts[2:] if g]
            gene_sets[gs_name] = genes
    return gene_sets

# Convert gseapy ssGSEA res (res.res2d, long format) to scores matrix (sample by pathway)
def res2d_to_scores(res, score_col = "NES"):
    
    df = res.res2d.copy()

    col_map = {c.lower(): c for c in df.columns}
    name_col = col_map.get("name", "Name")
    term_col = col_map.get("term", "Term")

    score_col_actual = None
    for c in df.columns:
        if c.upper() == score_col.upper():
            score_col_actual = c
            break
    if score_col_actual is None:
        raise ValueError(f"Score column {score_col} not found. Available: {list(df.columns)}")

    scores = df.pivot(index=name_col, columns=term_col, values=score_col_actual)
    scores.index.name = "cell_id"
    return scores

# ssGSEA from cell by gene matrix (npz format)
def ssGSEA_from_cellxgene_npz_filtered(npz_path: str, cell_ids: list, gene_ids: list, gmt_path: str, out_path: str, chunk_size: int = 2000, min_geneset_size: int = 5, max_geneset_size: int = 5000, do_log1p: bool = True, do_cpm: bool = True, min_total_counts: int = 5, min_nnz_genes: int = 20):
    
    # load cell by gene matrix
    X = sparse.load_npz(npz_path).tocsr()
    if X.shape != (len(cell_ids), len(gene_ids)):
        raise ValueError(f"Shape mismatch: X {X.shape} vs {(len(cell_ids), len(gene_ids))}")

    # parse GMT into dict: {pathway: [genes]}
    gene_sets = read_gmt(gmt_path)
    pathway_names = list(gene_sets.keys())

    # ---------
    # 1) filter cells (SG-positive + enough genes)
    # ---------
    total_counts = np.asarray(X.sum(axis=1)).ravel()
    nnz_genes = np.diff(X.indptr)  # number of nonzero genes per row in CSR

    keep = (total_counts >= min_total_counts) & (nnz_genes >= min_nnz_genes)
    keep_idx = np.where(keep)[0]
    
    print(f"ssGSEA filtering: keeping {keep_idx.size} / {X.shape[0]} cells ({keep_idx.size / X.shape[0] * 100:.2f}%)")

    # pre-allocate full scores (all cells) as zeros
    scores_full = pd.DataFrame(
        0.0,
        index=pd.Index(cell_ids, name="cell_id"),
        columns=pathway_names,
        dtype=np.float32,
    )

    # if nothing passes filtering, just write zeros and return
    if keep_idx.size == 0:
        scores_full.to_parquet(out_path)
        return scores_full

    # ---------
    # 2) run ssGSEA on kept cells only (chunk over kept_idx)
    # ---------
    for start in range(0, keep_idx.size, chunk_size):
        end = min(start + chunk_size, keep_idx.size)
        idx = keep_idx[start:end]

        Xb = X[idx, :].astype(np.float32)

        # optional: CPM + log1p to reduce ties (many zeros) and depth effects
        if do_cpm:
            libsize = np.asarray(Xb.sum(axis=1)).ravel()
            libsize[libsize == 0] = 1.0
            Xb = Xb.multiply(1e6 / libsize[:, None])
        if do_log1p:
            Xb = Xb.copy()
            Xb.data = np.log1p(Xb.data)

        # gseapy wants genes by samples (DataFrame)
        expr = pd.DataFrame(
            Xb.toarray().T,
            index=gene_ids,
            columns=[cell_ids[i] for i in idx],
        )

        res = gp.ssgsea(
            data=expr,
            gene_sets=gene_sets,
            sample_norm_method="rank",
            min_size=min_geneset_size,
            max_size=max_geneset_size,
            outdir=None,
            verbose=False,
            processes=1,
        )

        # sample by pathway
        scores = res2d_to_scores(res, score_col="NES")
        scores = scores.reindex(index=[cell_ids[i] for i in idx], columns=pathway_names)

        # write back into the full matrix (others stay 0)
        scores_full.loc[scores.index, scores.columns] = scores.astype(np.float32)

    scores_full.to_parquet(out_path)
    return scores_full

In [9]:
# ==================== Main operations ==================== #

settings = {"Xenium_5K_BC": {"cell_type_label": True},
            "Xenium_5K_OC": {"cell_type_label": True},
            "Xenium_5K_CC": {"cell_type_label": True},
            "Xenium_5K_LC": {"cell_type_label": False},
            "Xenium_5K_Prostate": {"cell_type_label": False},
            "Xenium_5K_Skin": {"cell_type_label": False}}

# for data in settings.keys():
for data in ["Xenium_5K_BC"]:
    
    print(f"========== Processing {data}... ==========")
    
    # paths
    data_dir = f"../../data/{data}/"
    utils_dir = "../../data/_utils/"
    output_dir = f"../../output/{data}/"
    
    # read data
    adata = sc.read_h5ad(data_dir + "intermediate_data/adata.h5ad")
    adata_tumor = adata[adata.obs["cell_type_merged"] == "Malignant cell"].copy()
    granule_adata = sc.read_h5ad(data_dir + "processed_data/granule_adata.h5ad")
    
    # determine plot size
    x_range = adata.obs["global_x"].max() - adata.obs["global_x"].min()
    y_range = adata.obs["global_y"].max() - adata.obs["global_y"].min()
    short_edge = min(x_range, y_range)

    scale = 5 / short_edge
    plot_figsize = (int(x_range * scale), int(y_range * scale))
    print(f"Plot size: {plot_figsize}")
    
    # granule expression information
    X = granule_adata.X
    if not isinstance(X, np.ndarray):
        X = X.toarray()

    molecules_per_granule = X.sum(axis=1)
    median_molecules = np.median(molecules_per_granule)

    gene_detection_counts = (X > 0).sum(axis=0)
    total_granules = X.shape[0]
    threshold = 0.05 * total_granules
    genes_above_threshold = np.sum(gene_detection_counts > threshold)

    print(f"Median mRNA molecules per granule: {median_molecules}")
    print(f"Genes detected in at least one granule: {np.count_nonzero(gene_detection_counts)}")
    print(f"Genes detected in more than 5% of granules: {genes_above_threshold}")
    
    # check cell and gene IDs
    cell_ids = list(adata_tumor.obs["cell_id"])
    gene_ids = list(adata_tumor.var.index)
    
    cell_ids_npz = np.load(data_dir + "processed_data/cell_ids.npy", allow_pickle = True).tolist()
    gene_ids_npz = np.load(data_dir + "processed_data/gene_ids.npy", allow_pickle = True).tolist()
    
    if cell_ids_npz != cell_ids:
        raise ValueError("Cell ID order mismatch between NPZ and current adata_tumor!")

    if gene_ids_npz != gene_ids:
        raise ValueError("Gene order mismatch between NPZ and current adata_tumor!")
    
    # prepare data
    cell_ids = adata_tumor.obs["cell_id"].astype(str).to_numpy()
    granule_cell_ids = granule_adata.obs["cell_id"].astype(str).to_numpy()
    
    cell2row = {cid: i for i, cid in enumerate(cell_ids)}
    rows = np.fromiter((cell2row.get(cid, -1) for cid in granule_cell_ids),
                       dtype=np.int64, count=len(granule_cell_ids))
    keep = rows >= 0
    rows = rows[keep]
    
    Xg = granule_adata.X
    if not sp.isspmatrix(Xg):
        Xg = sp.csr_matrix(Xg)
    else:
        Xg = Xg.tocsr()
    Xg = Xg[keep, :]

    n_cells = adata_tumor.n_obs
    n_granules = Xg.shape[0]

    A = sp.csr_matrix(
        (np.ones(n_granules, dtype=np.float32),
        (rows, np.arange(n_granules, dtype=np.int64))),
        shape=(n_cells, n_granules),
    )
    
    X_sg = (A @ Xg).tocsr()
    sparse.save_npz(data_dir + "processed_data/SG_expression_matrix.npz", X_sg)
    print(f"Shape of the SG expression matrix: {X_sg.shape}")
    
    # run ssGSEA on SG expression
    # gmt_path = utils_dir + "hallmark_pathways_filtered.gmt"
    gmt_path = utils_dir + "all_pathways_filtered.gmt"
    scores_all = ssGSEA_from_cellxgene_npz_filtered(
            npz_path = data_dir + "processed_data/SG_expression_matrix.npz",
            cell_ids = cell_ids,
            gene_ids = gene_ids,
            gmt_path = gmt_path,
            out_path = data_dir + "processed_data/ssgsea_hallmark_sg.parquet",
        )
    scores_nz = scores_all.loc[(scores_all != 0).any(axis = 1)].copy()
    
    # global order based on SG-positive cells
    scores_long_tmp = scores_nz.reset_index().melt(id_vars="cell_id", var_name="Pathway", value_name="NES")
    global_order = scores_long_tmp.groupby("Pathway")["NES"].median().sort_values(ascending=False).index
    global_order_labels = [" ".join(s.capitalize() for s in i.split("_")[1:]) for i in global_order]
    
    for scores, label in zip([scores_all, scores_nz], ["all_cells", "SG_positive"]):
    
        # long format
        scores_long = scores.reset_index().melt(id_vars="cell_id", var_name="Pathway", value_name="NES")
        
        # statistical tests
        stats = []
        for pathway, df in scores_long.groupby("Pathway"):
            nes = df["NES"].dropna().to_numpy(dtype=float)
            if np.allclose(nes, 0):
                pval = 1.0
                stat = 0.0
            else:
                stat, pval = wilcoxon(nes, alternative="two-sided", zero_method="wilcox")
            stats.append({"Pathway": pathway, "median": np.median(nes), "pval": pval})
        stats_df = pd.DataFrame(stats)
        stats_df["qval"] = multipletests(stats_df["pval"], method="fdr_bh")[1]
        
        # determine significance
        alpha = 0.05
        stats_df["significance"] = "nonsignificant"
        stats_df.loc[(stats_df["qval"] < alpha) & (stats_df["median"] > 0), "significance"] = "positive"
        stats_df.loc[(stats_df["qval"] < alpha) & (stats_df["median"] < 0), "significance"] = "negative"
        
        scores_long = scores_long.merge(stats_df[["Pathway", "significance"]], on="Pathway", how="left")
        palette = {row.Pathway: "#d73027" if row["significance"] == "positive" else "#4575b4" if row["significance"] == "negative" else "lightgray" for _, row in stats_df.iterrows()}
        legend_handles = [Patch(facecolor="#d73027", edgecolor="black", label="Positive median NES"),
                            Patch(facecolor="#4575b4", edgecolor="black", label="Negative median NES"),
                            Patch(facecolor="lightgray", edgecolor="black", label="Not significant")]

        # boxplot of all pathways
        plt.figure(figsize=(25, 6))
        ax = sns.boxplot(data=scores_long, x="Pathway", y="NES", order=global_order, showfliers=False, palette=palette)
        ax.axhline(0, color="black", linestyle="--", linewidth=0.8)
        ax.set_xticklabels(global_order_labels, rotation=45, ha="right", fontsize=8)
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.legend(handles=legend_handles, loc="upper right", frameon=True, fontsize=12)
        plt.savefig(output_dir + f"ssgsea_hallmark_sg_{label}.jpeg", dpi = 300, bbox_inches = "tight")
        plt.close()
    
    # plot top pathways
    n_top = 5

    scores_mean = scores_all.mean(axis = 0).sort_values(ascending = False)
    top_scores = scores_mean.head(n_top)
    
    for pathway in top_scores.index:
        
        # add pathway to adata_tumor.obs
        pathway_label = f"{pathway}"
        adata_tumor.obs[pathway_label] = scores_all[pathway].values
        
        # plot pathway score
        sc.set_figure_params(figsize = plot_figsize)
        ax = sc.pl.scatter(adata_tumor, x="global_x", y="global_y", color=pathway_label, color_map=color_cts, size=1, show=False)
        ax.grid(False)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_title("")
        for spine in ax.spines.values():
            spine.set_visible(False)
        plt.savefig(output_dir + f"sg_{pathway_label}.jpeg", dpi = 300, bbox_inches = "tight")
        plt.close()

Plot size: (5, 7)
Median mRNA molecules per granule: 8.0
Genes detected in at least one granule: 4985
Genes detected in more than 5% of granules: 14
Shape of the SG expression matrix: (102180, 5001)
ssGSEA filtering: keeping 31907 / 102180 cells (31.23%)


In [10]:
scores

Unnamed: 0_level_0,GOBP_POSITIVE_REGULATION_OF_RNA_METABOLIC_PROCESS,GOBP_CELLULAR_RESPONSE_TO_STRESS,GOBP_REGULATION_OF_RESPONSE_TO_STRESS,GOBP_REGULATION_OF_RESPONSE_TO_EXTERNAL_STIMULUS,REACTOME_RNA_POLYMERASE_II_TRANSCRIPTION,GOBP_NEGATIVE_REGULATION_OF_TRANSCRIPTION_BY_RNA_POLYMERASE_II,REACTOME_POST_TRANSLATIONAL_PROTEIN_MODIFICATION,GOBP_REGULATION_OF_CELLULAR_RESPONSE_TO_STRESS,REACTOME_SIGNALING_BY_INTERLEUKINS,GOBP_POSITIVE_REGULATION_OF_RESPONSE_TO_EXTERNAL_STIMULUS,...,REACTOME_SENSORY_PERCEPTION,REACTOME_PROCESSING_OF_CAPPED_INTRON_CONTAINING_PRE_MRNA,REACTOME_SIGNALING_BY_TGFB_FAMILY_MEMBERS,GOBP_RIBONUCLEOPROTEIN_COMPLEX_SUBUNIT_ORGANIZATION,HALLMARK_SPERMATOGENESIS,REACTOME_HOST_INTERACTIONS_OF_HIV_FACTORS,REACTOME_POTENTIAL_THERAPEUTICS_FOR_SARS,GOBP_STRESS_FIBER_ASSEMBLY,GOBP_REGULATION_OF_MIRNA_TRANSCRIPTION,REACTOME_ION_CHANNEL_TRANSPORT
cell_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
aaaabkoj-1,0.076168,0.034577,-0.019903,-0.008922,0.058823,0.117428,0.061361,0.008573,0.005351,0.012420,...,0.057843,0.220716,0.019160,0.084002,-0.232082,-0.090494,0.034057,0.253198,0.243978,0.092077
aaababaf-1,0.076198,0.035394,-0.020473,-0.014264,0.046729,0.107377,0.055947,0.015324,-0.012759,0.008215,...,0.011568,0.223515,0.021213,0.068939,-0.230464,-0.088519,0.036281,0.279331,0.214983,0.093861
aaabbcid-1,0.074988,0.038123,-0.014413,-0.006463,0.055188,0.114568,0.067417,0.010903,-0.004805,0.009703,...,0.042601,0.209792,0.047430,0.111418,-0.234812,-0.093291,0.030439,0.270201,0.250071,0.087251
aaacljgp-1,0.080699,0.025355,-0.018407,-0.006115,0.060528,0.109995,0.063628,0.007933,-0.016598,0.010614,...,0.008560,0.218505,0.061986,0.103470,-0.233094,-0.090693,0.048037,0.330920,0.243532,0.091576
aaaecblc-1,0.075674,0.032483,-0.018940,-0.005460,0.045224,0.108730,0.058015,0.003729,-0.015984,0.016287,...,0.009432,0.205313,0.019014,0.050891,-0.232306,-0.090279,0.034148,0.253317,0.243995,0.092007
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
nkkmbikp-1,0.072945,0.029932,-0.022058,-0.010661,0.054944,0.109733,0.063390,0.003032,-0.010771,0.010483,...,0.009889,0.195693,0.019771,0.051938,-0.234173,-0.047134,0.106214,0.256361,0.246907,0.093742
nkknfabc-1,0.072817,0.035284,-0.017397,-0.010275,0.050662,0.103846,0.057392,0.016986,-0.002369,0.009418,...,0.030125,0.194435,0.057990,0.143484,-0.234826,-0.075387,0.033672,0.254825,0.228732,0.092251
nkknicmk-1,0.071515,0.035386,-0.008924,0.003858,0.061200,0.105385,0.069413,0.019454,-0.020218,0.011865,...,0.003342,0.236517,0.012890,0.189888,-0.239239,-0.080880,0.027894,0.279561,0.217477,0.085571
nmnlccmb-1,0.073242,0.033070,-0.022889,-0.017442,0.056373,0.119537,0.068208,0.018356,-0.019678,0.009059,...,0.035773,0.262257,0.014781,0.059554,-0.237958,-0.052378,0.068396,0.276582,0.208958,0.133463
