In [None]:
import sys
import os
import collections
import numpy as np
import pandas as pd
import sklearn.neighbors
import sklearn.metrics
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import Cell_BLAST as cb
import scvi.dataset
import scvi.models
import scvi.inference
import exputils

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = exputils.pick_gpu_lowest_memory()
cb.config.RANDOM_SEED = 0
plt.rcParams['svg.fonttype'] = "none"
plt.rcParams['font.family'] = "Arial"
N_POSTERIOR = 50
N_MODELS = 16
N_QUERIES = 1000
N_NEIGHBORS = 50
PATH = "./distance_comparison/"
os.makedirs(PATH, exist_ok=True)

## Read data

### Reference

In [None]:
ref = cb.data.ExprDataSet.merge_datasets(dict(
    baron=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Baron_human/data.h5"),
    xin=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Xin_2016/data.h5"),
    lawlor=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Lawlor/data.h5")
), merge_uns_slots=["seurat_genes"]).normalize()  # Otherwise scVI cannot generalize so we both use normalized data
ref.obs["__libsize__"] = np.array(ref.exprs.sum(axis=1)).ravel()
ref = ref[:, ref.uns["seurat_genes"]]
ref = exputils.clean_dataset(ref, "cell_ontology_class")

In [None]:
ref.to_anndata().write_h5ad(os.path.join(PATH, "ref.h5ad"))
ref_scvi = scvi.dataset.AnnDataset("ref.h5ad", save_path=PATH)
ref_scvi.batch_indices = cb.utils.encode_integer(ref.obs["dataset_name"])[0].reshape((-1, 1))
n_batch = np.unique(ref_scvi.batch_indices).size

### Query

In [None]:
query = dict(
    # Positive
    segerstolpe=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Segerstolpe/data.h5"),
    muraro=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Muraro/data.h5"),
    enge=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Enge/data.h5"),
    # Negative
    wu_human=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Wu_human/data.h5"),
    zheng=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Zheng/data.h5"),
    philippeos=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Philippeos/data.h5")
)
for key in query:
    query[key] = query[key].normalize()  # Otherwise scVI cannot generalize so we both use normalized data
    query[key].obs["__libsize__"] = np.array(query[key].exprs.sum(axis=1)).ravel()
    query[key] = query[key][:, ref.var_names]

In [None]:
random_state = np.random.RandomState(0)
min_size = max(min(val.shape[0] for val in query.values()), 2000)
query = cb.data.ExprDataSet.merge_datasets({
    key: val[
        random_state.choice(val.shape[0], min(min_size, val.shape[0]), replace=False), :
    ] for key, val in query.items()
})
query = exputils.clean_dataset(query, "cell_ontology_class")

In [None]:
query[:, ref.uns["seurat_genes"]].to_anndata().write_h5ad(os.path.join(PATH, "query.h5ad"))
query_scvi = scvi.dataset.AnnDataset("query.h5ad", save_path=PATH)

## Model training

### Cell BLAST

In [None]:
regs = (1e-4, 1e-3, 1e-2, 1e-1, 1, 10)

In [None]:
cb_models = collections.defaultdict(list)
for reg in regs:
    for seed in range(N_MODELS):
        print(f"==== Cell BLAST model {seed} with reg = {reg} ====")
        cb_models[reg].append(cb.directi.fit_DIRECTi(
            ref, ref.uns["seurat_genes"], batch_effect="dataset_name",
            latent_dim=10, cat_dim=20, latent_module_kwargs=dict(lambda_reg=reg),
            random_seed=seed
        ))

### scVI

In [None]:
KLs = ("dynamic", 1e-2, 1e-1, 1, 10, 50)

In [None]:
scvi_models = collections.defaultdict(list)
for kl in KLs:
    for seed in range(N_MODELS):
        print(f"==== scVI model {seed} with KL = {kl} ====")
        np.random.seed(seed)
        torch.manual_seed(seed)
        vae = scvi.models.VAE(ref_scvi.nb_genes, n_latent=10, n_batch=n_batch)
        trainer = scvi.inference.annotation.UnsupervisedTrainer(
            vae, ref_scvi, kl=None if kl == "dynamic" else kl, use_cuda=True,
            metrics_to_monitor=["ll"], frequency=5,
            early_stopping_kwargs=dict(
                early_stopping_metric="ll", save_best_state_metric="ll",
                patience=30, threshold=0
            )
        )
        trainer.train(n_epochs=1000)
        scvi_models[kl].append(vae)

