# Posterior visualization

## Preparation

In [None]:
import os
import functools
import numpy as np
import pandas as pd
import scipy.sparse
import sklearn.svm
import sklearn.metrics
import matplotlib.pyplot as plt
import torch
import scvi.dataset
import scvi.models
import scvi.inference
import Cell_BLAST as cb
import exputils

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = exputils.pick_gpu_lowest_memory()
cb.config.RANDOM_SEED = 0
plt.rcParams['svg.fonttype'] = "none"
plt.rcParams['font.family'] = "Arial"
N_POSTERIOR = 200
PATH = "./posterior_visualization_adam/"
os.makedirs(PATH, exist_ok=True)

Prepare data

In [None]:
ds = cb.data.ExprDataSet.read_dataset("../../Datasets/data/Adam/data.h5")
ds.obs["__libsize__"] = np.asarray(ds.exprs.sum(axis=1)).ravel()
ds = ds[:, ds.uns["seurat_genes"]]
ds.to_anndata().write_h5ad(os.path.join(PATH, "ds.h5ad"))
ds_scvi = scvi.dataset.AnnDataset("ds.h5ad", save_path=PATH)

In [None]:
FOCUS_CTs = ["Distal tubule", "Ureteric bud"]

## Train models

### Cell BLAST

In [None]:
cb_model = cb.directi.fit_DIRECTi(
    ds, ds.uns["seurat_genes"], latent_dim=2, cat_dim=20,
    random_seed=4
)

In [None]:
ds.latent = exputils.get_cb_latent(cb_model, ds)

In [None]:
ax = ds.visualize_latent(
    "cell_type1", method=None, sort=True, width=4.5, height=4.5,
    scatter_kws=dict(rasterized=True)
)
ax.get_figure().savefig(os.path.join(PATH, "cb_latent.pdf"), dpi=300, bbox_inches="tight")

### scVI

In [None]:
np.random.seed(0)
torch.manual_seed(0)
scvi_model = scvi.models.VAE(ds_scvi.nb_genes, n_latent=2)
scvi_trainer = scvi.inference.UnsupervisedTrainer(
    scvi_model, ds_scvi, use_cuda=True, metrics_to_monitor=["ll"], frequency=5,
    early_stopping_kwargs=dict(
        early_stopping_metric="ll", save_best_state_metric="ll",
        patience=30, threshold=0
    )
)
scvi_trainer.train(n_epochs=1000)

In [None]:
ds.latent = exputils.get_scvi_latent(scvi_model, ds_scvi)

In [None]:
ax = ds.visualize_latent(
    "cell_type1", method=None, sort=True, width=4.5, height=4.5,
    scatter_kws=dict(rasterized=True)
)
ax.get_figure().savefig(os.path.join(PATH, "scvi_latent.pdf"), dpi=300, bbox_inches="tight")

## Comparison

In [None]:
posterior_obs = pd.concat([ds.obs] * N_POSTERIOR, axis=0)
posterior_exprs = scipy.sparse.csr_matrix((posterior_obs.shape[0], ds.shape[1]))  # Sham expression matrix
posterior_ds = cb.data.ExprDataSet(posterior_exprs, posterior_obs, ds.var, {})

### Cell BLAST

In [None]:
tmp = exputils.get_cb_latent(cb_model, ds, n_posterior=N_POSTERIOR)
posterior_ds.latent = np.concatenate([tmp[:, i, :] for i in range(tmp.shape[1])], axis=0)
ax = posterior_ds.visualize_latent(
    method=None, size=3, width=4.5, height=4.5,
    scatter_kws=dict(alpha=1 / N_POSTERIOR, rasterized=True)
)
ds.latent = exputils.get_cb_latent(cb_model, ds)
ax = ds.visualize_latent("cell_type1", method=None, sort=True, scatter_kws=dict(rasterized=True), ax=ax)
ax.get_figure().savefig(os.path.join(PATH, "cb_posterior.png"), dpi=300, bbox_inches="tight")

