In [None]:
from mudata import read
from ..data.datasets import available_datasets
import numpy as np
import os
import scanpy as sc
from tqdm import tqdm
import matplotlib.pyplot as plt
from metrics import entropy_batch_mixing

sc._settings.settings._vector_friendly = True

In [None]:
mdata = read(available_datasets["papalexi_2021"].local_path)
adata = mdata["rna"]

In [None]:
seed = 123
contrastive_vi_plus_results_dir = "../results/papalexi_2021/contrastive_vi_plus/seed_123/inference_marginalize/early_stopping_True/learn_basal_mean_True/n_classifier_layers_3/mmd_penalty_1000.0"
contrastive_vi_results_dir = (
    "../results/papalexi_2021/contrastive_vi/seed_123/early_stopping_True"
)
mixscape_results_dir = "../results/papalexi_2021/mixscape/seed_123"

adata.obsm["mixscape"] = np.load(
    os.path.join(mixscape_results_dir, "salient_latent_rep.npy")
)
adata.obsm["contrastive_vi"] = np.load(
    os.path.join(contrastive_vi_results_dir, "salient_latent_rep.npy")
)
adata.obsm["contrastive_vi_plus"] = np.load(
    os.path.join(contrastive_vi_plus_results_dir, "salient_latent_rep.npy")
)

adata.obs["mixscape_pert_probs"] = np.load(
    os.path.join(mixscape_results_dir, "pert_probs.npy")
)
adata.obs["contrastive_vi_plus_pert_probs"] = np.load(
    os.path.join(contrastive_vi_plus_results_dir, "pert_probs.npy")
)

In [None]:
methods = ["pca", "mixscape", "contrastive_vi", "contrastive_vi_plus"]
sc.pp.pca(adata)
adata.obsm["pca"] = adata.obsm["X_pca"]

for method in tqdm(methods):
    sc.pp.neighbors(adata, use_rep=method)
    sc.tl.umap(adata)

    adata.obsm[f"{method}_umap"] = adata.obsm["X_umap"].copy()

# For nicer titles when plotting
method_formatted = {
    "mixscape": "Mixscape",
    "contrastive_vi": "ContrastiveVI",
    "contrastive_vi_plus": "ContrastiveVI+",
    "pca": "PCA",
}

In [None]:
from collections import defaultdict
import gc
from tqdm import tqdm

seeds = [123, 42, 456, 46, 999]

metrics = defaultdict(lambda: defaultdict(list))
sc.pp.pca(adata)
adata.obsm["pca"] = adata.obsm["X_pca"]

for seed in tqdm(seeds):
    gc.collect()
    mixscape_results_dir = f"../results/papalexi_2021/mixscape/seed_{seed}/"
    adata.obsm["mixscape"] = np.load(
        os.path.join(mixscape_results_dir, "salient_latent_rep.npy")
    )

    contrastive_vi_plus_results_dir = f"../results/papalexi_2021/contrastive_vi_plus/seed_{seed}/inference_marginalize/early_stopping_True/learn_basal_mean_True/n_classifier_layers_3/mmd_penalty_1000.0"
    adata.obsm["contrastive_vi_plus"] = np.load(
        os.path.join(contrastive_vi_plus_results_dir, "salient_latent_rep.npy")
    )

    contrastive_vi_results_dir = (
        f"../results/papalexi_2021/contrastive_vi/seed_{seed}/early_stopping_True/"
    )
    adata.obsm["contrastive_vi"] = np.load(
        os.path.join(contrastive_vi_results_dir, "salient_latent_rep.npy")
    )

    for method in methods:
        print(method)
        metrics["entropy_replicate_mixing"][method].append(
            entropy_batch_mixing(adata.obsm[method], adata.obs["replicate"])
        )

        metrics["entropy_phase_mixing"][method].append(
            entropy_batch_mixing(adata.obsm[method], adata.obs["Phase"])
        )

In [None]:
import seaborn as sns

fig, axes = plt.subplots(2, 3, figsize=(16, 10), dpi=200)

sc.pl.embedding(
    adata,
    basis="pca_umap",
    ax=axes[0][0],
    show=False,
    color="Phase",
    title="PCA",
    legend_loc=None,
)

sc.pl.embedding(
    adata,
    basis="contrastive_vi_plus_umap",
    ax=axes[0][1],
    show=False,
    color="Phase",
    title="ContrastiveVI+",
)
axes[0][1].get_legend().set_title("Cell cycle")

axes[0][2].bar(
    [method_formatted[method] for method in methods],
    [np.mean(metrics["entropy_phase_mixing"][method]) for method in methods],
    yerr=[np.std(metrics["entropy_phase_mixing"][method]) for method in methods],
    capsize=5,
    color=sns.color_palette("Dark2")[-4:],
)
axes[0][2].spines[["right", "top"]].set_visible(False)
axes[0][2].set_title("Entropy of mixing (cell cycle)")
axes[0][2].set_xticklabels(axes[0][2].get_xticklabels(), rotation=15)

sc.pl.embedding(
    adata,
    basis="pca_umap",
    ax=axes[1][0],
    show=False,
    color="replicate",
    title="PCA",
    legend_loc=None,
    palette=sns.color_palette("Set1")[-3:],
)

