In [None]:
import scanpy as sc
import numpy as np
from anndata import read_h5ad
from ..data.datasets import available_datasets
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

sc._settings.settings._vector_friendly = True

In [None]:
adata = read_h5ad(available_datasets["replogle_2022"].local_path)

In [None]:
adata.var.set_index("gene_name", inplace=True)
adata.var.index = adata.var.index.astype(str)
adata.var_names_make_unique()

# Annotate cells with cell cycle phases/scores
cell_cycle_genes_path = "regev_lab_cell_cycle_genes.txt"

if not os.path.exists(cell_cycle_genes_path):
    os.system(
        "wget https://raw.githubusercontent.com/scverse/scanpy_usage/master/180209_cell_cycle/data/regev_lab_cell_cycle_genes.txt"
    )
cell_cycle_genes = [x.strip() for x in open(cell_cycle_genes_path)]

s_genes = cell_cycle_genes[:43]
g2m_genes = cell_cycle_genes[43:]
cell_cycle_genes = [x for x in cell_cycle_genes if x in adata.var_names]

sc.tl.score_genes_cell_cycle(adata, s_genes=s_genes, g2m_genes=g2m_genes)

In [None]:
contrastive_vi_plus_results_dir = "../results/replogle_2022/contrastive_vi_plus/seed_123/inference_marginalize/early_stopping_True/learn_basal_mean_True/n_classifier_layers_3/mmd_penalty_1000.0"
adata.obsm["contrastive_vi_plus"] = np.load(
    os.path.join(contrastive_vi_plus_results_dir, "salient_latent_rep.npy")
)
adata.obs["contrastive_vi_plus_pert_probs"] = np.load(
    os.path.join(contrastive_vi_plus_results_dir, "pert_probs.npy")
)

contrastive_vi_results_dir = (
    "../results/replogle_2022/contrastive_vi/seed_123/early_stopping_True/"
)
adata.obsm["contrastive_vi"] = np.load(
    os.path.join(contrastive_vi_results_dir, "salient_latent_rep.npy")
)

mixscape_results_dir = "../results/replogle_2022/mixscape/seed_123"
adata.obsm["mixscape"] = np.load(
    os.path.join(mixscape_results_dir, "salient_latent_rep.npy")
)
adata.obs["mixscape_pert_probs"] = np.load(
    os.path.join(mixscape_results_dir, "pert_probs.npy")
)

In [None]:
sc.pp.pca(adata)
adata.obsm["pca"] = adata.obsm["X_pca"]

In [None]:
methods = ["pca", "mixscape", "contrastive_vi", "contrastive_vi_plus"]

for method in tqdm(methods):
    sc.pp.neighbors(adata, use_rep=method)
    sc.tl.umap(adata)

    adata.obsm[f"{method}_umap"] = adata.obsm["X_umap"].copy()

# For nicer titles when plotting
method_formatted = {
    "pca": "PCA",
    "mixscape": "Mixscape",
    "contrastive_vi": "ContrastiveVI",
    "contrastive_vi_plus": "ContrastiveVI+",
}

In [None]:
from sklearn.metrics import adjusted_rand_score
from sklearn.cluster import KMeans
from collections import defaultdict
import gc
from metrics import entropy_batch_mixing
import seaborn as sns

In [None]:
seeds = [123, 42, 456, 46, 999]
metrics = defaultdict(lambda: defaultdict(list))