In [None]:
mask = functools.reduce(np.logical_and, [
    np.in1d(ds.obs["cell_type1"], FOCUS_CTs),
    ds.obs["latent_1"] > 0.0,
    ds.obs["latent_1"] < 4.0,
    ds.obs["latent_2"] > -3.8,
    ds.obs["latent_2"] < 0.0
])
sub_ds = ds[mask, :]
posterior_mask = np.concatenate([mask] * N_POSTERIOR, axis=0)
sub_posterior_ds = posterior_ds[posterior_mask, :]

mask = sklearn.metrics.silhouette_samples(sub_ds.latent, sub_ds.obs["cell_type1"]) > 0.1
sub_ds = sub_ds[mask, :]
posterior_mask = np.concatenate([mask] * N_POSTERIOR, axis=0)
sub_posterior_ds = sub_posterior_ds[posterior_mask, :]

In [None]:
svc = sklearn.svm.SVC(random_state=0, gamma=0.01).fit(sub_ds.latent, sub_ds.obs["cell_type1"])
sub_ds.obs["support"] = np.logical_and(
    np.in1d(np.arange(sub_ds.shape[0]), svc.support_),
    np.abs(svc.decision_function(sub_ds.latent)) > 0.5
)
sub_ds.obs["class"] = pd.Categorical(svc.predict(sub_ds.latent))
ax = sub_ds.visualize_latent(
    hue="support", style="class",
    method=None, sort=True, size=30, width=4.5, height=4.5,
    scatter_kws=dict(markers=["s", "^"], rasterized=True)
)
ax.get_figure().savefig(os.path.join(PATH, "cb_support.pdf"), dpi=300, bbox_inches="tight")

In [None]:
mask = np.logical_and(sub_ds.obs["support"], sub_ds.obs["class"] == FOCUS_CTs[0])
centers = sub_ds.latent[mask]
posterior_mask = np.concatenate([mask] * N_POSTERIOR, axis=0)
posterior = sub_posterior_ds.latent[posterior_mask]
deviation = posterior - np.concatenate([centers] * N_POSTERIOR, axis=0)
ax = exputils.aligned_posterior_plot(deviation, lim=0.2)
ax.set_title(f"{FOCUS_CTs[0]} (variational)")
ax.get_figure().savefig(os.path.join(PATH, f"cb_{FOCUS_CTs[0]}_variational.pdf"), dpi=300, bbox_inches="tight")

In [None]:
tmp = exputils.metropolis_hastings(
    centers,
    lambda latent: exputils.get_cb_log_unnormalized_posterior(cb_model, sub_ds[mask, :], latent),
    target=N_POSTERIOR
)
posterior = np.concatenate([tmp[:, i, :] for i in range(tmp.shape[1])], axis=0)
deviation = posterior - np.concatenate([centers] * N_POSTERIOR, axis=0)
ax = exputils.aligned_posterior_plot(deviation, lim=0.4)
ax.set_title(f"{FOCUS_CTs[0]} (MCMC)")
ax.get_figure().savefig(os.path.join(PATH, f"cb_{FOCUS_CTs[0]}_mcmc.pdf"), dpi=300, bbox_inches="tight")

In [None]:
mask = np.logical_and(sub_ds.obs["support"], sub_ds.obs["class"] == FOCUS_CTs[1])
centers = sub_ds.latent[mask]
posterior_mask = np.concatenate([mask] * N_POSTERIOR, axis=0)
posterior = sub_posterior_ds.latent[posterior_mask]
deviation = posterior - np.concatenate([centers] * N_POSTERIOR, axis=0)
ax = exputils.aligned_posterior_plot(deviation, lim=0.2)
ax.set_title(f"{FOCUS_CTs[1]} (variational)")
ax.get_figure().savefig(os.path.join(PATH, f"cb_{FOCUS_CTs[1]}_variational.pdf"), dpi=300, bbox_inches="tight")

In [None]:
tmp = exputils.metropolis_hastings(
    centers,
    lambda latent: exputils.get_cb_log_unnormalized_posterior(cb_model, sub_ds[mask, :], latent),
    target=N_POSTERIOR
)
posterior = np.concatenate([tmp[:, i, :] for i in range(tmp.shape[1])], axis=0)
deviation = posterior - np.concatenate([centers] * N_POSTERIOR, axis=0)
ax = exputils.aligned_posterior_plot(deviation, lim=0.4)
ax.set_title(f"{FOCUS_CTs[1]} (MCMC)")
ax.get_figure().savefig(os.path.join(PATH, f"cb_{FOCUS_CTs[1]}_mcmc.pdf"), dpi=300, bbox_inches="tight")