sc.pl.embedding(
    adata,
    basis="contrastive_vi_plus_umap",
    ax=axes[1][1],
    show=False,
    color="replicate",
    title="ContrastiveVI+",
    palette=sns.color_palette("Set1")[-3:],
)
axes[1][1].get_legend().set_title("Replicate")


axes[1][2].bar(
    [method_formatted[method] for method in methods],
    [np.mean(metrics["entropy_replicate_mixing"][method]) for method in methods],
    yerr=[np.std(metrics["entropy_replicate_mixing"][method]) for method in methods],
    capsize=5,
    color=sns.color_palette("Dark2")[-4:],
)
axes[1][2].set_title("Entropy of mixing (replicate)")
axes[1][2].set_xticklabels(axes[1][2].get_xticklabels(), rotation=15)
axes[1][2].spines[["right", "top"]].set_visible(False)

for ax in axes[0][:2]:
    ax.set_xlabel("UMAP1")
    ax.set_ylabel("UMAP2")

for ax in axes[1][:2]:
    ax.set_xlabel("UMAP1")
    ax.set_ylabel("UMAP2")

plt.subplots_adjust(wspace=0.5, hspace=0.25)
plt.savefig("papalexi_confounders.pdf", bbox_inches="tight")

In [None]:
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(17.5, 10), dpi=200)

gs = gridspec.GridSpec(2, 4, width_ratios=[1, 1, 0.5, 1])

ax1 = fig.add_subplot(gs[0, 0])

sc.pl.embedding(
    adata,
    basis="contrastive_vi_umap",
    ax=ax1,
    show=False,
    color="gene",
    title="ContrastiveVI",
    legend_loc=None,
)

ax2 = fig.add_subplot(gs[0, 1])

sc.pl.embedding(
    adata,
    basis="contrastive_vi_plus_umap",
    ax=ax2,
    show=False,
    color="gene",
    title="ContrastiveVI+",
)

legend = ax2.get_legend()

ax3 = fig.add_subplot(gs[0, 3])

sc.pl.embedding(
    adata,
    basis="contrastive_vi_plus_umap",
    ax=ax3,
    show=False,
    color="contrastive_vi_plus_pert_probs",
    title="ContrastiveVI+",
)

for ax in [ax1, ax2, ax3]:
    ax.set_xlabel("UMAP1")
    ax.set_ylabel("UMAP2")


gs00 = gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=gs[1, :])

ax4 = fig.add_subplot(gs00[:2])

sc.pl.violin(
    # adata_ko_,
    adata,
    groupby="gene",
    keys="PDL1",
    stripplot=False,
    inner="box",
    ax=ax4,
    order=[
        "STAT1",
        "JAK2",
        "IFNGR2",
        "IFNGR1",
        "IRF1",
        "BRD4",
        "CUL3",
        "STAT2",
        "SMAD4",
        "NT",
    ],
    show=False,
    palette="tab10_r",
)

ax4.set_ylabel("PDL1 expression")
ax4.set_xlabel("Targeted Gene")

gene = "STAT2"
adata_gene = adata[
    (adata.obs["gene"] == "NT")
    | (
        (adata.obs["gene"] == gene)
        & (adata.obs["contrastive_vi_plus_pert_probs"] > 0.5)
    )
]

ax5 = fig.add_subplot(gs00[2])

df = sc.get.obs_df(adata_gene, ["IFI6", "ISG15", "gene"])
df = df.set_index("gene").stack().reset_index()
df.columns = ["gene", "pert", "value"]

sns.violinplot(
    data=df,
    x="pert",
    y="value",
    hue="gene",
    inner="box",
    ax=ax5,
    linewidth=1,
    palette=[
        ax4.get_children()[36].get_facecolor(),
        ax4.get_children()[28].get_facecolor(),
    ],
    cut=0,
)
ax5.legend(frameon=False, title="Perturbation", loc="upper right")
ax5.set_xlabel("")
ax5.set_ylabel("Log library size normalized Expression")


gene = "SMAD4"
adata_gene = adata[
    (adata.obs["gene"] == "NT")
    | (
        (adata.obs["gene"] == gene)
        & (adata.obs["contrastive_vi_plus_pert_probs"] > 0.5)
    )
]

ax6 = fig.add_subplot(gs00[3])

df = sc.get.obs_df(adata_gene, ["APOC1", "FN1", "gene"])
df = df.set_index("gene").stack().reset_index()
df.columns = ["gene", "pert", "value"]

sns.violinplot(
    data=df,
    x="pert",
    y="value",
    hue="gene",
    inner="box",
    ax=ax6,
    linewidth=1,
    palette=[
        ax4.get_children()[36].get_facecolor(),
        ax4.get_children()[32].get_facecolor(),
    ],
    cut=0,
)
ax6.legend(frameon=False, title="Perturbation", loc="upper right")
ax6.set_xlabel("")
ax6.set_ylabel("Log library size normalized Expression")

sns.despine(ax=ax4)
sns.despine(ax=ax5)
sns.despine(ax=ax6)
plt.savefig("papalexi_genes.pdf", bbox_inches="tight")