In [None]:
import os
import shutil
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import tensorflow as tf
from tensorboard.backend.event_processing import event_accumulator

import scvi.dataset
import scvi.models
import scvi.inference
import scvi.inference.annotation
import Cell_BLAST as cb

import exputils

In [None]:
cb.config.RANDOM_SEED = 0
plt.rcParams['svg.fonttype'] = "none"
plt.rcParams['font.family'] = "Arial"
os.environ["CUDA_VISIBLE_DEVICES"] = exputils.pick_gpu_lowest_memory()
PATH = "./training_dynamics/"
os.makedirs(PATH, exist_ok=True)

## Cell BLAST

In [None]:
ds = cb.data.ExprDataSet.read_dataset("../../Datasets/data/Baron_human/data.h5")
ds[:, ds.uns["seurat_genes"]].to_anndata().write_h5ad(os.path.join(PATH, "ds.h5ad"))
ds_scvi = scvi.dataset.AnnDataset("ds.h5ad", save_path=PATH)

In [None]:
if os.path.exists(os.path.join(PATH, "cb")):
    shutil.rmtree(os.path.join(PATH, "cb"))
model = cb.directi.fit_DIRECTi(
    ds, ds.uns["seurat_genes"], batch_effect="donor",
    latent_dim=10, cat_dim=20,
    random_seed=0, path=os.path.join(PATH, "cb")
)

In [None]:
ea = event_accumulator.EventAccumulator(
    glob.glob(os.path.join(PATH, "cb", "summary", "*.tfevents.*"))[0],
    size_guidance={event_accumulator.SCALARS: 0}
).Reload()
ea.Tags()["scalars"]

In [None]:
cb_loss_df = pd.concat([
    pd.DataFrame.from_records([
        {"Epoch": item.step, "Model": "Cell BLAST", "Partition": "Training", "Negative log-likelihood": item.value}
        for item in ea.Scalars("decoder/NB/raw_loss:0 (train)")
    ]),
    pd.DataFrame.from_records([
        {"Epoch": item.step, "Model": "Cell BLAST", "Partition": "Validation", "Negative log-likelihood": item.value}
        for item in ea.Scalars("decoder/NB/raw_loss:0 (val)")
    ])
]).loc[:, ["Epoch", "Model", "Partition", "Negative log-likelihood"]]
cb_loss_df["Negative log-likelihood"] *= len(ds.uns["seurat_genes"])
cb_loss_df.head()

## scVI

In [None]:
batch_indices = cb.utils.encode_integer(ds.obs["donor"])[0]
ds_scvi.batch_indices, ds_scvi.n_batches = batch_indices.reshape((-1, 1)), np.unique(batch_indices).size

In [None]:
np.random.seed(0)
torch.manual_seed(0)
scvi_model = scvi.models.VAE(ds_scvi.nb_genes, n_latent=10, n_batch=ds_scvi.n_batches)
scvi_trainer = scvi.inference.UnsupervisedTrainer(
    scvi_model, ds_scvi, use_cuda=True, metrics_to_monitor=["ll"], frequency=1,
    early_stopping_kwargs=dict(
        early_stopping_metric="ll", save_best_state_metric="ll",
        patience=30, threshold=0
    )
)
scvi_trainer.train(n_epochs=1000)

In [None]:
scvi_loss_df = pd.DataFrame(scvi_trainer.history).reset_index().rename(columns={
    "index": "Epoch",
    "ll_train_set": "Training",
    "ll_test_set": "Validation"
}).melt(
    id_vars="Epoch", var_name="Partition", value_name="Negative log-likelihood"
).assign(Model="scVI").loc[:, ["Epoch", "Model", "Partition", "Negative log-likelihood"]]
scvi_loss_df.head()

## Comparison

In [None]:
loss_df = pd.concat([cb_loss_df, scvi_loss_df])

In [None]:
with open("../../Evaluation/palette_method.json", "r") as f:
    palette = json.load(f)

In [None]:
fig, ax = plt.subplots(figsize=(4.5, 3.5))
ax = sns.lineplot(
    x="Epoch", y="Negative log-likelihood",
    hue="Model", style="Partition",
    palette=palette, data=loss_df, ax=ax
)
ax.set_ylim(300, 1000)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.legend(
    bbox_to_anchor=(1.05, 0.5), loc="center left",
    borderaxespad=0.0, frameon=False
)
fig.savefig(os.path.join(PATH, "ll_cmp.pdf"), bbox_inches="tight")

## Other Cell BLAST discriminator losses

In [None]:
cat_d_loss_df = pd.concat([
    pd.DataFrame.from_records([
        {"Epoch": item.step, "Loss": r"$D_c\ loss$", "Partition": "Training", "Value": item.value}
        for item in ea.Scalars("discriminator/CatGau/cat/d_loss/d_loss:0 (train)")
    ]),
    pd.DataFrame.from_records([
        {"Epoch": item.step, "Loss": r"$D_c\ loss$", "Partition": "Validation", "Value": item.value}
        for item in ea.Scalars("discriminator/CatGau/cat/d_loss/d_loss:0 (val)")
    ])
]).loc[:, ["Epoch", "Loss", "Partition", "Value"]]
cat_d_loss_df.head()

In [None]:
gau_d_loss_df = pd.concat([
    pd.DataFrame.from_records([
        {"Epoch": item.step, "Loss": r"$D_z\ loss$", "Partition": "Training", "Value": item.value}
        for item in ea.Scalars("discriminator/CatGau/gau/d_loss/d_loss:0 (train)")
    ]),
    pd.DataFrame.from_records([
        {"Epoch": item.step, "Loss": r"$D_z\ loss$", "Partition": "Validation", "Value": item.value}
        for item in ea.Scalars("discriminator/CatGau/gau/d_loss/d_loss:0 (val)")
    ])
]).loc[:, ["Epoch", "Loss", "Partition", "Value"]]
gau_d_loss_df.head()

In [None]:
batch_d_loss_df = pd.concat([
    pd.DataFrame.from_records([
        {"Epoch": item.step, "Loss": r"$D_b\ loss$", "Partition": "Training", "Value": item.value}
        for item in ea.Scalars("discriminator/donor/d_loss:0 (train)")
    ]),
    pd.DataFrame.from_records([
        {"Epoch": item.step, "Loss": r"$D_b\ loss$", "Partition": "Validation", "Value": item.value}
        for item in ea.Scalars("discriminator/donor/d_loss:0 (val)")
    ])
])
batch_d_loss_df = batch_d_loss_df.loc[
    batch_d_loss_df["Value"] != 0,
    ["Epoch", "Loss", "Partition", "Value"]
]
batch_d_loss_df.head()

In [None]:
fig, ax = plt.subplots(figsize=(4.5, 3.5))
ax = sns.lineplot(
    x="Epoch", y="Value", hue="Loss", style="Partition",
    data=pd.concat([gau_d_loss_df, cat_d_loss_df, batch_d_loss_df]), ax=ax
)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.legend(
    bbox_to_anchor=(1.05, 0.5), loc="center left",
    borderaxespad=0.0, frameon=False
)
fig.savefig(os.path.join(PATH, "cb_d_loss.pdf"), bbox_inches="tight")