# Evaluation on Kang Data

In [1]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

# Do not show FutureWarnings
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

In [2]:
import os
import pandas as pd
import anndata as ad
import numpy as np
import benchmark as bm
import scanpy as sc
import plotnine as p9
from prismo import PRISMO
from prismo.pl import plot_weights, plot_factor_correlation
from prismo import FeatureSets as fs
import seaborn as sns
import plotnine as p9
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
import matplotlib.pyplot as plt
from prismo import FeatureSets as fs



import sys

sys.path.append("..")
from plotting_settings import discrete_scale_fill, discrete_scale_color

%matplotlib
%matplotlib inline

# Settings
device = "cuda:1"
seed = 314159
rng = np.random.default_rng(seed)

  from .autonotebook import tqdm as notebook_tqdm
 captum (see https://github.com/pytorch/captum).


Spectra GPU support is still under development. Raise any issues on github 
 
 Changes from v1: 
 (1) GPU support [see tutorial] 
 (2) minibatching for local parameters and data 
 Note that minibatching may affect optimization results 
 Code will eventually be merged into spectra.py
Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.

Using matplotlib backend: module://matplotlib_inline.backend_inline


## Load data and Reactome pathways

In [3]:
FPR = 0.0

In [4]:
def get_data(fpr, fnr, seed=None):
    adata = ad.read_h5ad("data/kang_tutorial.h5ad").copy()
    adata.var_names = adata.var_names.str.upper()
    # adata._inplace_subset_var(adata.to_df().std() > 0.2)

    adata = adata[
        :, adata.to_df().var().sort_values(ascending=False).iloc[:3000].index
    ].copy()

    hallmark_collection = fs.from_gmt("../msigdb/h.all.v7.5.1.symbols.gmt")
    hallmark_collection = hallmark_collection.filter(
        adata.var_names,
        min_fraction=0.4,
        min_count=40,
        max_count=200,
    )
    reactome_collection = fs.from_gmt("../msigdb/c2.cp.reactome.v7.5.1.symbols.gmt")
    reactome_collection = reactome_collection.filter(
        adata.var_names,
        min_fraction=0.4,
        min_count=40,
        max_count=200,
    )
    gene_set_collection = hallmark_collection | reactome_collection

    F = fs(gene_set_collection)
    print(pd.DataFrame(F.find_similar_pairs(), columns=["A", "B", "Similarity"]))

    gene_set_collection = gene_set_collection.merge_similar(
        metric="jaccard",
        similarity_threshold=0.8,
        iteratively=True,
    )

    true_mask = gene_set_collection.to_mask(adata.var_names.tolist())
    terms = true_mask.index.tolist()

    # Modify the prior knowledge introducing noise
    true_mask_copy = true_mask.copy()
    true_mask = true_mask.values
    noisy_mask = bm.get_rand_noisy_mask(rng, true_mask, fpr=fpr, fnr=fnr, seed=seed)

    return adata, true_mask, noisy_mask, terms, true_mask_copy


def preprocess(adata):
    x = adata.X
    x = x - x.min(axis=0)
    log_x = np.log1p(x)
    log_x = log_x / log_x.std()
    log_x_centered = log_x - log_x.mean(axis=0)
    # log_x_stdised = log_x_centered / log_x_centered.std()

    return {
        "expimap": log_x_centered.astype(np.float32),
        "expimap_nb": x.astype(np.float32),
        "expimap_hardmask": log_x_centered.astype(np.float32),
        "expimap_hardmask_nb": x.astype(np.float32),
        "spectra": log_x.astype(np.float32),
        "prismo": log_x_centered.astype(np.float32),
        "prismo_nmf": log_x.astype(np.float32),
    }

In [None]:
adata, true_mask, noisy_mask, terms, true_mask_copy = get_data(FPR, FPR, seed=seed)
data_dict = preprocess(adata)

# Cutoff Gene Set name after 40 characters
adata.var_names = adata.var_names.str.slice(0, 40)

### Train PRISMO

In [5]:
model_dict = {}

In [6]:
# "prismo_0.01_1000_5000_0.001_3_std_0.0_0.0_40_1",
# "prismo_0.01_1000_5000_0.001_5_std_0.0_0.0_40_1",

for model_name_parmams in ["prismo_0.003_200_10000_0.005_1_std_0.0_0.0_40_1"]:
    for model_name in ["prismo"]:

        def extract_params_from_string(s):
            if "nmf" in s:
                s = s.replace("_nmf", "")
            (
                lr,
                early_stopper_patience,
                max_epochs,
                prior_penalty,
                n_factors,
                gamma_prior_scale,
                init_scale,
                dense_factor_scale,
            ) = s.split("_")[1:9]
            return (
                lr,
                early_stopper_patience,
                max_epochs,
                prior_penalty,
                n_factors,
                gamma_prior_scale,
                init_scale,
                dense_factor_scale,
            )

        (
            lr,
            early_stopper_patience,
            max_epochs,
            prior_penalty,
            n_factors,
            gamma_prior_scale,
            init_scale,
            dense_factor_scale,
        ) = extract_params_from_string(model_name_parmams)

        print(f"Training with params: {lr}, {early_stopper_patience}, {max_epochs}, {prior_penalty}, {n_factors}, {gamma_prior_scale}, {init_scale}, {dense_factor_scale}")

        if False & os.path.exists(
            f"/data/m015k/prismo/benchmark/prismo/{model_name_parmams}.h5"
        ):
            print(f"Loading Model {model_name_parmams}")
            model_dict[model_name] = PRISMO.load(
                f"/data/m015k/prismo/benchmark/prismo/{model_name_parmams}.h5"
            )
        else:
            print("Training Model")
            model_dict[model_name] = bm.train_prismo(
                data_dict[model_name],
                noisy_mask,
                obs=adata.obs,
                var=adata.var,
                seed=seed,
                terms=terms,
                obs_names=adata.obs_names.copy(),
                var_names=adata.var_names.copy(),
                n_factors=n_factors,
                nmf="nmf" in model_name,
                prior_penalty=prior_penalty,
                max_epochs=max_epochs,
                batch_size=0,
                n_particles=1,
                lr=lr,
                early_stopper_patience=early_stopper_patience,
                device="cuda:1",
                true_mask=true_mask,
                init_factors="random",
                init_scale=0.1,
                save_path=f"/data/m015k/prismo/benchmark/prismo/{model_name_parmams}.h5",
                dense_factor_scale=init_scale,
                # gamma_prior_scale=gamma_prior_scale,
            )

Training with params: 0.003, 200, 10000, 0.005, 1, std, 0.0, 0.0
Training Model


NameError: name 'data_dict' is not defined

In [8]:
# Representative model for downstream analysis
model = model_dict["prismo"]

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(20, 6))
sns.heatmap(model.get_weights("numpy")["view_1"], ax=ax[0], cmap="RdBu_r", center=0)
sns.heatmap(model.get_factors("numpy")["group_1"].T, ax=ax[1], cmap="RdBu_r", center=0)
sns.heatmap(model.get_dispersion("numpy")["view_1"][:, None], ax=ax[2], cmap="RdBu_r", center=0)