In [None]:
latent = sub_ds.latent
posterior = np.stack(np.split(sub_posterior_ds.latent, N_POSTERIOR), axis=1)
eud, npd, correctness = [], [], []
random_state = np.random.RandomState(2020)
for _ in range(10000):
    i, j = random_state.choice(sub_ds.shape[0], size=2)
    correctness.append(sub_ds.obs["class"].iloc[i] == sub_ds.obs["class"].iloc[j])
    eud.append(np.sqrt(np.square(latent[i] - latent[j]).sum()))
    npd.append(cb.blast.npd_v1(
        latent[i], latent[j],
        posterior[i], posterior[j]
    ))
eud = np.array(eud)
npd = np.array(npd)
correctness = np.array(correctness)

In [None]:
ax = exputils.distance_pair_plot(eud, npd, correctness)
ax.get_figure().savefig(os.path.join(PATH, "cb_distance_cmp.pdf"), dpi=300, bbox_inches="tight")

### scVI

In [None]:
tmp = exputils.get_scvi_latent(scvi_model, ds_scvi, n_posterior=N_POSTERIOR)
posterior_ds.latent = np.concatenate([tmp[:, i, :] for i in range(tmp.shape[1])], axis=0)
ax = posterior_ds.visualize_latent(
    method=None, size=3, width=4.5, height=4.5,
    scatter_kws=dict(alpha=1 / N_POSTERIOR, rasterized=True)
)
ds.latent, ds.obs["library"] = exputils.get_scvi_latent(scvi_model, ds_scvi, return_library=True)
ax = ds.visualize_latent("cell_type1", method=None, sort=True, scatter_kws=dict(rasterized=True), ax=ax)
ax.get_figure().savefig(os.path.join(PATH, "scvi_posterior.png"), dpi=300, bbox_inches="tight")

In [None]:
mask = functools.reduce(np.logical_and, [
    np.in1d(ds.obs["cell_type1"], FOCUS_CTs),
    ds.obs["latent_1"] > 0.0,
    ds.obs["latent_1"] < 2.2,
    ds.obs["latent_2"] > -1.2,
    ds.obs["latent_2"] < 0.9
])
sub_ds = ds[mask, :]
posterior_mask = np.concatenate([mask] * N_POSTERIOR, axis=0)
sub_posterior_ds = posterior_ds[posterior_mask, :]

mask = sklearn.metrics.silhouette_samples(sub_ds.latent, sub_ds.obs["cell_type1"]) > 0.1
sub_ds = sub_ds[mask, :]
posterior_mask = np.concatenate([mask] * N_POSTERIOR, axis=0)
sub_posterior_ds = sub_posterior_ds[posterior_mask, :]

In [None]:
svc = sklearn.svm.SVC(random_state=0, gamma=0.05).fit(sub_ds.latent, sub_ds.obs["cell_type1"])
sub_ds.obs["support"] = np.logical_and(
    np.in1d(np.arange(sub_ds.shape[0]), svc.support_),
    np.abs(svc.decision_function(sub_ds.latent)) > 0.5
)
sub_ds.obs["class"] = pd.Categorical(svc.predict(sub_ds.latent))
ax = sub_ds.visualize_latent(
    hue="support", style="class",
    method=None, sort=True, size=30, width=4.5, height=4.5,
    scatter_kws=dict(markers=["s", "^"], rasterized=True)
)
ax.get_figure().savefig(os.path.join(PATH, "scvi_support.pdf"), dpi=300, bbox_inches="tight")

In [None]:
mask = np.logical_and(sub_ds.obs["support"], sub_ds.obs["class"] == FOCUS_CTs[0])
centers = sub_ds.latent[mask]
posterior_mask = np.concatenate([mask] * N_POSTERIOR, axis=0)
posterior = sub_posterior_ds.latent[posterior_mask]
deviation = posterior - np.concatenate([centers] * N_POSTERIOR, axis=0)
ax = exputils.aligned_posterior_plot(deviation, lim=0.2)
ax.set_title(f"{FOCUS_CTs[0]} (variational)")
ax.get_figure().savefig(os.path.join(PATH, f"scvi_{FOCUS_CTs[0]}_variational.pdf"), dpi=300, bbox_inches="tight")

