In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
import umap
import pandas as pd
import seaborn as sns
import numpy as np

from data_loader import load_xenium_breast_cancer

from prismo.prismo import (
    PRISMO,
    DataOptions,
    ModelOptions,
    TrainingOptions,
    SmoothOptions,
)
from prismo.gpu import get_free_gpu_idx
from prismo.io import save_model, load_model
from prismo.downstream import match
from prismo.plotting import (
    plot_training_curve,
    plot_variance_explained,
    plot_factor_correlation,
    plot_factors_scatter,
)

In [None]:
data = load_xenium_breast_cancer()
del data["group_visium"]
data

In [None]:
feature_stds = data["group_chromium"]["rna"].to_df().std()
(feature_stds > 0).mean()

In [None]:
data["group_chromium"]["rna"] = data["group_chromium"]["rna"][
    :, feature_stds.sort_values(ascending=False).iloc[:4000].index
].copy()

In [None]:
# joint_df = pd.concat([data["group_xenium"]["rna"].to_df(), data["group_chromium"]["rna"].to_df()], axis=0, join="outer")
# joint_df

In [None]:
data["group_chromium"]["rna"].var_names = (
    data["group_chromium"]["rna"].var["symbol"].astype(str)
)

In [None]:
data["group_xenium"]["rna"].var_names = (
    data["group_xenium"]["rna"].var["symbol"].astype(str)
)

In [None]:
from prismo import feature_sets

In [None]:
def to_upper(feature_set_collection):
    return feature_sets.FeatureSets(
        [
            feature_sets.FeatureSet([f.upper() for f in fs], fs.name)
            for fs in feature_set_collection
        ],
        name=feature_set_collection.name,
    )

In [None]:
hallmark_collection = feature_sets.from_gmt(
    "../msigdb/h.all.v7.5.1.symbols.gmt", name="hallmark"
)


reactome_collection = feature_sets.from_gmt(
    "../msigdb/c2.cp.reactome.v7.5.1.symbols.gmt", name="reactome"
)

kegg_collection = feature_sets.from_gmt(
    "../msigdb/c2.cp.kegg.v7.5.1.symbols.gmt", name="kegg"
)

celltype_collection = pd.read_csv(
    "adipose_markers.txt", sep="\t", comment="#", header=None
)
celltype_collection[1] = celltype_collection[1].str.split(",")
celltype_collection = feature_sets.from_dataframe(
    celltype_collection, name="celltype", name_col=0, features_col=1
)

# merge both collections
gene_set_collection = hallmark_collection
gene_set_collection = to_upper(gene_set_collection)
# sorted([fs.name for fs in gene_set_collection])
gene_set_collection

In [None]:
# keep = [
#     'HALLMARK_ANDROGEN_RESPONSE',
#     # 'HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION',
#     'HALLMARK_ESTROGEN_RESPONSE_LATE',
#     # 'HALLMARK_APICAL_SURFACE',
#     'HALLMARK_ESTROGEN_RESPONSE_EARLY',
#     # 'HALLMARK_ALLOGRAFT_REJECTION',
#     # 'HALLMARK_ANGIOGENESIS',
#     # 'HALLMARK_NOTCH_SIGNALING',
#     'HALLMARK_KRAS_SIGNALING_UP',
#     # 'HALLMARK_UV_RESPONSE_DN',
#     # 'HALLMARK_UV_RESPONSE_UP',
#     'HALLMARK_FATTY_ACID_METABOLISM',
#     'HALLMARK_P53_PATHWAY',
#     'HALLMARK_MTORC1_SIGNALING',
#     'HALLMARK_APOPTOSIS',
#     'HALLMARK_DNA_REPAIR',
#     # 'HALLMARK_MYOGENESIS',
#     # 'HALLMARK_UNFOLDED_PROTEIN_RESPONSE',
#     # 'HALLMARK_CHOLESTEROL_HOMEOSTASIS',
#     'HALLMARK_INTERFERON_GAMMA_RESPONSE',
#     'HALLMARK_IL2_STAT5_SIGNALING',
#     'HALLMARK_KRAS_SIGNALING_DN',
#     'HALLMARK_XENOBIOTIC_METABOLISM',
#     'HALLMARK_OXIDATIVE_PHOSPHORYLATION',
#     'HALLMARK_INTERFERON_ALPHA_RESPONSE',
#     'HALLMARK_HYPOXIA',
#     'neutrophil',
#     'HALLMARK_ADIPOGENESIS',
#     'HALLMARK_MYC_TARGETS_V1',
#     'HALLMARK_G2M_CHECKPOINT',
#     # 'HALLMARK_COMPLEMENT',
#     'nk_cell',
#     # 'HALLMARK_REACTIVE_OXYGEN_SPECIES_PATHWAY',
#     't_cell',
#     'HALLMARK_E2F_TARGETS',
#     # 'HALLMARK_BILE_ACID_METABOLISM',
#     # 'HALLMARK_PI3K_AKT_MTOR_SIGNALING',
#     'HALLMARK_COAGULATION',
#     # 'HALLMARK_PANCREAS_BETA_CELLS',
#     # 'HALLMARK_TNFA_SIGNALING_VIA_NFKB',
#     # 'HALLMARK_APICAL_JUNCTION',
#     # 'HALLMARK_HEDGEHOG_SIGNALING',
#     # 'HALLMARK_MITOTIC_SPINDLE',
#     'HALLMARK_MYC_TARGETS_V2',
#     # 'HALLMARK_WNT_BETA_CATENIN_SIGNALING',
#     'mast_cell',
#     # 'HALLMARK_SPERMATOGENESIS',
#     'b_cell',
#     # 'HALLMARK_TGF_BETA_SIGNALING',
#     # 'HALLMARK_PROTEIN_SECRETION',
#     # 'HALLMARK_PEROXISOME',
#     # 'HALLMARK_IL6_JAK_STAT3_SIGNALING',
#     # 'HALLMARK_INFLAMMATORY_RESPONSE',
#     # 'HALLMARK_GLYCOLYSIS',
#     # 'HALLMARK_HEME_METABOLISM',
#     'ASPC'
# ]