In [None]:
# Show correlation between factors
plot_factor_correlation(model, figsize=(25, 25))

# Pathway Analysis

### UMAP plots

In [None]:
adata_latent = model.get_factors("anndata")["group_1"].copy()
adata_latent.obs = adata.obs.copy()
obs_names = adata_latent.obs_names.to_numpy()
adata_latent = adata_latent[obs_names, :].copy()
adata_latent

In [12]:
sc.pp.neighbors(adata_latent, use_rep="X")
sc.tl.leiden(adata_latent)
sc.tl.umap(adata_latent)

In [None]:
df_plot = pd.DataFrame(adata_latent.obsm["X_umap"])
df_plot["stim"] = adata_latent.obs["stim"].values
df_plot["cell_type"] = adata_latent.obs["cell_type"].values
df_plot.columns = ["UMAP1", "UMAP2", "stim", "cell_type"]
df_plot

In [None]:
# Make scatterplot of UMAP embedding
p = (
    p9.ggplot(df_plot)
    + p9.aes(x="UMAP1", y="UMAP2", color="stim")
    + p9.geom_point(size=1.5, alpha=0.5)
    + p9.theme(legend_position="right")
    # Add title
    + p9.ggtitle("Conditions")
    + discrete_scale_color
)
p

In [15]:
df_plot = pd.DataFrame(
    {
        "x": adata_latent.obsm["X_umap"][:, 0],
        "y": adata_latent.obsm["X_umap"][:, 1],
        "cell type": adata_latent.obs["cell_type"],
        "condition": adata_latent.obs["condition"],
    }
)

(
    p9.ggplot(df_plot, p9.aes(x="x", y="y", color="condition"))
    + p9.geom_point(size=0.5, alpha=0.25)
    + p9.labs(title="Conditions", x="UMAP1", y="UMAP2")
    + p9.scale_color_manual(values=colors_disc)
    # Use colors_disc
)

(
    p9.ggplot(df_plot, p9.aes(x="x", y="y", color="cell type"))
    + p9.geom_point(size=0.33, alpha=0.5)
    + p9.labs(title="Cell Types", x="UMAP1", y="UMAP2")
    + p9.scale_color_manual(values=list(reversed(colors_disc)))
    + p9.guides(color=p9.guide_legend(override_aes={"size": 3, "alpha": 1}))
)

In [None]:
sc.tl.rank_genes_groups(adata_latent, groupby="condition", groups=["stimulated"])

In [None]:
dp = sc.pl.rank_genes_groups_dotplot(
    adata_latent,
    return_fig=True,
    groupby="cell_type",
    save="rank_genes_groups_dotplot.png",
    show=False,
    standard_scale="var",
    # values_to_plot="logfoldchanges",
    # cmap='bwr',
    # vmin=-4,
    # vmax=4,
    # min_logfoldchange=3,
    # colorbar_title='log fold change'
)
dp.add_totals().style(dot_edge_color="black", dot_edge_lw=0.5).show()

