In [None]:
from anndata import read_h5ad
from ..data.datasets import available_datasets
from ..constants import DEFAULT_RANDOM_SEEDS as seeds
import gc
import numpy as np
import os
import scanpy as sc
from tqdm import tqdm
from collections import defaultdict
from sklearn.metrics import adjusted_rand_score
from sklearn.cluster import KMeans
from .metrics import entropy_batch_mixing
import matplotlib.pyplot as plt
import seaborn as sns

sc._settings.settings._vector_friendly = True

: 

In [None]:
adata = read_h5ad(available_datasets["norman_2019"].local_path)

adata.var.set_index("gene_name", inplace=True)

# Annotate cells with cell cycle phases/scores
cell_cycle_genes_path = "regev_lab_cell_cycle_genes.txt"

if not os.path.exists(cell_cycle_genes_path):
    os.system(
        "wget https://raw.githubusercontent.com/scverse/scanpy_usage/master/180209_cell_cycle/data/regev_lab_cell_cycle_genes.txt"
    )
cell_cycle_genes = [x.strip() for x in open(cell_cycle_genes_path)]

s_genes = cell_cycle_genes[:43]
g2m_genes = cell_cycle_genes[43:]
cell_cycle_genes = [x for x in cell_cycle_genes if x in adata.var_names]

sc.tl.score_genes_cell_cycle(adata, s_genes=s_genes, g2m_genes=g2m_genes)

In [None]:
seed = 123
contrastive_vi_plus_results_dir = f"../results/norman_2019/contrastive_vi_plus/seed_{seed}/inference_marginalize/early_stopping_True/learn_basal_mean_True/n_classifier_layers_3/mmd_penalty_1000.0"
adata.obsm["contrastive_vi_plus"] = np.load(
    os.path.join(contrastive_vi_plus_results_dir, "salient_latent_rep.npy")
)
adata.obs["contrastive_vi_plus_pert_probs"] = np.load(
    os.path.join(contrastive_vi_plus_results_dir, "pert_probs.npy")
)

contrastive_vi_results_dir = (
    f"../results/norman_2019/contrastive_vi/seed_{seed}/early_stopping_True/"
)
adata.obsm["contrastive_vi"] = np.load(
    os.path.join(contrastive_vi_results_dir, "salient_latent_rep.npy")
)

mixscape_results_dir = f"../results/norman_2019/mixscape/seed_{seed}"
adata.obsm["mixscape"] = np.load(
    os.path.join(mixscape_results_dir, "salient_latent_rep.npy")
)
adata.obs["mixscape_pert_probs"] = np.load(
    os.path.join(mixscape_results_dir, "pert_probs.npy")
)

In [None]:
methods = ["pca", "mixscape", "contrastive_vi", "contrastive_vi_plus"]
sc.pp.pca(adata)
adata.obsm["pca"] = adata.obsm["X_pca"]

for method in tqdm(methods):
    sc.pp.neighbors(adata, use_rep=method)
    sc.tl.umap(adata)

    adata.obsm[f"{method}_umap"] = adata.obsm["X_umap"].copy()

In [None]:
# For nicer titles when plotting
method_formatted = {
    "pca": "PCA",
    "mixscape": "Mixscape",
    "contrastive_vi": "ContrastiveVI",
    "contrastive_vi_plus": "ContrastiveVI+",
}

methods = ["pca", "mixscape", "contrastive_vi", "contrastive_vi_plus"]

In [None]:
metrics = defaultdict(lambda: defaultdict(list))
for seed in tqdm(seeds):
    gc.collect()
    mixscape_results_dir = f"../results/norman_2019/mixscape_pertpy/seed_{seed}/"
    adata.obsm["mixscape"] = np.load(
        os.path.join(mixscape_results_dir, "salient_latent_rep.npy")
    )
    adata.obs["mixscape_pert_probs"] = np.load(
        os.path.join(mixscape_results_dir, "pert_probs.npy")
    )

    contrastive_vi_plus_results_dir = f"../results/norman_2019/contrastive_vi_plus_t/seed_{seed}/inference_marginalize/early_stopping_True/learn_basal_mean_True/n_classifier_layers_3/mmd_penalty_1000.0"
    adata.obsm["contrastive_vi_plus"] = np.load(
        os.path.join(contrastive_vi_plus_results_dir, "salient_latent_rep.npy")
    )
    adata.obs["contrastive_vi_plus_pert_probs"] = np.load(
        os.path.join(contrastive_vi_plus_results_dir, "pert_probs.npy")
    )

    contrastive_vi_results_dir = (
        f"../results/norman_2019/contrastive_vi/seed_{seed}/early_stopping_True/"
    )
    adata.obsm["contrastive_vi"] = np.load(
        os.path.join(contrastive_vi_results_dir, "salient_latent_rep.npy")
    )

    for method in methods:
        metrics["entropy_phase_mixing"][method].append(
            entropy_batch_mixing(adata.obsm[method], adata.obs["phase"])
        )

        adata_ = adata[adata.obs["gene_program"] != "Ctrl"]
        if method in ["contrastive_vi_plus", "mixscape"]:
            adata_ = adata_[adata_.obs[f"{method}_pert_probs"] > 0.5]
        metrics["gene_program_ari"][method].append(
            adjusted_rand_score(
                adata_.obs["gene_program"],
                KMeans(len(adata_.obs["gene_program"].unique())).fit_predict(
                    adata_.obsm[method]
                ),
            )
        )

In [None]:
def reformat_pert(x):
    if x == "ctrl":
        return x
    g1, g2 = x.split("+")
    if g1 == "ctrl":
        return f"{g2}+ctrl"
    else:
        return x