In [None]:
sub_ds_use = sub_ds[mask, :]
sub_ds_use.to_anndata().write_h5ad(os.path.join(PATH, "ds.h5ad"))
sub_ds_use_scvi = scvi.dataset.AnnDataset("ds.h5ad", save_path=PATH)
tmp, _ = exputils.metropolis_hastings(
    [centers, sub_ds_use.obs[["library"]].to_numpy()],
    lambda latent, library: exputils.get_scvi_log_unnormalized_posterior(scvi_model, sub_ds_use_scvi, latent, library),
    target=N_POSTERIOR, proposal_std=[0.02, 0.1]
)
posterior = np.concatenate([tmp[:, i, :] for i in range(tmp.shape[1])], axis=0)
deviation = posterior - np.concatenate([centers] * N_POSTERIOR, axis=0)
ax = exputils.aligned_posterior_plot(deviation, lim=0.4)
ax.set_title(f"{FOCUS_CTs[0]} (MCMC)")
ax.get_figure().savefig(os.path.join(PATH, f"scvi_{FOCUS_CTs[0]}_mcmc.pdf"), dpi=300, bbox_inches="tight")

In [None]:
mask = np.logical_and(sub_ds.obs["support"], sub_ds.obs["class"] == FOCUS_CTs[1])
centers = sub_ds.latent[mask]
posterior_mask = np.concatenate([mask] * N_POSTERIOR, axis=0)
posterior = sub_posterior_ds.latent[posterior_mask]
deviation = posterior - np.concatenate([centers] * N_POSTERIOR, axis=0)
ax = exputils.aligned_posterior_plot(deviation, lim=0.2)
ax.set_title(f"{FOCUS_CTs[1]} (variational)")
ax.get_figure().savefig(os.path.join(PATH, f"scvi_{FOCUS_CTs[1]}_variational.pdf"), dpi=300, bbox_inches="tight")

In [None]:
sub_ds_use = sub_ds[mask, :]
sub_ds_use.to_anndata().write_h5ad(os.path.join(PATH, "ds.h5ad"))
sub_ds_use_scvi = scvi.dataset.AnnDataset("ds.h5ad", save_path=PATH)
tmp, _ = exputils.metropolis_hastings(
    [centers, sub_ds_use.obs[["library"]].to_numpy()],
    lambda latent, library: exputils.get_scvi_log_unnormalized_posterior(scvi_model, sub_ds_use_scvi, latent, library),
    target=N_POSTERIOR, proposal_std=[0.02, 0.1]
)
posterior = np.concatenate([tmp[:, i, :] for i in range(tmp.shape[1])], axis=0)
deviation = posterior - np.concatenate([centers] * N_POSTERIOR, axis=0)
ax = exputils.aligned_posterior_plot(deviation, lim=0.4)
ax.set_title(f"{FOCUS_CTs[1]} (MCMC)")
ax.get_figure().savefig(os.path.join(PATH, f"scvi_{FOCUS_CTs[1]}_mcmc.pdf"), dpi=300, bbox_inches="tight")

In [None]:
latent = sub_ds.latent.astype(np.float32)
posterior = np.stack(np.split(sub_posterior_ds.latent, N_POSTERIOR), axis=1).astype(np.float32)
eud, npd, correctness = [], [], []
random_state = np.random.RandomState(2020)
for _ in range(10000):
    i, j = random_state.choice(sub_ds.shape[0], size=2)
    correctness.append(sub_ds.obs["class"].iloc[i] == sub_ds.obs["class"].iloc[j])
    eud.append(np.sqrt(np.square(latent[i] - latent[j]).sum()))
    npd.append(cb.blast.npd_v1(
        latent[i], latent[j],
        posterior[i], posterior[j]
    ))
eud = np.array(eud)
npd = np.array(npd)
correctness = np.array(correctness)

In [None]:
ax = exputils.distance_pair_plot(eud, npd, correctness)
ax.get_figure().savefig(os.path.join(PATH, "scvi_distance_cmp.pdf"), dpi=300, bbox_inches="tight")