In [None]:
# type: ignore
import os
from collections import defaultdict

import numpy as np
import scanpy as sc
from tqdm import tqdm

from anndata import read_h5ad
from ..data.datasets import available_datasets
from sklearn.metrics.pairwise import rbf_kernel

adata = read_h5ad(available_datasets["replogle_2022"].local_path)

In [None]:
adata.var.set_index("gene_name", inplace=True)
adata.var_names = adata.var_names.astype(str)
adata.var_names_make_unique()

In [None]:
seed = 123

mixscape_results_dir = f"../results/replogle_2022/mixscape/seed_{seed}"
adata.obsm["mixscape"] = np.load(
    os.path.join(mixscape_results_dir, "salient_latent_rep.npy")
)
adata.obs["mixscape_pert_probs"] = np.load(
    os.path.join(mixscape_results_dir, "pert_probs.npy")
)
sc.pp.pca(adata)

contrastive_vi_plus_results_dir = os.path.join(
    "../results",
    "replogle_2022",
    "contrastive_vi_plus",
    f"seed_{seed}",
    "inference_marginalize",
    "early_stopping_True",
    "learn_basal_mean_True",
    "n_classifier_layers_3",
    "mmd_penalty_1000.0",
)

adata.obsm["contrastive_vi_plus"] = np.load(
    os.path.join(contrastive_vi_plus_results_dir, "salient_latent_rep.npy")
)
adata.obs["contrastive_vi_plus_pert_probs"] = np.load(
    os.path.join(contrastive_vi_plus_results_dir, "pert_probs.npy")
)

In [None]:
# Using pertpy's implementation of MMD here
def metric(X, Y):
    XX = rbf_kernel(X, X, gamma=1.0)
    YY = rbf_kernel(Y, Y, gamma=1.0)
    XY = rbf_kernel(X, Y, gamma=1.0)
    return XX.mean() + YY.mean() - 2 * XY.mean()


mixscape_results_dir = f"results/replogle_2022/mixscape/seed_{seed}"
sc.pp.pca(adata)

energy_distance_diffs = defaultdict(list)
energy_distances_pert = defaultdict(list)
energy_distances_np = defaultdict(list)

perts = adata.obs["gene"].unique()

for pert in tqdm(perts):
    if pert == "non-targeting":
        continue

    adata_pert = adata[adata.obs["gene"] == pert]

    if adata_pert[adata_pert.obs["mixscape_pert_probs"] >= 0.5].shape[0] == 0:
        continue
    if (
        adata_pert[adata_pert.obs["contrastive_vi_plus_pert_probs"] >= 0.5].shape[0]
        == 0
    ):
        continue

    if adata_pert[adata_pert.obs["mixscape_pert_probs"] < 0.5].shape[0] == 0:
        continue
    if adata_pert[adata_pert.obs["contrastive_vi_plus_pert_probs"] < 0.5].shape[0] == 0:
        continue

    adata_nt = adata[adata.obs["gene"] == "non-targeting"]

    energy_distance_all = metric(adata_pert.X, adata_nt.X)

    for method in ["contrastive_vi_plus", "mixscape"]:
        energy_distance_method = metric(
            adata_pert[adata_pert.obs[f"{method}_pert_probs"] >= 0.5].X,
            adata_nt.X,
        )
        energy_distance_diffs[method].append(
            energy_distance_method - energy_distance_all
        )

        energy_distances_np[method].append(
            metric(
                adata_pert[adata_pert.obs[f"{method}_pert_probs"] < 0.5].X,
                adata_nt.X,
            )
        )

for method in ["contrastive_vi_plus", "mixscape"]:
    energy_distance_diffs[method] = np.array(energy_distance_diffs[method])
    energy_distances_np[method] = np.array(energy_distances_np[method])

In [None]:
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
import seaborn as sns

fig, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=200)

xy = np.vstack(
    [energy_distance_diffs["mixscape"], energy_distance_diffs["contrastive_vi_plus"]]
)
z = gaussian_kde(xy)(xy)

axes[0].scatter(
    x=energy_distance_diffs["mixscape"],
    y=energy_distance_diffs["contrastive_vi_plus"],
    c=z,
    s=5,
)

lims = [
    np.min([axes[0].get_xlim(), axes[0].get_ylim()]),  # min of both axes
    np.max([axes[0].get_xlim(), axes[0].get_ylim()]),  # max of both axes
]

# now plot both limits against eachother
axes[0].plot(lims, lims, "k-", alpha=0.75, zorder=1)
axes[0].set_aspect("equal")
axes[0].set_title("Change in MMD after filtering")
axes[0].set_ylabel("ContrastiveVI+")
axes[0].set_xlabel("Mixscape")
axes[0].set_ylim(lims[0], lims[1])
axes[0].set_xlim(lims[0], lims[1])

xy = np.vstack(
    [energy_distances_np["mixscape"], energy_distances_np["contrastive_vi_plus"]]
)
z = gaussian_kde(xy)(xy)

axes[1].scatter(
    x=energy_distances_np["mixscape"],
    y=energy_distances_np["contrastive_vi_plus"],
    c=z,
    s=5,
)

lims = [
    np.min([axes[1].get_xlim(), axes[1].get_ylim()]),  # min of both axes
    np.max([axes[1].get_xlim(), axes[1].get_ylim()]),  # max of both axes
]

axes[1].set_xlim(lims)
axes[1].set_ylim(lims)


# now plot both limits against eachother
axes[1].plot(lims, lims, "k-", alpha=0.75, zorder=1)
axes[1].set_aspect("equal")
axes[1].set_title("MMD between escaping and controls")
axes[1].set_ylabel("ContrastiveVI+")
axes[1].set_xlabel("Mixscape")

sns.despine()
plt.savefig("replogle_mmd.pdf")

In [None]:
from scipy.stats import binomtest

In [None]:
binomtest(
    n=len(energy_distance_diffs["contrastive_vi_plus"]),
    k=sum(
        energy_distance_diffs["contrastive_vi_plus"] > energy_distance_diffs["mixscape"]
    ),
    p=0.5,
)

In [None]:
binomtest(
    n=len(energy_distances_np["contrastive_vi_plus"]),
    k=sum(energy_distances_np["contrastive_vi_plus"] < energy_distances_np["mixscape"]),
    p=0.5,
)