In [None]:
import os
import Cell_BLAST as cb
import utils
os.environ["CUDA_VISIBLE_DEVICES"] = utils.pick_gpu_lowest_memory()
cb.config.RANDOM_SEED = 0
cb.config.N_JOBS = 4
fixed_model_kwargs = dict(
    latent_dim=10, cat_dim=100,
    epoch=500, patience=20
)

In [None]:
cb.__version__

---

# Smart-seq2

In [None]:
quake_smart_seq2 = cb.data.ExprDataSet.read_dataset("../../Datasets/data/Quake_Smart-seq2/data.h5")
utils.peek(quake_smart_seq2, "build/quake/Quake_Smart-seq2")
quake_smart_seq2.obs.head()

In [None]:
quake_smart_seq2.obs["cluster"] = "cluster_" + quake_smart_seq2.obs["cluster"].astype(int).astype(str)
quake_smart_seq2.obs.dtypes

In [None]:
opt_model_kwargs = dict(h_dim=512, prob_module_kwargs=dict(lambda_reg=0.001))
quake_smart_seq2_model = cb.directi.fit_DIRECTi(
    quake_smart_seq2, quake_smart_seq2.uns["seurat_genes"],
    **fixed_model_kwargs, **opt_model_kwargs
)
quake_smart_seq2.latent = quake_smart_seq2_model.inference(quake_smart_seq2)

