In [None]:
# https://www.sc-best-practices.org/conditions/gsea_pathway.html#id380
# Kang HM, Subramaniam M, Targ S, et al. Multiplexed droplet single-cell RNA-sequencing using natural genetic variation
#   Nat Biotechnol. 2020 Nov;38(11):1356]. Nat Biotechnol. 2018;36(1):89-94. doi:10.1038/nbt.4042

In [None]:
%load_ext autoreload
%autoreload 2

import scanpy as sc
import tensorflow as tf
from tensorflow.keras.models import Model
from ivae_scorer.datasets import load_kang
from tensorflow.keras import callbacks
from ivae_scorer.utils import set_all_seeds
from ivae_scorer.bio import get_reactome_adj
from ivae_scorer.bio import sync_gexp_adj
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import minmax_scale
import matplotlib.pyplot as plt
import seaborn as sns
import dotenv
from pathlib import Path

project_path = Path(dotenv.find_dotenv()).parent
results_path = project_path.joinpath("results")
results_path.mkdir(exist_ok=True, parents=True)
data_path = project_path.joinpath("data")
data_path.mkdir(exist_ok=True, parents=True)
figs_path = results_path.joinpath("figs")
figs_path.mkdir(exist_ok=True, parents=True)
tables_path = results_path.joinpath("tables")
tables_path.mkdir(exist_ok=True, parents=True)

set_all_seeds(seed=42)

tf.config.experimental.enable_op_determinism()

sc.set_figure_params(dpi=300, color_map="viridis")
sc.settings.verbosity = 1
sc.logging.print_header()

In [None]:
adata = load_kang(data_folder=data_path, normalize=True, n_genes=4000)

In [None]:
x_trans = adata.to_df()

In [None]:
reactome = get_reactome_adj()

reactome.head()

In [None]:
x_trans, reactome = sync_gexp_adj(x_trans, reactome)

In [None]:
x_trans.shape, reactome.shape

In [None]:
obs = adata.obs.copy()

x_train, x_test = train_test_split(
    x_trans.apply(minmax_scale),
    test_size=0.33,
    stratify=obs["cell_type"],
    random_state=42,
)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")

In [None]:
x_train.shape, reactome.shape

In [None]:
from ivae_scorer.models import build_reactome_vae

vae, encoder, decoder = build_reactome_vae(reactome)

In [None]:
batch_size = 32

callback = callbacks.EarlyStopping(
    monitor="val_loss",  # Stop training when `val_loss` is no longer improving
    min_delta=1e-1,  # "no longer improving" being defined as "no better than 1e-5 less"
    patience=30,  # "no longer improving" being further defined as "for at least 3 epochs"
    verbose=0,
)

history = vae.fit(
    x_train.values,
    shuffle=True,
    verbose=0,
    epochs=100,
    batch_size=batch_size,
    callbacks=[callback],
    validation_data=(x_test.values, None),
)

In [None]:
with sns.plotting_context("paper"):
    history.history.keys()
    # summarize history for loss
    plt.figure(figsize=(2, 2))
    plt.plot(history.history["loss"])
    plt.plot(history.history["val_loss"])
    plt.title("model loss")
    plt.ylabel("loss")
    plt.xlabel("epoch")
    plt.legend(["train", "val"], loc="upper right")

In [None]:
import pandas as pd

x_train_encoded = encoder.predict(x_train, batch_size=batch_size)[0]
x_train_encoded = pd.DataFrame(x_train_encoded, index=x_train.index)

adata = sc.AnnData(X=x_train_encoded)
adata.obs = obs.loc[x_train.index]

# sc.pp.log1p(adata)
sc.pp.neighbors(adata, use_rep="X")
sc.tl.leiden(adata)
sc.tl.umap(adata)

In [None]:
layer_outputs = [layer.output for layer in encoder.layers]
activation_model = Model(inputs=encoder.input, outputs=layer_outputs)

In [None]:
# circuit activity extracted from first layer
layer_id = 1

entitie_names = reactome.columns.str.replace("REACTOME_", "").str.replace("_", " ")
# circuit_names = circuit_to_pathway_adj.rename(columns=pathway_renamer).columns

x_train_encoded = activation_model.predict(x_train, batch_size=batch_size)[layer_id]
x_train_encoded = pd.DataFrame(
    x_train_encoded, index=x_train.index, columns=entitie_names
)

entities_adata = sc.AnnData(X=x_train_encoded.abs())
entities_adata.obs = obs.loc[x_train.index]