In [None]:
adata.obs["guide_merged_fixed"] = [reformat_pert(x) for x in adata.obs["guide_merged"]]

In [None]:
program = "Granulocyte/apoptosis"

adata_program = adata[
    (adata.obs["gene_program"].isin([program]))
    & (adata.obs["contrastive_vi_plus_pert_probs"] > 0.5)
]

sc.pp.neighbors(adata_program, use_rep="contrastive_vi_plus")
sc.tl.umap(adata_program)

In [None]:
import hotspot

sc.pp.filter_genes(adata_program, min_cells=1)  # Necessary for Hotspot
hs = hotspot.Hotspot(
    adata_program,
    layer_key="counts",
    model="danb",
    latent_obsm_key="contrastive_vi_plus",
)

In [None]:
hs.create_knn_graph(weighted_graph=False, n_neighbors=30)
hs_results = hs.compute_autocorrelations()
hs_results.sort_values(by="C", ascending=False).head(20)

In [None]:
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(17.5, 10), dpi=200)
gs = gridspec.GridSpec(2, 5, width_ratios=[1, 0.05, 1, 0.15, 1])

ax0 = fig.add_subplot(gs[0, 0])
sc.pl.embedding(
    adata,
    basis="contrastive_vi_plus_umap",
    ax=ax0,
    show=False,
    color="phase",
    title="Cell cycle phase",
)

ax1 = fig.add_subplot(gs[0, 2])
sc.pl.embedding(
    adata,
    basis="contrastive_vi_plus_umap",
    ax=ax1,
    show=False,
    color="gene_program",
    title="Gene program",
)


ax2 = fig.add_subplot(gs[0, 4])
sc.pl.embedding(
    adata,
    basis="contrastive_vi_plus_umap",
    ax=ax2,
    show=False,
    color="contrastive_vi_plus_pert_probs",
    title="Probability of perturbation",
)

for ax in [ax0, ax1, ax2]:
    ax.set_xlabel("UMAP1")
    ax.set_ylabel("UMAP2")

x = np.arange(len(metrics.keys()))  # the label locations
width = 0.1  # the width of the bars
multiplier = 0

metrics_formatted = {
    "gene_program_ari": "Gene program ARI",
    "entropy_phase_mixing": "Entropy of mixing",
}

colors = {
    "pca": sns.color_palette("Dark2")[-4],
    "mixscape": sns.color_palette("Dark2")[-3],
    "contrastive_vi": sns.color_palette("Dark2")[-2],
    "contrastive_vi_plus": sns.color_palette("Dark2")[-1],
}

ax3 = fig.add_subplot(gs[1, 0])
for method in methods:
    offset = width * multiplier
    ax3.bar(
        x + offset,
        [np.mean(metrics[metric][method]) for metric in metrics.keys()],
        width,
        yerr=[np.std(metrics[metric][method]) for metric in metrics.keys()],
        capsize=5,
        label=method_formatted[method],
        color=colors[method],
    )
    multiplier += 1

ax3.set_xticks(
    x + 1.5 * width, [metrics_formatted[metric] for metric in metrics.keys()]
)
ax3.legend(
    ncol=1,
    frameon=False,
    bbox_to_anchor=(1.0, 0.5),
    loc="center left",
)
sns.despine(ax=ax3)

gs00 = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[1, 4])
ax4 = fig.add_subplot(gs00[0, 0])
sc.pl.umap(
    adata_program, color="BTG1", ax=ax4, show=False, cmap="magma", colorbar_loc=None
)

ax5 = fig.add_subplot(gs00[0, 1])
sc.pl.umap(
    adata_program, color="BTG2", ax=ax5, show=False, cmap="magma", colorbar_loc=None
)

ax6 = fig.add_subplot(gs00[1, 0])
sc.pl.umap(
    adata_program, color="BIRC5", ax=ax6, show=False, cmap="magma", colorbar_loc=None
)

ax7 = fig.add_subplot(gs00[1, 1])
sc.pl.umap(
    adata_program, color="YBX1", ax=ax7, show=False, cmap="magma", colorbar_loc=None
)

plt.subplots_adjust(wspace=0.5)

gs1 = gridspec.GridSpecFromSubplotSpec(2, 3, subplot_spec=gs[1, 2:4])
ax8 = fig.add_subplot(gs1[0, 0])
sc.pl.umap(
    adata_program,
    ax=ax8,
    show=False,
    color="guide_merged_fixed",
    title="Perturbation",
)
ax9 = fig.add_subplot(gs1[1, 0])
sc.pl.umap(
    adata_program, color="LST1", ax=ax9, show=False, cmap="magma", colorbar_loc=None
)

ax10 = fig.add_subplot(gs1[1, 1])
sc.pl.umap(
    adata_program, color="CSF3R", ax=ax10, show=False, cmap="magma", colorbar_loc=None
)

ax11 = fig.add_subplot(gs1[1, 2])
sc.pl.umap(
    adata_program, color="ITGAM", ax=ax11, show=False, cmap="magma", colorbar_loc=None
)

for ax in [ax4, ax5, ax6, ax7, ax8, ax9, ax10, ax11]:
    ax.set_xlabel("")
    ax.set_ylabel("")

cbar = fig.colorbar(
    ax7.collections[0], ax=[ax4, ax5, ax6, ax7], fraction=0.04, aspect=30
)
cbar.ax.set_yticks([])
cbar.ax.set_title("High")
cbar.ax.set_ylabel("Log normalized expression", rotation=270, labelpad=15)
cbar.ax.set_xlabel("Low")

plt.savefig("norman.pdf", bbox_inches="tight")