# Train VAE model

This notebook will first try to train the current VAE model before modifying the loss function to work with count data. Specifically, the loss function uses the loss from defined in [Eraslan et al.](https://www.nature.com/articles/s41467-018-07931-2). This publication uses the zero-inflated negative binomial (ZINB) distribution, which models highly sparse and overdispersed count data. ZINB is a mixture model that is composed of
   1. A point mass at 0 to represent the excess of 0's
   2. A NB distribution to represent the count distribution

Params of ZINB conditioned on the input data are estimated. These params include the mean and dispersion parameters of the NB component (μ and θ) and the mixture coefficient that represents the weight of the point mass (π)

We adopted code from: https://github.com/theislab/dca/blob/master/dca/loss.py

More details about the new model can be found: https://docs.google.com/presentation/d/1Q_0BUbfg51OicxY4MdI0IwhdhFfJmzX0f8VyuyGNXrw/edit#slide=id.ge45eb3c133_0_56

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import matplotlib.pyplot as plt
import pandas as pd
from cm_modules import paths, utils, train_vae_modules
import scanpy as sc
import anndata

In [2]:
# Set seeds to get reproducible VAE trained models
# train_vae_modules.set_all_seeds()

In [3]:
base_dir = os.path.abspath(os.path.join(os.getcwd(), "../"))

# Read in config variables
config_filename = os.path.abspath(
    os.path.join(base_dir, "test_vae_training", "config_current_vae.tsv")
)

params = utils.read_config(config_filename)

dataset_name = params["dataset_name"]

raw_compendium_filename = params["raw_compendium_filename"]
normalized_compendium_filename = params["normalized_compendium_filename"]

In [4]:
raw_compendium = pd.read_csv(raw_compendium_filename, sep="\t", index_col=0, header=0)

In [5]:
print(raw_compendium.shape)
raw_compendium.head()

(11857, 1232)


Unnamed: 0,Bacteria Actinobacteriota Actinobacteria Bifidobacteriales Bifidobacteriaceae Bifidobacterium,Bacteria Bacteroidota Bacteroidia Bacteroidales Bacteroidaceae Bacteroides,Bacteria Actinobacteriota Coriobacteriia Coriobacteriales Coriobacteriaceae Collinsella,Bacteria Firmicutes Clostridia Lachnospirales Lachnospiraceae Agathobacter,Bacteria Firmicutes Negativicutes Veillonellales-Selenomonadales Selenomonadaceae Megamonas,Bacteria Firmicutes Clostridia Lachnospirales Lachnospiraceae Blautia,Bacteria Firmicutes Clostridia Oscillospirales Ruminococcaceae Faecalibacterium,Bacteria Firmicutes Clostridia Lachnospirales Lachnospiraceae Anaerostipes,Bacteria Bacteroidota Bacteroidia Bacteroidales Prevotellaceae Prevotella,Bacteria Firmicutes Bacilli Lactobacillales Streptococcaceae Streptococcus,...,Bacteria Actinobacteriota Acidimicrobiia Microtrichales Ilumatobacteraceae NA,Bacteria Verrucomicrobiota Verrucomicrobiae Pedosphaerales Pedosphaeraceae ADurb.Bin063-1,Bacteria Proteobacteria Alphaproteobacteria Caulobacterales Caulobacteraceae PMMR1,Bacteria Bacteroidota Bacteroidia Flavobacteriales Cryomorphaceae NA,Bacteria Bacteroidota Bacteroidia Flavobacteriales Flavobacteriaceae Pseudofulvibacter,Bacteria Proteobacteria Alphaproteobacteria Rickettsiales Rickettsiaceae NA,Bacteria Bacteroidota Bacteroidia Flavobacteriales Flavobacteriaceae Gelidibacter,Bacteria Proteobacteria Gammaproteobacteria Burkholderiales Comamonadaceae Ideonella,Bacteria Proteobacteria Alphaproteobacteria Rhizobiales Xanthobacteraceae Rhodoplanes,Bacteria Proteobacteria Alphaproteobacteria Sphingomonadales Sphingomonadaceae Rhizorhapis
PRJDB5310_DRR077057,311,423,0,0,0,0,429,0,0,0,...,0,0,0,0,0,0,0,0,0,0
PRJDB5310_DRR077058,243,313,0,0,0,13,239,0,0,0,...,0,0,0,0,0,0,0,0,0,0
PRJDB5310_DRR077059,0,255,0,196,0,477,297,51,0,0,...,0,0,0,0,0,0,0,0,0,0
PRJDB5310_DRR077060,698,0,167,159,0,158,151,23,0,127,...,0,0,0,0,0,0,0,0,0,0
PRJDB5310_DRR077061,39,335,64,96,61,189,104,68,0,23,...,0,0,0,0,0,0,0,0,0,0


In [6]:
raw_compendium.T.to_csv(
    os.path.join(paths.LOCAL_DATA_DIR, "raw_microbiome_transposed.tsv"), sep="\t"
)

In [7]:
# TO DO:
# In the DCA paper, they log2 transformed and z-score normalized their data

# Try normalzing the data
# Here we are normalizing the microbiome count data per taxon
# so that each taxon is in the range 0-1
train_vae_modules.normalize_expression_data(
    base_dir, config_filename, raw_compendium_filename, normalized_compendium_filename
)

input: dataset contains 11857 samples and 1232 genes
Output: normalized dataset contains 11857 samples and 1232 genes


In [8]:
# test1 = pd.read_csv(raw_compendium_filename, sep="\t")
test2 = pd.read_csv(normalized_compendium_filename, sep="\t", index_col=0, header=0)

In [9]:
# test1.head()

In [10]:
test2.shape

(11857, 1232)

In [11]:
test2.head()

Unnamed: 0,Bacteria Actinobacteriota Actinobacteria Bifidobacteriales Bifidobacteriaceae Bifidobacterium,Bacteria Bacteroidota Bacteroidia Bacteroidales Bacteroidaceae Bacteroides,Bacteria Actinobacteriota Coriobacteriia Coriobacteriales Coriobacteriaceae Collinsella,Bacteria Firmicutes Clostridia Lachnospirales Lachnospiraceae Agathobacter,Bacteria Firmicutes Negativicutes Veillonellales-Selenomonadales Selenomonadaceae Megamonas,Bacteria Firmicutes Clostridia Lachnospirales Lachnospiraceae Blautia,Bacteria Firmicutes Clostridia Oscillospirales Ruminococcaceae Faecalibacterium,Bacteria Firmicutes Clostridia Lachnospirales Lachnospiraceae Anaerostipes,Bacteria Bacteroidota Bacteroidia Bacteroidales Prevotellaceae Prevotella,Bacteria Firmicutes Bacilli Lactobacillales Streptococcaceae Streptococcus,...,Bacteria Actinobacteriota Acidimicrobiia Microtrichales Ilumatobacteraceae NA,Bacteria Verrucomicrobiota Verrucomicrobiae Pedosphaerales Pedosphaeraceae ADurb.Bin063-1,Bacteria Proteobacteria Alphaproteobacteria Caulobacterales Caulobacteraceae PMMR1,Bacteria Bacteroidota Bacteroidia Flavobacteriales Cryomorphaceae NA,Bacteria Bacteroidota Bacteroidia Flavobacteriales Flavobacteriaceae Pseudofulvibacter,Bacteria Proteobacteria Alphaproteobacteria Rickettsiales Rickettsiaceae NA,Bacteria Bacteroidota Bacteroidia Flavobacteriales Flavobacteriaceae Gelidibacter,Bacteria Proteobacteria Gammaproteobacteria Burkholderiales Comamonadaceae Ideonella,Bacteria Proteobacteria Alphaproteobacteria Rhizobiales Xanthobacteraceae Rhodoplanes,Bacteria Proteobacteria Alphaproteobacteria Sphingomonadales Sphingomonadaceae Rhizorhapis
PRJDB5310_DRR077057,0.001496,0.000948,0.0,0.0,0.0,0.0,0.004278,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
PRJDB5310_DRR077058,0.001169,0.000702,0.0,0.0,0.0,4.6e-05,0.002383,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
PRJDB5310_DRR077059,0.0,0.000572,0.0,0.001081,0.0,0.001694,0.002962,0.001032,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
PRJDB5310_DRR077060,0.003358,0.0,0.004034,0.000877,0.0,0.000561,0.001506,0.000465,0.0,0.00062,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
PRJDB5310_DRR077061,0.000188,0.000751,0.001546,0.000529,0.003158,0.000671,0.001037,0.001376,0.0,0.000112,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [12]:
# Convert input to anndata object
test2_anndata = anndata.AnnData(test2)

In [13]:
# Save
raw_compendium_anndata_filename = os.path.join(
    paths.LOCAL_DATA_DIR, "raw_microbiome_transposed_anndata.h5ad"
)
test2_anndata.write(raw_compendium_anndata_filename)

In [14]:
# Create VAE directories if needed
output_dirs = [
    os.path.join(base_dir, dataset_name, "models"),
    os.path.join(base_dir, dataset_name, "logs"),
]

NN_architecture = params["NN_architecture"]

# Check if NN architecture directory exist otherwise create
for each_dir in output_dirs:
    sub_dir = os.path.join(each_dir, NN_architecture)
    os.makedirs(sub_dir, exist_ok=True)

In [25]:
# Train VAE on new compendium data
train_vae_modules.train_vae(config_filename, raw_compendium_anndata_filename)

dca: Successfully preprocessed 1232 genes and 11857 cells.
Successfully read in data
(11857, 1232)
Normalized input data
(2625, 1232)
input dataset contains 2625 rows and 1232 columns
built network
about to compiile
optimizer <tensorflow.python.keras.optimizer_v2.adam.Adam object at 0x7fee1e7168d0>
result Tensor("zinb_loss/SelectV2_26:0", shape=(), dtype=float32)
Model: "model_52"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
count (InputLayer)              [(None, 1232)]       0                                            
__________________________________________________________________________________________________
enc0 (Dense)                    (None, 2500)         3082500     count[0][0]                      
__________________________________________________________________________________________________
batch_normalization_24

FailedPreconditionError: Could not find variable training_16/Adam/learning_rate. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status=Not found: Container localhost does not exist. (Could not find resource: localhost/training_16/Adam/learning_rate)
	 [[node training_16/Adam/learning_rate/Read/ReadVariableOp (defined at /home/alexandra/Documents/Repos/common-microbes/cm_modules/train.py:120) ]]

Original stack trace for 'training_16/Adam/learning_rate/Read/ReadVariableOp':
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/traitlets/config/application.py", line 845, in launch_instance
    app.start()
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/ipykernel/kernelapp.py", line 505, in start
    self.io_loop.start()
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tornado/platform/asyncio.py", line 199, in start
    self.asyncio_loop.run_forever()
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
    self._run_once()
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
    handle._run()
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tornado/ioloop.py", line 688, in <lambda>
    lambda f: self._run_callback(functools.partial(callback, future))
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tornado/ioloop.py", line 741, in _run_callback
    ret = callback()
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tornado/gen.py", line 814, in inner
    self.ctx_run(self.run)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tornado/gen.py", line 775, in run
    yielded = self.gen.send(value)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 365, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tornado/gen.py", line 234, in wrapper
    yielded = ctx_run(next, result)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 272, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tornado/gen.py", line 234, in wrapper
    yielded = ctx_run(next, result)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 542, in execute_request
    user_expressions, allow_stdin,
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tornado/gen.py", line 234, in wrapper
    yielded = ctx_run(next, result)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/ipykernel/ipkernel.py", line 294, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2899, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2944, in _run_cell
    return runner(coro)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3170, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3361, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-25-0ffc79af1531>", line 2, in <module>
    train_vae_modules.train_vae(config_filename, raw_compendium_anndata_filename)
  File "/home/alexandra/Documents/Repos/common-microbes/cm_modules/train_vae_modules.py", line 216, in train_vae
    tensorboard=False
  File "/home/alexandra/Documents/Repos/common-microbes/cm_modules/train.py", line 120, in train
    **kwds)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_v1.py", line 814, in fit
    use_multiprocessing=use_multiprocessing)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_arrays_v1.py", line 661, in fit
    steps_name='steps_per_epoch')
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_arrays_v1.py", line 181, in model_iteration
    f = _make_execution_function(model, mode)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_arrays_v1.py", line 551, in _make_execution_function
    return model._make_execution_function(mode)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_v1.py", line 2097, in _make_execution_function
    self._make_train_function()
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_v1.py", line 2029, in _make_train_function
    params=self._collected_trainable_weights, loss=self.total_loss)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 773, in get_updates
    return [self.apply_gradients(grads_and_vars)]
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 636, in apply_gradients
    self._create_all_weights(var_list)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 822, in _create_all_weights
    self._create_hypers()
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 975, in _create_hypers
    aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 1194, in add_weight
    aggregation=aggregation)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py", line 815, in _add_variable_with_custom_getter
    **kwargs_for_getter)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer_utils.py", line 139, in make_variable
    shape=variable_shape if variable_shape else None)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/ops/variables.py", line 260, in __call__
    return cls._variable_v1_call(*args, **kwargs)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/ops/variables.py", line 221, in _variable_v1_call
    shape=shape)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/ops/variables.py", line 199, in <lambda>
    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/ops/variable_scope.py", line 2626, in default_variable_creator
    shape=shape)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/ops/variables.py", line 264, in __call__
    return super(VariableMetaclass, cls).__call__(*args, **kwargs)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 1595, in __init__
    distribute_strategy=distribute_strategy)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 1777, in _init_from_args
    value = gen_resource_variable_ops.read_variable_op(handle, dtype)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/ops/gen_resource_variable_ops.py", line 485, in read_variable_op
    "ReadVariableOp", resource=resource, dtype=dtype, name=name)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 750, in _apply_op_helper
    attrs=attr_protos, op_def=op_def)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3565, in _create_op_internal
    op_def=op_def)
  File "/home/alexandra/anaconda3/envs/common_microbe/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 2045, in __init__
    self._traceback = tf_stack.extract_stack_for_node(self._c_op)


In [17]:
(2142 + 238) / 11857

0.2007253099434933

In [None]:
# Plot training and validation loss separately
# stat_logs_filename = "logs/DCA/tybalt_2layer_30latent_stats.tsv"

# stats = pd.read_csv(stat_logs_filename, sep="\t", index_col=None, header=0)

In [None]:
# plt.plot(stats["loss"])

In [None]:
# plt.plot(stats["val_loss"], color="orange")