# sc.pp.log1p(adata)
sc.pp.neighbors(entities_adata, use_rep="X")
sc.tl.leiden(entities_adata)
sc.tl.umap(entities_adata)

In [None]:
method = "wilcoxon"
sc.tl.rank_genes_groups(
    entities_adata, "condition", refrence="control", key_added=method, method=method
)

result = entities_adata.uns[method]
groups = ["stimulated"]
dacs = pd.DataFrame(
    {
        group + "_" + key: result[key][group]
        for group in groups
        for key in ["names", "scores", "pvals", "pvals_adj", "logfoldchanges"]
    }
)

dacs.head(10)

In [None]:
dacs_to_write = dacs.head(10).copy()
dacs_to_write.columns = dacs_to_write.columns.str.replace("stimulated_", "")
dacs_to_write = dacs_to_write.rename(columns={"names": "pathways"})
dacs_to_write["pathways"] = dacs_to_write["pathways"].str.title()
dacs_to_write = dacs_to_write.drop("pvals", axis=1)
dacs_to_write["pathways"] = dacs_to_write["pathways"].str.replace(
    "Alpha Beta", r"$\alpha, \beta$"
)
dacs_to_write["pathways"] = dacs_to_write["pathways"].str.replace("Ifn", "IFN")
dacs_to_write["pathways"] = dacs_to_write["pathways"].str.replace("Adn", "ADN")
dacs_to_write["pathways"] = dacs_to_write["pathways"].str.replace("Oas", "OAS")
dacs_to_write["pathways"] = dacs_to_write["pathways"].str.replace("Ddx58", "DDX58")
dacs_to_write["pathways"] = dacs_to_write["pathways"].str.replace("Ns1", "NS1")
dacs_to_write["pathways"] = dacs_to_write["pathways"].str.replace("Ifih1", "IFIH1 ")
dacs_to_write.to_latex(
    tables_path.joinpath("ivae_scorer_reactome.tex"),
    float_format="%.2f",
    index=False,
    escape=True,
)

In [None]:
dacs_top = dacs["stimulated_names"][:10]
dacs_top = dacs_top[dacs_top.str.contains("interferon|ifn", case=False)]
dacs_top

In [None]:
adata.obsm["pathways"] = x_train_encoded.abs()
adata.obs[dacs_top] = adata.obsm["pathways"][dacs_top]

In [None]:
sc.pl.umap(
    adata,
    color=["condition", "cell_type"],
    frameon=False,
    ncols=2,
    wspace=0.3,
    show=False,
)

plt.savefig(figs_path.joinpath("ivae_scorer_reactome_latent.pdf"), bbox_inches="tight")

sc.pl.umap(adata, color=dacs_top, frameon=False, ncols=4, wspace=0.3, show=False)

plt.savefig(
    figs_path.joinpath("ivae_scorer_reactome_activity.pdf"), bbox_inches="tight"
)

sc.pl.umap(
    adata,
    color=["condition", "cell_type"] + dacs_top.tolist(),
    frameon=False,
    ncols=2,
    wspace=0.3,
    show=False,
)

plt.savefig(figs_path.joinpath("ivae_scorer_reactome_ifn.pdf"), bbox_inches="tight")

In [None]:
q = entitie_names.str.contains(
    "influenza infection|hiv infection|SARS COV 2 INFECTION|hcmv infection", case=False
)
ifn_circuits = entitie_names[q].tolist()
adata.obs[ifn_circuits] = adata.obsm["pathways"][ifn_circuits]

sc.pl.umap(
    adata,
    color=["condition", "cell_type"],
    frameon=False,
    ncols=2,
    wspace=0.3,
    show=False,
)

plt.savefig(
    figs_path.joinpath("ivae_scorer_reactome_infection_latent.pdf"), bbox_inches="tight"
)


sc.pl.umap(adata, color=ifn_circuits, frameon=False, ncols=4, wspace=0.3, show=False)

plt.savefig(
    figs_path.joinpath("ivae_scorer_reactome_infection_activity.pdf"),
    bbox_inches="tight",
)

sc.pl.umap(
    adata,
    color=["condition", "cell_type"] + ifn_circuits,
    frameon=False,
    ncols=2,
    wspace=0.3,
    show=False,
)

plt.savefig(
    figs_path.joinpath("ivae_scorer_reactome_infection_all.pdf"),
    bbox_inches="tight",
)

In [None]:
adata.obs[["condition", "cell_type", "INFLUENZA INFECTION"]].groupby(["cell_type"])[
    "INFLUENZA INFECTION"
].mean().sort_values()