## Distance metric ROC

### Preparation

In [None]:
def get_cb_latent_and_posterior(model):
    return model.inference(ref), model.inference(query), \
        model.inference(ref, n_posterior=N_POSTERIOR, progress_bar=True), \
        model.inference(query, n_posterior=N_POSTERIOR, progress_bar=True)


def get_scvi_latent_and_posterior(model):
    def _get(model, ds_scvi):
        _trainer = scvi.inference.annotation.UnsupervisedTrainer(model, ds_scvi)
        tmp = _trainer.get_all_latent_and_imputed_values()
        ref_zm, ref_zv = tmp["latent"], tmp["latent_var"]
        random_state = np.random.RandomState(0)
        posterior = [
            random_state.multivariate_normal(_ref_zm, np.diag(_ref_zv), size=N_POSTERIOR)
            for _ref_zm, _ref_zv in zip(ref_zm, ref_zv)
        ]
        return ref_zm.astype(np.float32), np.stack(posterior, axis=0).astype(np.float32)
    ref_latent, ref_posterior = _get(model, ref_scvi)
    query_latent, query_posterior = _get(model, query_scvi)
    return ref_latent, query_latent, ref_posterior, query_posterior


def get_nn_idx(ref_latent, query_latent, ref_label, query_label):
    random_state = np.random.RandomState(0)
    nn = sklearn.neighbors.NearestNeighbors().fit(ref_latent)
    ref_idx, query_idx, correctness = [], [], []
    query_idx = random_state.choice(query_latent.shape[0], size=N_QUERIES, replace=False)
    ref_idx = nn.kneighbors(query_latent[query_idx], n_neighbors=N_NEIGHBORS)[1].ravel()
    query_idx = np.repeat(query_idx, N_NEIGHBORS)
    correctness = np.array(ref_label[ref_idx]) == np.array(query_label[query_idx])
    return ref_idx, query_idx, correctness


def compute_distances(ref_latent, query_latent, ref_posterior, query_posterior):
    edist, pdist, pdist_old = [], [], []
    for i in cb.utils.smart_tqdm()(range(ref_idx.size)):
        edist.append(np.sqrt(np.square(ref_latent[i] - query_latent[i]).sum()))
        pdist.append(cb.blast.npd_v1(
            query_latent[i], ref_latent[i], 
            query_posterior[i], ref_posterior[i]
        ))
    return np.array(edist), np.array(pdist)


def distance_pair_plot(edist, pdist, correctness):
    df = pd.DataFrame({
        "Euclidean distance": edist,
        "Posterior distance": pdist,
        "Correctness": correctness
    })
    
    g = sns.JointGrid(x="Euclidean distance", y="Posterior distance", data=df)
    for _correctness, _df in df.groupby("Correctness"):
        sns.kdeplot(_df["Euclidean distance"], ax=g.ax_marg_x, legend=False, shade=True)
        sns.kdeplot(_df["Posterior distance"], ax=g.ax_marg_y, vertical=True, legend=False, shade=True)
        sns.kdeplot(_df["Euclidean distance"], _df["Posterior distance"], n_levels=10, ax=g.ax_joint)
    ax = sns.scatterplot(
        x="Euclidean distance", y="Posterior distance", hue="Correctness",
        data=df.sample(frac=1, random_state=0), s=5, edgecolor=None, alpha=0.5,
        rasterized=True, ax=g.ax_joint
    )
    _ = g.ax_joint.legend(frameon=False)
    return ax

In [None]:
df = dict(distance=[], correctness=[], model=[], metric=[], seed=[])

### Cell BLAST

