In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import Cell_BLAST as cb
import exputils

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = exputils.pick_gpu_lowest_memory()
cb.config.RANDOM_SEED = 0
cb.config.N_JOBS = 4
plt.rcParams['svg.fonttype'] = "none"
plt.rcParams['font.family'] = "Arial"
PATH = "./n_posterior_sample"
if not os.path.exists(PATH):
    os.makedirs(PATH)

In [None]:
dataset = cb.data.ExprDataSet.read_dataset("../../Datasets/data/Baron_human/data.h5")

In [None]:
rand_hex = cb.utils.rand_hex()
print("Training at %s..." % rand_hex)
model = cb.directi.fit_DIRECTi(
    dataset, dataset.uns["seurat_genes"], latent_dim=10, cat_dim=20,
    path="/tmp/cb/%s" % rand_hex
)

In [None]:
ref_indices = np.random.choice(np.arange(dataset.shape[0]), 1000, replace=False)
query_indices = np.random.choice(
    np.setdiff1d(np.arange(dataset.shape[0]), ref_indices),
    200, replace=False
)

In [None]:
ref = dataset[ref_indices, :]
query = dataset[query_indices, :]

In [None]:
posterior_distance_dict = {}
for n_posterior in (5, 10, 20, 50, 100, 500, 1000):
    print("==== n_posterior: %d ====" % n_posterior)
    blast = cb.blast.BLAST([model], ref, n_posterior=n_posterior)
    hits = blast.query(query, n_neighbors=50)
    posterior_distance_dict[n_posterior] = np.concatenate(hits.dist, axis=0)
    print()

In [None]:
n_posterior_list = [5, 10, 20, 50, 100, 500]
std_list = []
for n_posterior in n_posterior_list:
    std_list.append(np.std(
        (posterior_distance_dict[n_posterior] - posterior_distance_dict[1000]).ravel()
    ))
std_df = pd.DataFrame({
    "Number of posterior samples": n_posterior_list,
    "Standard deviation of NPD": std_list
})

In [None]:
fig, ax = plt.subplots(figsize=(4.0, 4.0))
ax = sns.lineplot(
    x="Number of posterior samples", y="Standard deviation of NPD",
    style=1, markers=True, dashes=False, data=std_df, ax=ax, legend=False
)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
fig.savefig(os.path.join(PATH, "pd_std.pdf"), bbox_inches="tight")