# gene_set_collection = gene_set_collection.keep(keep)

In [None]:
gene_set_collection = gene_set_collection.filter(
    data["group_chromium"]["rna"].var_names,
    min_fraction=0.2,
    min_count=15,
    max_count=300,
)
# sorted([fs.name for fs in gene_set_collection])
gene_set_collection

In [None]:
gene_set_collection.median_size

In [None]:
sorted([fs.name for fs in gene_set_collection])

In [None]:
[gs.name for gs in gene_set_collection if "HALLMARK" not in gs.name]

In [None]:
sorted(data["group_xenium"]["rna"].obs["celltype"].unique())

In [None]:
data["group_chromium"]["rna"].varm["gene_set_mask"] = gene_set_collection.to_mask(
    data["group_chromium"]["rna"].var_names.tolist()
).T

In [None]:
data["group_chromium"]["rna"]

In [None]:
device = "cpu"
try:
    device = f"cuda:{get_free_gpu_idx()}"
except Exception as e:
    print(e)
device

In [None]:
data_opts = DataOptions(
    group_by=None,
    scale_per_group=True,
    covariates_obs_key=None,
    covariates_obsm_key={"group_xenium": "spatial", "group_chromium": None},
    use_obs="union",
    use_var="union",
    plot_data_overview=False,
)

In [None]:
model_opts = ModelOptions(
    n_factors=3,
    weight_prior="Horseshoe",
    factor_prior={"group_xenium": "GP", "group_chromium": "Normal"},
    likelihoods="Normal",
    nonnegative_weights=True,
    nonnegative_factors=True,
    annotations=None,
    annotations_varm_key={"rna": "gene_set_mask"},
    prior_penalty=0.001,
    init_factors="random",
    init_scale=0.1,
)

In [None]:
training_opts = TrainingOptions(
    device=device,
    batch_size=10000,
    max_epochs=200,
    n_particles=1,
    lr=0.003,
    early_stopper_patience=10,
    print_every=100,
    save=False,
    save_path=None,
    seed=None,
)

In [None]:
smooth_opts = SmoothOptions(
    n_inducing=400,
    kernel="RBF",
    warp_groups=[],
    warp_interval=20,
    warp_open_begin=True,
    warp_open_end=True,
    warp_reference_group=None,
)

In [None]:
model = PRISMO()
model.fit(data, data_opts, model_opts, training_opts, smooth_opts)

# prismo_model = load_model("xenium_scrna_prismo_spatial_model")

In [None]:
# save_model(model, "prismo_model_xenium_chromium")

In [None]:
len(gene_set_collection)

In [None]:
len(model.factor_names)

In [None]:
from prismo import plotting

In [None]:
model._cache["factors"] = model.get_factors("anndata")
model._cache["factors"]