In [None]:
for seed in tqdm(seeds):
    gc.collect()
    mixscape_results_dir = f"../results/replogle_2022/mixscape/seed_{seed}/"
    adata.obsm["mixscape"] = np.load(
        os.path.join(mixscape_results_dir, "salient_latent_rep.npy")
    )
    adata.obs["mixscape_pert_probs"] = np.load(
        os.path.join(mixscape_results_dir, "pert_probs.npy")
    )

    contrastive_vi_plus_results_dir = f"../results/replogle_2022/contrastive_vi_plus/seed_{seed}/inference_marginalize/early_stopping_True/learn_basal_mean_True/n_classifier_layers_3/mmd_penalty_1000.0"
    adata.obsm["contrastive_vi_plus"] = np.load(
        os.path.join(contrastive_vi_plus_results_dir, "salient_latent_rep.npy")
    )
    adata.obs["contrastive_vi_plus_pert_probs"] = np.load(
        os.path.join(contrastive_vi_plus_results_dir, "pert_probs.npy")
    )

    contrastive_vi_results_dir = (
        f"../results/replogle_2022/contrastive_vi/seed_{seed}/early_stopping_True/"
    )
    adata.obsm["contrastive_vi"] = np.load(
        os.path.join(contrastive_vi_results_dir, "salient_latent_rep.npy")
    )

    for method in methods:
        metrics["entropy_phase_mixing"][method].append(
            entropy_batch_mixing(adata.obsm[method], adata.obs["phase"])
        )

        adata_ = adata[adata.obs["pathway"] != "Ctrl"]
        if method in ["contrastive_vi_plus", "mixscape"]:
            adata_ = adata_[adata_.obs[f"{method}_pert_probs"] > 0.5]
        metrics["pathway_ari"][method].append(
            adjusted_rand_score(
                adata_.obs["pathway"],
                KMeans(len(adata_.obs["pathway"].unique())).fit_predict(
                    adata_.obsm[method]
                ),
            )
        )

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(18, 10), dpi=200, width_ratios=[1, 1, 0.25, 1])

sc.pl.embedding(
    adata,
    basis="pca_umap",
    ax=axes[0][0],
    show=False,
    color="phase",
    title="PCA",
)

sc.pl.embedding(
    adata,
    basis="pca_umap",
    ax=axes[0][1],
    show=False,
    color="pathway",
    title="PCA",
)

for ax in axes[0][:2]:
    ax.set_xlabel("UMAP1")
    ax.set_ylabel("UMAP2")


x = np.arange(len(metrics.keys()))  # the label locations
width = 0.1  # the width of the bars
multiplier = 0

metrics_formatted = {
    "pathway_ari": "Pathway ARI",
    "entropy_phase_mixing": "Entropy of mixing",
}

colors = {
    "pca": sns.color_palette("Dark2")[-4],
    "mixscape": sns.color_palette("Dark2")[-3],
    "contrastive_vi": sns.color_palette("Dark2")[-2],
    "contrastive_vi_plus": sns.color_palette("Dark2")[-1],
}

for method in methods:
    offset = width * multiplier
    axes[0][3].bar(
        x + offset,
        [np.mean(metrics[metric][method]) for metric in metrics.keys()],
        width,
        yerr=[np.std(metrics[metric][method]) for metric in metrics.keys()],
        capsize=5,
        label=method_formatted[method],
        color=colors[method],
    )
    multiplier += 1

# Add some text for labels, title and custom x-axis tick labels, etc.
# axes[2].set_title('')
axes[0][3].set_xticks(
    x + width, [metrics_formatted[metric] for metric in metrics.keys()]
)
axes[0][3].legend(frameon=False)

axes[0][3].spines[["right", "top"]].set_visible(False)

sc.pl.embedding(
    adata,
    basis="contrastive_vi_plus_umap",
    ax=axes[1][0],
    show=False,
    color="phase",
    title="ContrastiveVI+",
)


sc.pl.embedding(
    adata,
    basis="contrastive_vi_plus_umap",
    ax=axes[1][1],
    show=False,
    color="pathway",
    title="ContrastiveVI+",
)

sc.pl.embedding(
    adata,
    basis="contrastive_vi_plus_umap",
    ax=axes[1][3],
    show=False,
    color="contrastive_vi_plus_pert_probs",
    title="ContrastiveVI+",
)

for ax in axes[1][:4]:
    ax.set_xlabel("UMAP1")
    ax.set_ylabel("UMAP2")

axes[0][2].set_visible(False)
axes[1][2].set_visible(False)

fig.get_axes()[-1].set_ylabel("Probability of perturbation", rotation=270, labelpad=15)

plt.subplots_adjust(wspace=0.5)
plt.savefig("replogle_latent_space.pdf", bbox_inches="tight")