# Train VAE model

This notebook will first try to train the current VAE model before modifying the loss function to work with count data. Specifically, the loss function uses the loss from defined in [Eraslan et al.](https://www.nature.com/articles/s41467-018-07931-2). This publication uses the zero-inflated negative binomial (ZINB) distribution, which models highly sparse and overdispersed count data. ZINB is a mixture model that is composed of
   1. A point mass at 0 to represent the excess of 0's
   2. A NB distribution to represent the count distribution

Params of ZINB conditioned on the input data are estimated. These params include the mean and dispersion parameters of the NB component (μ and θ) and the mixture coefficient that represents the weight of the point mass (π)

We adopted code from: https://github.com/theislab/dca/blob/master/dca/loss.py

More details about the new model can be found: https://docs.google.com/presentation/d/1Q_0BUbfg51OicxY4MdI0IwhdhFfJmzX0f8VyuyGNXrw/edit#slide=id.ge45eb3c133_0_56

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import matplotlib.pyplot as plt
import pandas as pd
from cm_modules import paths, utils, train_vae_modules

Using TensorFlow backend.


In [2]:
# Set seeds to get reproducible VAE trained models
# train_vae_modules.set_all_seeds()

In [3]:
base_dir = os.path.abspath(os.path.join(os.getcwd(), "../"))

# Read in config variables
config_filename = os.path.abspath(
    os.path.join(base_dir, "test_vae_training", "config_current_vae.tsv")
)

params = utils.read_config(config_filename)

dataset_name = params["dataset_name"]

raw_compendium_filename = params["raw_compendium_filename"]
normalized_compendium_filename = params["normalized_compendium_filename"]

In [4]:
raw_compendium = pd.read_csv(raw_compendium_filename, sep="\t", index_col=0, header=0)

In [5]:
print(raw_compendium.shape)
raw_compendium.head()

(11857, 1232)


Unnamed: 0,Bacteria Actinobacteriota Actinobacteria Bifidobacteriales Bifidobacteriaceae Bifidobacterium,Bacteria Bacteroidota Bacteroidia Bacteroidales Bacteroidaceae Bacteroides,Bacteria Actinobacteriota Coriobacteriia Coriobacteriales Coriobacteriaceae Collinsella,Bacteria Firmicutes Clostridia Lachnospirales Lachnospiraceae Agathobacter,Bacteria Firmicutes Negativicutes Veillonellales-Selenomonadales Selenomonadaceae Megamonas,Bacteria Firmicutes Clostridia Lachnospirales Lachnospiraceae Blautia,Bacteria Firmicutes Clostridia Oscillospirales Ruminococcaceae Faecalibacterium,Bacteria Firmicutes Clostridia Lachnospirales Lachnospiraceae Anaerostipes,Bacteria Bacteroidota Bacteroidia Bacteroidales Prevotellaceae Prevotella,Bacteria Firmicutes Bacilli Lactobacillales Streptococcaceae Streptococcus,...,Bacteria Actinobacteriota Acidimicrobiia Microtrichales Ilumatobacteraceae NA,Bacteria Verrucomicrobiota Verrucomicrobiae Pedosphaerales Pedosphaeraceae ADurb.Bin063-1,Bacteria Proteobacteria Alphaproteobacteria Caulobacterales Caulobacteraceae PMMR1,Bacteria Bacteroidota Bacteroidia Flavobacteriales Cryomorphaceae NA,Bacteria Bacteroidota Bacteroidia Flavobacteriales Flavobacteriaceae Pseudofulvibacter,Bacteria Proteobacteria Alphaproteobacteria Rickettsiales Rickettsiaceae NA,Bacteria Bacteroidota Bacteroidia Flavobacteriales Flavobacteriaceae Gelidibacter,Bacteria Proteobacteria Gammaproteobacteria Burkholderiales Comamonadaceae Ideonella,Bacteria Proteobacteria Alphaproteobacteria Rhizobiales Xanthobacteraceae Rhodoplanes,Bacteria Proteobacteria Alphaproteobacteria Sphingomonadales Sphingomonadaceae Rhizorhapis
PRJDB5310_DRR077057,311,423,0,0,0,0,429,0,0,0,...,0,0,0,0,0,0,0,0,0,0
PRJDB5310_DRR077058,243,313,0,0,0,13,239,0,0,0,...,0,0,0,0,0,0,0,0,0,0
PRJDB5310_DRR077059,0,255,0,196,0,477,297,51,0,0,...,0,0,0,0,0,0,0,0,0,0
PRJDB5310_DRR077060,698,0,167,159,0,158,151,23,0,127,...,0,0,0,0,0,0,0,0,0,0
PRJDB5310_DRR077061,39,335,64,96,61,189,104,68,0,23,...,0,0,0,0,0,0,0,0,0,0


In [6]:
raw_compendium.T.to_csv(
    os.path.join(paths.LOCAL_DATA_DIR, "raw_microbiome_transposed.tsv"), sep="\t"
)

In [7]:
# TO DO:
# In the DCA paper, they log2 transformed and z-score normalized their data

# Try normalzing the data
# Here we are normalizing the microbiome count data per taxon
# so that each taxon is in the range 0-1
train_vae_modules.normalize_expression_data(
    base_dir, config_filename, raw_compendium_filename, normalized_compendium_filename
)

input: dataset contains 11857 samples and 1232 genes
Output: normalized dataset contains 11857 samples and 1232 genes


In [8]:
# test1 = pd.read_csv(raw_compendium_filename, sep="\t")
# test2 = pd.read_csv(normalized_compendium_filename, sep="\t")

In [9]:
# test1.head()

In [10]:
# test2.head()

In [11]:
# Create VAE directories if needed
output_dirs = [
    os.path.join(base_dir, dataset_name, "models"),
    os.path.join(base_dir, dataset_name, "logs"),
]

NN_architecture = params["NN_architecture"]

# Check if NN architecture directory exist otherwise create
for each_dir in output_dirs:
    sub_dir = os.path.join(each_dir, NN_architecture)
    os.makedirs(sub_dir, exist_ok=True)

In [13]:
# Train VAE on new compendium data
train_vae_modules.train_vae(config_filename, raw_compendium_filename)

ValueError: Variables annot. `var` must have number of columns of `X` (1232), but has 7391 rows.

In [None]:
# Plot training and validation loss separately
stat_logs_filename = "logs/DCA/tybalt_2layer_30latent_stats.tsv"

# stats = pd.read_csv(stat_logs_filename, sep="\t", index_col=None, header=0)

In [None]:
# plt.plot(stats["loss"])

In [None]:
# plt.plot(stats["val_loss"], color="orange")