In [None]:
# celltype_map = {k: "" for k in model._cache["factors"]["group_xenium"].obs["celltype"].unique().tolist()}
celltype_map = {
    "DCIS_2": "DCIS",
    "Macrophages_1": "Macrophages",
    "Invasive_Tumor": "Invasive Tumor",
    "Stromal": "Stromal",
    "CD4+_T_Cells": "T Cells",
    "Unlabeled": "UNL",
    "CD8+_T_Cells": "T Cells",
    "Prolif_Invasive_Tumor": "Invasive Tumor",
    "Endothelial": "Endothelial",
    "Macrophages_2": "Macrophages",
    "T_Cell_&_Tumor_Hybrid": "T Cells & Tumor",
    "Myoepi_ACTA2+": "Myoepithelial",
    "B_Cells": "B Cells",
    "LAMP3+_DCs": "Dendritic Cells",
    "DCIS_1": "DCIS",
    "Perivascular-Like": "Perivascular",
    "Stromal_&_T_Cell_Hybrid": "Stromal & T Cells",
    "Myoepi_KRT15+": "Myoepithelial",
    "IRF7+_DCs": "Dendritic Cells",
    "Mast_Cells": "Mast Cells",
}

model._cache["factors"]["group_xenium"].obs["celltype_2"] = (
    model._cache["factors"]["group_xenium"]
    .obs["celltype"]
    .map(celltype_map)
    .astype("category")
)

In [None]:
model._cache["factors"]["group_xenium"].obs["celltype_2"].unique().tolist()

In [None]:
sc.tl.rank_genes_groups(
    model._cache["factors"]["group_xenium"],
    "celltype_2",
    method="wilcoxon",
    groups=[
        "DCIS",
        "Macrophages",
        "Invasive Tumor",
        "Stromal",
        "T Cells",
        "Endothelial",
        "Myoepithelial",
        "B Cells",
        "Dendritic Cells",
    ],
)

In [None]:
# sc.tl.rank_genes_groups(
#     prismo_model._cache["factors"]["group_chromium"], "celltype", method="wilcoxon"
# )

In [None]:
plt.figure(figsize=(24, 6))
relevant_factors, _ = plotting.groupplot_rank(
    model,
    group_idx="group_xenium",
    pl_type=plotting.BOXPLOT,
    top=1,
    gap=0.1,
    showfliers=False,
    rot=90,
)

In [None]:
z_df = model._cache["factors"]["group_xenium"].to_df()
z_df.head()

In [None]:
cov_df = pd.DataFrame(
    model.covariates["group_xenium"].cpu().detach().numpy(),
    index=model.sample_names["group_xenium"],
)
cov_df.head()

In [None]:
def plot_scatter_grid(data, relevant_cols, ncols=4, **kwargs):
    nrows = (len(relevant_cols) + ncols - 1) // ncols
    fig, axs = plt.subplots(
        nrows,
        ncols,
        sharex=True,
        sharey=True,
        figsize=(ncols * 5, nrows * 5),
        squeeze=False,
    )
    for i, rf in enumerate(relevant_cols):
        ax = sns.scatterplot(
            data,
            x=0,
            y=1,
            hue=rf,
            s=8,
            palette=sns.color_palette("rocket_r", as_cmap=True),
            ax=axs[i // ncols, i % ncols],
            linewidth=0.0,
            alpha=1.0,
            rasterized=True,
            legend=False,
            *kwargs,
        )
        # sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
        ax.set_title(rf)
        ax.set(xlabel=None)
        ax.set(ylabel=None)
        # ax.set_xticks(ticks=[], labels=[])
        # ax.set_yticks(ticks=[], labels=[])

    plt.xticks(ticks=[], labels=[])
    plt.yticks(ticks=[], labels=[])
    fig.tight_layout()
    plt.show()

In [None]:
relevant_factors = model.factor_names.tolist()[:12]
# celltype_factors = [fn for fn in model.factor_names if "HALLMARK" not in fn]
# relevant_factors += celltype_factors

In [None]:
plot_scatter_grid(
    pd.concat([z_df, cov_df], axis=1),
    ["HALLMARK_ESTROGEN_RESPONSE_EARLY", "HALLMARK_ESTROGEN_RESPONSE_LATE"],
    ncols=1,
)

In [None]:
plotting.scatter(
    model,
    "HALLMARK_ESTROGEN_RESPONSE_EARLY",
    "HALLMARK_ESTROGEN_RESPONSE_LATE",
    group_idx="group_xenium",
    groupby="celltype_2",
    # groups=["CD4+_T_Cells", "CD8+_T_Cells", "B_Cells"],
    size=16,
)

In [None]:
for rf in ["HALLMARK_ESTROGEN_RESPONSE_EARLY", "HALLMARK_ESTROGEN_RESPONSE_LATE"]:
    plotting.plot_top_weights_muvi(model, rf, figsize=(6, 6), top=25)
    plt.show()

In [None]:
for rf in relevant_factors:
    plotting.plot_top_weights_muvi(model, rf, figsize=(4, 6), top=15)
    plt.show()