In [None]:
import os
import numpy as np
import pandas as pd
import sklearn.metrics
import matplotlib.pyplot as plt
import matplotlib.ticker as tkr
import seaborn as sns
import Cell_BLAST as cb
import exputils

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = exputils.pick_gpu_lowest_memory()
cb.config.RANDOM_SEED = 0
cb.config.N_JOBS = 8
plt.rcParams['svg.fonttype'] = "none"
plt.rcParams['font.family'] = "Arial"
random_state = np.random.RandomState(0)
PATH = "./n_models"
if not os.path.exists(PATH):
    os.makedirs(PATH)

In [None]:
def pred_report(pred, true, positive_labels):
    true_positive_mask = np.in1d(true, positive_labels)
    pred_positive_mask = ~np.in1d(pred, ["rejected"])
    sensitivity = np.logical_and(true_positive_mask, pred_positive_mask).sum() / true_positive_mask.sum()
    specificity = np.logical_and(~true_positive_mask, ~pred_positive_mask).sum() / (~true_positive_mask).sum()
    positive_mask = np.logical_and(true_positive_mask, pred_positive_mask)
    acc = (true[positive_mask] == pred[positive_mask]).sum() / positive_mask.sum()
    return acc, specificity, sensitivity

# Read data

## Reference

In [None]:
ref = cb.data.ExprDataSet.merge_datasets(dict(
    baron=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Baron_human/data.h5"),
    xin=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Xin_2016/data.h5"),
    lawlor=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Lawlor/data.h5")
), merge_uns_slots=["seurat_genes"])
ref = exputils.clean_dataset(ref, "cell_ontology_class")

## Query

In [None]:
query = dict(
    # Positive
    segerstolpe=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Segerstolpe/data.h5"),
    enge=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Enge/data.h5"),
    muraro=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Muraro/data.h5"),
    # Negative
    wu_human=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Wu_human/data.h5"),
    zheng=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Zheng/data.h5"),
    philippeos=cb.data.ExprDataSet.read_dataset("../../Datasets/data/Philippeos/data.h5")
)

min_size = max(min(val.shape[0] for val in query.values()), 2000)
query = cb.data.ExprDataSet.merge_datasets({
    key: val[
        random_state.choice(val.shape[0], min(min_size, val.shape[0]), replace=False), :
    ] for key, val in query.items()
})
query = exputils.clean_dataset(query, "cell_ontology_class")

# Train models

In [None]:
models = []
for i in range(128):
    print("==== Model: %d ====" % i)
    try:
        model = cb.directi.DIRECTi.load(
            os.path.join(PATH, "model_%d" % i),
            _mode=cb.directi.DIRECTi._TEST
        )
    except Exception:
        model = cb.directi.fit_DIRECTi(
            ref, ref.uns["seurat_genes"], batch_effect="dataset_name",
            latent_dim=10, cat_dim=20, epoch=300, patience=20,
            random_seed=i, path=os.path.join(PATH, "model_%d" % i)
        )
        model.save()
    models.append(model)

# Test BLAST with different number of models

In [None]:
np.random.seed(0)
hits_dict = {}
for n_model in (1, 2, 4, 8, 16):
    print("==== Number of models: %d ====" % n_model)
    hits_dict[n_model] = []
    available_models = np.arange(128)
    for trial in range(8):
        used_models = np.random.choice(available_models, n_model, replace=False)
        available_models = np.setdiff1d(available_models, used_models)
        blast = cb.blast.BLAST([models[idx] for idx in used_models], ref)
        hits_dict[n_model].append(blast.query(query))

In [None]:
pred_dict = {}
for n_model in hits_dict.keys():
    pred_dict[n_model] = []
    for hits in hits_dict[n_model]:
        pred = hits.reconcile_models().filter(
            "pval", 0.05
        ).annotate("cell_ontology_class")["cell_ontology_class"]
        pred_dict[n_model].append(pred)

In [None]:
report_df = [(key, *pred_report(
    pred_dict[key][i].values, query.obs["cell_ontology_class"], np.unique(ref.obs["cell_ontology_class"])
)) for key in pred_dict.keys() for i in range(len(pred_dict[key]))]
report_df = [*zip(*report_df)]
report_df = pd.DataFrame({
    "Number of models": report_df[0],
    "Accuracy": report_df[1],
    "Specificity": report_df[2],
    "Sensitivity": report_df[3]
}).melt(id_vars="Number of models", var_name="Metric", value_name="Value")
report_df.to_csv(os.path.join(PATH, "n_models.csv"))

In [None]:
report_df = pd.read_csv(os.path.join(PATH, "n_models.csv"))
fig, ax = plt.subplots(figsize=(4.0, 4.0))
ax = sns.lineplot(
    x="Number of models", y="Value", hue="Metric", style="Metric",
    markers=True, dashes=False, data=report_df, ax=ax
)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.xaxis.set_major_locator(tkr.MaxNLocator(integer=True))
plt.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), borderaxespad=0.0, frameon=False)
fig.savefig(os.path.join(PATH, "n_models.pdf"), bbox_inches="tight")