In [None]:
ax = quake_smart_seq2.visualize_latent("cell_ontology_class", scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("build/quake/Quake_Smart-seq2/cell_ontology_class.svg", dpi=utils.DPI, bbox_inches="tight")

In [None]:
ax = quake_smart_seq2.visualize_latent("cell_type1", scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("build/quake/Quake_Smart-seq2/cell_type1.svg", dpi=utils.DPI, bbox_inches="tight")

In [None]:
ax = quake_smart_seq2.visualize_latent("cluster", scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("build/quake/Quake_Smart-seq2/cluster.svg", dpi=utils.DPI, bbox_inches="tight")

In [None]:
ax = quake_smart_seq2.visualize_latent("free_annotation", scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("build/quake/Quake_Smart-seq2/free_annotation.svg", dpi=utils.DPI, bbox_inches="tight")

In [None]:
quake_smart_seq2.write_dataset("build/quake/Quake_Smart-seq2/Quake_Smart-seq2.h5")

In [None]:
%%capture capio
quake_smart_seq2_models = [quake_smart_seq2_model]
for i in range(1, cb.config.N_JOBS):
    print("==== Model %d ====" % i)
    quake_smart_seq2_models.append(cb.directi.fit_DIRECTi(
        quake_smart_seq2, quake_smart_seq2.uns["seurat_genes"],
        **fixed_model_kwargs, **opt_model_kwargs,
        random_seed=i
    ))
quake_smart_seq2_blast = cb.blast.BLAST(quake_smart_seq2_models, quake_smart_seq2)
quake_smart_seq2_blast.save("build/quake/Quake_Smart-seq2")

In [None]:
with open("build/quake/Quake_Smart-seq2/stdout.txt", "w") as f:
    f.write(capio.stdout)
with open("build/quake/Quake_Smart-seq2/stderr.txt", "w") as f:
    f.write(capio.stderr)

In [None]:
utils.self_projection(quake_smart_seq2_blast, "build/quake/Quake_Smart-seq2")

In [None]:
%%writefile build/quake/Quake_Smart-seq2/predictable.txt
cell_ontology_class
cell_type1
cluster
free_annotation

---

# 10x

In [None]:
quake_10x = cb.data.ExprDataSet.read_dataset("../../Datasets/data/Quake_10x/data.h5")
utils.peek(quake_10x, "build/quake/Quake_10x")
quake_10x.obs.head()

In [None]:
quake_10x.obs["cluster"] = "cluster_" + quake_10x.obs["cluster"].astype(int).astype(str)
quake_10x.obs.dtypes

In [None]:
opt_model_kwargs = dict(h_dim=512, prob_module_kwargs=dict(lambda_reg=0.001))
quake_10x_model = cb.directi.fit_DIRECTi(
    quake_10x, quake_10x.uns["seurat_genes"],
    **fixed_model_kwargs, **opt_model_kwargs
)
quake_10x.latent = quake_10x_model.inference(quake_10x)

In [None]:
ax = quake_10x.visualize_latent("cell_ontology_class", scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("build/quake/Quake_10x/cell_ontology_class.svg", dpi=utils.DPI, bbox_inches="tight")

In [None]:
ax = quake_10x.visualize_latent("cell_type1", scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("build/quake/Quake_10x/cell_type1.svg", dpi=utils.DPI, bbox_inches="tight")

In [None]:
ax = quake_10x.visualize_latent("cluster", scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("build/quake/Quake_10x/cluster.svg", dpi=utils.DPI, bbox_inches="tight")

In [None]:
ax = quake_10x.visualize_latent("free_annotation", scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("build/quake/Quake_10x/free_annotation.svg", dpi=utils.DPI, bbox_inches="tight")

In [None]:
quake_10x.write_dataset("build/quake/Quake_10x/Quake_10x.h5")

In [None]:
%%capture capio
quake_10x_models = [quake_10x_model]
for i in range(1, cb.config.N_JOBS):
    print("==== Model %d ====" % i)
    quake_10x_models.append(cb.directi.fit_DIRECTi(
        quake_10x, quake_10x.uns["seurat_genes"],
        **fixed_model_kwargs, **opt_model_kwargs,
        random_seed=i
    ))
quake_10x_blast = cb.blast.BLAST(quake_10x_models, quake_10x)
quake_10x_blast.save("build/quake/Quake_10x")

In [None]:
with open("build/quake/Quake_10x/stdout.txt", "w") as f:
    f.write(capio.stdout)
with open("build/quake/Quake_10x/stderr.txt", "w") as f:
    f.write(capio.stderr)

In [None]:
utils.self_projection(quake_10x_blast, "build/quake/Quake_10x")

In [None]:
%%writefile build/quake/Quake_10x/predictable.txt
cell_ontology_class
cell_type1
cluster
free_annotation

---

# Aligned

In [None]:
if not os.path.exists("build/quake/ALIGNED_Tabula_Muris"):
    os.makedirs("build/quake/ALIGNED_Tabula_Muris")
tabula_muris = cb.data.ExprDataSet.merge_datasets({
    "Smart-seq2": quake_smart_seq2,
    "10x": quake_10x
}, merge_uns_slots=["seurat_genes"])

In [None]:
tabula_muris.obs.dtypes

In [None]:
opt_model_kwargs = dict(
    h_dim=512, batch_effect="dataset_name",
    prob_module_kwargs=dict(lambda_reg=0.001),
    rmbatch_module_kwargs=dict(lambda_reg=0.005)
)
tabula_muris_model = cb.directi.fit_DIRECTi(
    tabula_muris, tabula_muris.uns["seurat_genes"],
    **fixed_model_kwargs, **opt_model_kwargs
)
tabula_muris.latent = tabula_muris_model.inference(tabula_muris)

In [None]:
ax = tabula_muris.visualize_latent("cell_type1", scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("build/quake/ALIGNED_Tabula_Muris/cell_type1.svg", dpi=utils.DPI, bbox_inches="tight")

In [None]:
ax = tabula_muris.visualize_latent("cell_ontology_class", scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("build/quake/ALIGNED_Tabula_Muris/cell_ontology_class.svg", dpi=utils.DPI, bbox_inches="tight")

In [None]:
ax = tabula_muris.visualize_latent("platform", scatter_kws=dict(rasterized=True))
ax.get_figure().savefig("build/quake/ALIGNED_Tabula_Muris/platform.svg", dpi=utils.DPI, bbox_inches="tight")

In [None]:
tabula_muris.write_dataset("build/quake/ALIGNED_Tabula_Muris/ALIGNED_Tabula_Muris.h5")

In [None]:
%%capture capio
tabula_muris_models = [tabula_muris_model]
for i in range(2, cb.config.N_JOBS + 1):  # seed 1 requires extra regularization, skipped here
    print("==== Model %d ====" % i)
    tabula_muris_models.append(cb.directi.fit_DIRECTi(
        tabula_muris, tabula_muris.uns["seurat_genes"],
        **fixed_model_kwargs, **opt_model_kwargs,
        random_seed=i
    ))
tabula_muris_blast = cb.blast.BLAST(tabula_muris_models, tabula_muris)
tabula_muris_blast.save("build/quake/ALIGNED_Tabula_Muris")

In [None]:
with open("build/quake/ALIGNED_Tabula_Muris/stdout.txt", "w") as f:
    f.write(capio.stdout)
with open("build/quake/ALIGNED_Tabula_Muris/stderr.txt", "w") as f:
    f.write(capio.stderr)

In [None]:
utils.self_projection(tabula_muris_blast, "build/quake/ALIGNED_Tabula_Muris")

In [None]:
%%writefile build/quake/ALIGNED_Tabula_Muris/predictable.txt
cell_ontology_class