In [None]:
import sys
import os
import matplotlib.pyplot as plt
import Cell_BLAST as cb

sys.path.insert(0, "../../Evaluation")
import utils

In [None]:
plt.rcParams['svg.fonttype'] = "none"
plt.rcParams['font.family'] = "Arial"
os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory()
cb.config.RANDOM_SEED = 0
cb.config.N_JOBS = 4
PATH = "zifp"
os.makedirs(PATH, exist_ok=True)

In [None]:
ds1 = cb.data.ExprDataSet.read_dataset("../../Datasets/data/Quake_Smart-seq2_Fat/data.h5")
ds1 = utils.clean_dataset(ds1, "cell_ontology_class")

ds2 = cb.data.ExprDataSet.read_dataset("../../Datasets/data/Quake_Smart-seq2_Brain_Non-Myeloid/data.h5")
ds2 = utils.clean_dataset(ds2, "cell_ontology_class")

In [None]:
ds = cb.data.ExprDataSet.merge_datasets(
    {"ds1": ds1, "ds2": ds2},
    merge_uns_slots=["seurat_genes"]
)

## Negative binomial (default)

In [None]:
nb_model = cb.directi.fit_DIRECTi(
    ds, ds.uns["seurat_genes"], batch_effect="dataset_name",
    latent_dim=10, cat_dim=20, prob_module="NB",
    rmbatch_module_kwargs=dict(lambda_reg=0.02)
)
ds.latent = nb_model.inference(ds)

In [None]:
ax = ds.visualize_latent("cell_ontology_class", width=4.5, height=4.5, scatter_kws=dict(rasterized=True))
ax.get_figure().savefig(os.path.join(PATH, "nb_ct.pdf"), dpi=300, bbox_inches="tight")

In [None]:
ax = ds.visualize_latent("dataset_name", width=4.5, height=4.5, scatter_kws=dict(rasterized=True))
ax.get_figure().savefig(os.path.join(PATH, "nb_ds.pdf"), dpi=300, bbox_inches="tight")

## Zero-inflated negative binomial

In [None]:
zinb_model = cb.directi.fit_DIRECTi(
    ds, ds.uns["seurat_genes"], batch_effect="dataset_name",
    latent_dim=10, cat_dim=20, prob_module="ZINB",
    rmbatch_module_kwargs=dict(lambda_reg=0.02)
)
ds.latent = zinb_model.inference(ds)

In [None]:
ax = ds.visualize_latent("cell_ontology_class", width=4.5, height=4.5, scatter_kws=dict(rasterized=True))
ax.get_figure().savefig(os.path.join(PATH, "zinb_ct.pdf"), dpi=300, bbox_inches="tight")

In [None]:
ax = ds.visualize_latent("dataset_name", width=4.5, height=4.5, scatter_kws=dict(rasterized=True))
ax.get_figure().savefig(os.path.join(PATH, "zinb_ds.pdf"), dpi=300, bbox_inches="tight")