In [None]:
dp = sc.pl.rank_genes_groups_dotplot(
    adata_latent, return_fig=True, standard_scale="var"
)
dp.add_totals().style(dot_edge_color="black", dot_edge_lw=0.5).show()

dp = sc.pl.rank_genes_groups_dotplot(adata_latent, return_fig=True, groupby="condition")
dp.add_totals().style(dot_edge_color='black', dot_edge_lw=0.5).show()

In [None]:
# Take all factors involving interferon and make 2D scatter plots between all of them
# factors = [vn for vn in adata_latent.var_names if "INTERFER" in vn]
factors = list(set([vn for vn in adata_latent.var_names if "INTERFER" in vn]))

df_plot = pd.DataFrame(model.get_factors("anndata")["group_1"].X)
df_plot.columns = model.get_factors("anndata")["group_1"].var_names
df_plot["condition"] = adata.obs["condition"].values

sns.pairplot(df_plot, hue="condition", vars=factors, plot_kws={"s": 3, "alpha": 0.25})

kmeans = KMeans(n_clusters=2, random_state=42)
df_plot["cluster"] = kmeans.fit_predict(df_plot[["Factor 1", "Factor 2"]])

condition_mapping = {
    label: idx for idx, label in enumerate(df_plot["condition"].unique())
}
df_plot["condition_numeric"] = df_plot["condition"].map(condition_mapping)
adjusted_rand_score(df_plot["condition_numeric"], df_plot["cluster"])

features = df_plot[["Factor 1", "Factor 3"]].values
conditions = df_plot["condition"].values

nbrs = NearestNeighbors(n_neighbors=6, algorithm="auto").fit(features)
distances, indices = nbrs.kneighbors(features)

fractions = []
for i, neighbors in enumerate(indices):
    same_condition_count = sum(
        conditions[neighbors[1:]] == conditions[i]
    )  # Exclude the sample itself
    fractions.append(same_condition_count / 5)

np.average(fractions)

In [None]:
# Take all factors involving interferon and make 2D scatter plots between all of them
# factors = [vn for vn in adata_latent.var_names if "INTERFER" in vn]
factors = list(
    set([vn for vn in adata_latent.var_names if "INTERFER" in vn])
    | set([x for x in dp.var_names if "INTERFER" in x or "Factor" in x])
)

df_plot = pd.DataFrame(model.get_factors("anndata")["group_1"].X)
df_plot.columns = model.get_factors("anndata")["group_1"].var_names
df_plot["condition"] = adata.obs["condition"].values

sns.pairplot(df_plot, hue="condition", vars=factors, plot_kws={"s": 3, "alpha": 0.25})

In [None]:
# Make the plot for Factor 3 and "REACTOME_INTERFERON_ALPHA_BETA_SIGNALING"
(
    p9.ggplot(
        df_plot,
        p9.aes(x="Factor 1", y="REACTOME_INTERFERON_SIGNALING", color="condition"),
    )
    + p9.geom_point(alpha=0.25)
    + discrete_scale_color
)

In [None]:
dfx = model.get_weights("anndata")["view_1"]
dfx[:, dfx.var_names.str.contains("ISG")].to_df().sort_values(
    "ISG15", ascending=False
).head(10)

In [23]:
ifn_factors = [vn for vn in adata_latent.var_names if "INTERFER" in vn]
factor_list = list(
    set(
        dfx[:, dfx.var_names.str.contains("ISG")]
        .to_df()
        .sort_values("ISG15", ascending=False)
        .head(10)["ISG15"]
        .index
    )
    | set(ifn_factors)
)

In [None]:
plot_weights(
    model,
    n_features=20,
    views="view_1",
    factors=factors,
    pointsize=2,
    figsize=(20, 14),
) + p9.facet_wrap("factor", nrow=2)

In [None]:
plot_weights(
    model,
    n_features=20,
    views="view_1",
    factors=factor_list,
    pointsize=2,
    figsize=(20, 14),
) + p9.facet_wrap("factor", nrow=2)

In [26]:
df_plot = pd.DataFrame(adata_latent.X).assign(
    stim=adata_latent.obs["stim"].values, sample_id=range(len(adata_latent.obs))
)
df_plot.columns = adata_latent.var_names.tolist() + ["stim", "sample_id"]

In [None]:
(
    p9.ggplot(
        df_plot,
        p9.aes(x="stim", y="REACTOME_INTERFERON_ALPHA_BETA_SIGNALING", color="stim"),
    )
    + p9.geom_boxplot()
    + p9.labs(x="Condition", y="REACTOME INTERFERON ALPHA BETA SIGNALING")
    + p9.theme(legend_title=p9.element_text(text="Condition"))
    + p9.theme(figure_size=(6, 4))
)