### Training MMIDAS - a cpoupled mixutre VAE model
This notebook guides you through the process of training a mixture variational autoencoder (VAE).

In [1]:
import os
from mmidas.cpl_mixvae import cpl_mixVAE
from mmidas.utils.tools import get_paths
from mmidas.utils.dataloader import load_data, get_loaders

import warnings
warnings.filterwarnings('ignore')

Specify the training parameters.

In [2]:
n_categories = 120 # upper bound of number of categories (clusters)
state_dim = 2 # continuous (state) variable dimensionality 
n_arm = 2 # number of arms
latent_dim = 10 # latent dimensionality of the model
batch_size = 5000 # mini-batch size for training
n_epoch = 10 # number of epochs for training
n_epoch_p = 5 # number of epochs for pruning
min_con = 0.9 # minimum consensus among arms
max_prun_it = 2 # maximum number of pruning iterations
batch_size = 5000 # mini-batch size for training
lr = 1e-3 # learning rate for training

Load the prepared data (as described in ```1_DataPrep.ipynb```) and create training and validation sets.

In [3]:
toml_file = 'pyproject.toml'
sub_file = 'smartseq_files'
config = get_paths(toml_file=toml_file, sub_file=sub_file)
data_path = config['paths']['main_dir'] / config['paths']['data_path']
data_file = data_path / config[sub_file]['anndata_file']

/Users/yeganeh.marghi/github/MMIDAS/pyproject.toml
Getting files directories belong to smartseq_files...


In [4]:
data = load_data(datafile=data_file)
trainloader, testloader, _, _, _= get_loaders(dataset=data['log1p'], batch_size=batch_size)

data is loaded!
 --------- Data Summary --------- 
num cell types: 115, num cells: 22365, num genes:5032


Create a designated folder to store training files

In [5]:
n_run = 1
augmentation = False
folder_name = f'run_{n_run}_K_{n_categories}_Sdim_{state_dim}_aug_{augmentation}_lr_{lr}_n_arm_{n_arm}_nbatch_{batch_size}' + \
            f'_nepoch_{n_epoch}_nepochP_{n_epoch_p}'
saving_folder = config['paths']['main_dir'] / config['paths']['saving_path']
saving_folder = saving_folder / folder_name
os.makedirs(saving_folder, exist_ok=True)
os.makedirs(saving_folder / 'model', exist_ok=True)
saving_folder = str(saving_folder)

Construct a cpl-mixVAE object and launch its training on the prepared data.

In [6]:
cplMixVAE = cpl_mixVAE(saving_folder=saving_folder)
cplMixVAE.init_model(n_categories=n_categories,
                     state_dim=state_dim,
                     input_dim=data['log1p'].shape[1],
                     lowD_dim=latent_dim,
                     lr=lr,
                     n_arm=n_arm)

---> Computional node is not assigned, using CPU!


In [7]:
model_file = cplMixVAE.train(train_loader=trainloader,
                             test_loader=testloader,
                             n_epoch=n_epoch,
                             n_epoch_p=n_epoch_p,
                             min_con=min_con,
                             max_prun_it=max_prun_it)

Start training ...
====> Epoch:0, Total Loss: 137667456.0000, Rec_arm_1: 8.7999, Rec_arm_2: 8.7966, Joint Loss: 137578912.0000, Entropy: -7.1332, Distance: 0.4805, Elapsed Time:4.13
====> Validation Total Loss: 28064712704.0000, Rec. Loss: 8.7721
====> Epoch:1, Total Loss: 69825152.0000, Rec_arm_1: 8.7497, Rec_arm_2: 8.7451, Joint Loss: 69737120.0000, Entropy: -7.3747, Distance: 0.4512, Elapsed Time:3.64
====> Validation Total Loss: 25941938176.0000, Rec. Loss: 8.6993
====> Epoch:2, Total Loss: 42043478.0000, Rec_arm_1: 8.6571, Rec_arm_2: 8.6533, Joint Loss: 41956372.0000, Entropy: -7.3914, Distance: 0.4467, Elapsed Time:3.58
====> Validation Total Loss: 22861494272.0000, Rec. Loss: 8.5640
====> Epoch:3, Total Loss: 26674182.0000, Rec_arm_1: 8.4823, Rec_arm_2: 8.4839, Joint Loss: 26588808.0000, Entropy: -7.5388, Distance: 0.4278, Elapsed Time:3.60
====> Validation Total Loss: 20230934528.0000, Rec. Loss: 8.3146
====> Epoch:4, Total Loss: 17169558.2500, Rec_arm_1: 8.1604, Rec_arm_2: 8.1

Working directly with command line, you have the option to train the model using a Python file, such as ```tutorial/train_unimodal.py``` as follows.

```
python train_unimodal.py --n_epoch 10 --n_epoch_p 5 --max_prun_it 2
```
or
```
python train_unimodal.py --n_epoch 10 --n_epoch_p 5 --max_prun_it 2 --device 'cuda'
```