In [None]:
for reg, models in cb_models.items():
    for i, model in enumerate(models):
        print(f"Dealing with model {i} of reg = {reg}...")
        ref_latent, query_latent, ref_posterior, query_posterior = get_cb_latent_and_posterior(model)
        ref_idx, query_idx, correctness = get_nn_idx(
            ref_latent, query_latent,
            ref.obs["cell_ontology_class"], query.obs["cell_ontology_class"]
        )
        edist, pdist = compute_distances(
            ref_latent[ref_idx], query_latent[query_idx],
            ref_posterior[ref_idx], query_posterior[query_idx]
        )
        if i == 0:
            ax = distance_pair_plot(edist, pdist, correctness)
            ax.get_figure().savefig(os.path.join(PATH, "cb_distance_cmp.pdf"), dpi=300, bbox_inches="tight")

        df["distance"].append(edist)
        df["correctness"].append(correctness)
        df["model"].append(np.repeat(f"Cell BLAST (reg = {reg})", edist.size))
        df["metric"].append(np.repeat("EuD", edist.size))
        df["seed"].append(np.repeat(i, edist.size))

        df["distance"].append(pdist)
        df["correctness"].append(correctness)
        df["model"].append(np.repeat(f"Cell BLAST (reg = {reg})", edist.size))
        df["metric"].append(np.repeat("NPD", edist.size))
        df["seed"].append(np.repeat(i, edist.size))

### scVI

In [None]:
for kl, models in scvi_models.items():
    for i, model in enumerate(models):
        print(f"Dealing with model {i} of KL = {kl}...")
        ref_latent, query_latent, ref_posterior, query_posterior = get_scvi_latent_and_posterior(model)
        ref_idx, query_idx, correctness = get_nn_idx(
            ref_latent, query_latent,
            ref.obs["cell_ontology_class"], query.obs["cell_ontology_class"]
        )
        edist, pdist = compute_distances(
            ref_latent[ref_idx], query_latent[query_idx],
            ref_posterior[ref_idx], query_posterior[query_idx]
        )
        if kl is None and i == 0:
            ax = distance_pair_plot(edist, pdist, correctness)
            ax.get_figure().savefig(os.path.join(PATH, "scvi_distance_cmp.pdf"), dpi=300, bbox_inches="tight")

        df["distance"].append(edist)
        df["correctness"].append(correctness)
        df["model"].append(np.repeat(f"scVI (KL = {kl})", edist.size))
        df["metric"].append(np.repeat("EuD", edist.size))
        df["seed"].append(np.repeat(i, edist.size))

        df["distance"].append(pdist)
        df["correctness"].append(correctness)
        df["model"].append(np.repeat(f"scVI (KL = {kl})", edist.size))
        df["metric"].append(np.repeat("NPD", edist.size))
        df["seed"].append(np.repeat(i, edist.size))

## Save results

In [None]:
df = pd.DataFrame({
    key: np.concatenate(val)
    for key, val in df.items()
})
df.to_csv(os.path.join(PATH, "distance.csv"), index=False)

## Plotting

In [None]:
df = pd.read_csv(os.path.join(PATH, "distance.csv"))
df["model"] = pd.Categorical(df["model"], categories=[
    f"Cell BLAST (reg = {reg})" for reg in regs
] + [
    f"scVI (KL = {kl})" for kl in KLs
])

In [None]:
auc_df = df.groupby(["model", "metric", "seed"]).apply(
    lambda x: sklearn.metrics.roc_auc_score(x["correctness"], -x["distance"])
).reset_index(name="AUC")

In [None]:
fig, ax = plt.subplots(figsize=(7.0, 4.0))
ax = sns.violinplot(
    x="model", y="AUC", hue="metric", data=auc_df,
    split=True, inner="quartile", width=0.9, linewidth=0.8
)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set(xlabel="Distance metric")
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, horizontalalignment="right")
ax.legend(
    bbox_to_anchor=(1.05, 0.5), loc="center left",
    borderaxespad=0.0, frameon=False
)
for xtick in ax.get_xticklabels():
    if xtick.get_text() in ("Cell BLAST (reg = 0.001)", "scVI (KL = dynamic)"):
        xtick.set_color("red")
fig.savefig(os.path.join(PATH, "auc.pdf"), bbox_inches="tight")