### Training MMIDAS - a coupled mixture VAE model
This notebook guides you through the process of training a mixture variational autoencoder.

In [1]:
import mmidas
from mmidas.nn_model import mixVAEConfig, loss_fn
from mmidas.cpl_mixvae import cpl_mixVAE
from mmidas.utils.tools import get_paths
from mmidas.utils.dataloader import load_data, get_loaders
import importlib
import os
import torch as t

import warnings
warnings.filterwarnings('ignore')

In [2]:
t.cuda.is_available()
t.cuda.device_count()

4

Specify the training parameters.

In [3]:
n_categories = 120 # upper bound of number of categories (clusters)
state_dim = 2 # continuous (state) variable dimensionality 
n_arm = 2 # number of arms
latent_dim = 10 # latent dimensionality of the model
batch_size = 5000 # mini-batch size for training
n_epoch = 10 # number of epochs for training
n_epoch_p = 5 # number of epochs for pruning
min_con = 0.9 # minimum consensus among arms
max_prun_it = 2 # maximum number of pruning iterations
batch_size = 5000 # mini-batch size for training
lr = 1e-3 # learning rate for training

Load the prepared data (as described in ```1_data_prep.ipynb```) and create training and validation sets.

In [4]:
toml_file = 'pyproject.toml'
sub_file = 'smartseq_files'
config = get_paths(toml_file=toml_file, sub_file=sub_file)
data_path = config['paths']['main_dir'] / config['paths']['data_path']
data_file = data_path / config[sub_file]['anndata_file']

/allen/programs/celltypes/workgroups/mousecelltypes/Hilal/MMIDAS/pyproject.toml
Getting files directories belong to smartseq_files...


In [5]:
data = load_data(datafile=data_file)
trainloader, testloader, _, = get_loaders(dataset=data['log1p'], batch_size=batch_size)

data is loaded!
 --------- Data Summary --------- 
num cell types: 115, num cells: 22365, num genes:5032


Create a designated folder to store training files

In [6]:
n_run = 1
augmentation = False
folder_name = f'run_{n_run}_K_{n_categories}_Sdim_{state_dim}_aug_{augmentation}_lr_{lr}_n_arm_{n_arm}_nbatch_{batch_size}' + \
            f'_nepoch_{n_epoch}_nepochP_{n_epoch_p}'
saving_folder = config['paths']['main_dir'] / config['paths']['saving_path']
saving_folder = saving_folder / folder_name
os.makedirs(saving_folder, exist_ok=True)
os.makedirs(saving_folder / 'model', exist_ok=True)
saving_folder = str(saving_folder)

Construct a cpl-mixVAE object and launch its training on the prepared data.

In [37]:
importlib.reload(mmidas.cpl_mixvae)
importlib.reload(mmidas.nn_model)
from mmidas.cpl_mixvae import cpl_mixVAE

cplMixVAE = cpl_mixVAE(saving_folder=saving_folder, device='cuda', seed=546)
cplMixVAE.init(categories=n_categories,
                     state_dim=state_dim,
                     input_dim=data['log1p'].shape[1],
                     lowD_dim=latent_dim,
                     lr=lr,
                     arms=n_arm)

model_file = cplMixVAE.train(train_loader=trainloader,
                             test_loader=testloader,
                             n_epoch=n_epoch,
                             n_epoch_p=n_epoch_p,
                             min_con=min_con,
                             max_prun_it=max_prun_it)

device: Tesla V100-SXM2-32GB
Start training ...


 10%|█         | 1/10 [00:00<00:05,  1.66it/s, Total Loss=1.7e+8, Rec_arm_1=8.8, Rec_arm_2=8.8, Joint Loss=1.69e+8, Entropy=-6.86, Distance=0.51, Elapsed Time=0.603, Validation Loss=3.66e+10, Validation Rec. Loss=8.77]

           -- EPOCH 0 --           
Total Loss:                    169525010.0000                
Rec_arm_1:                     8.7993                        
Rec_arm_2:                     8.7950                        
Joint Loss:                    169436480.0000                
Entropy:                       -6.8606                       
Distance:                      0.5097                        
Elapsed Time:                  0.6027                        
Validation Loss:               36639199232.0000              
Validation Rec. Loss:          8.7700                        


 20%|██        | 2/10 [00:03<00:17,  2.14s/it, Total Loss=6.44e+7, Rec_arm_1=8.75, Rec_arm_2=8.74, Joint Loss=6.43e+7, Entropy=-7.02, Distance=0.481, Elapsed Time=3.21, Validation Loss=4.44e+10, Validation Rec. Loss=8.7]

           -- EPOCH 1 --           
Total Loss:                    64385470.0000                 
Rec_arm_1:                     8.7533                        
Rec_arm_2:                     8.7364                        
Joint Loss:                    64297464.0000                 
Entropy:                       -7.0197                       
Distance:                      0.4807                        
Elapsed Time:                  3.2087                        
Validation Loss:               44432412672.0000              
Validation Rec. Loss:          8.6952                        


 30%|███       | 3/10 [00:04<00:10,  1.44s/it, Total Loss=3.65e+7, Rec_arm_1=8.67, Rec_arm_2=8.63, Joint Loss=3.64e+7, Entropy=-7.17, Distance=0.462, Elapsed Time=0.6, Validation Loss=4.48e+10, Validation Rec. Loss=8.55]

           -- EPOCH 2 --           
Total Loss:                    36484357.0000                 
Rec_arm_1:                     8.6678                        
Rec_arm_2:                     8.6273                        
Joint Loss:                    36397324.0000                 
Entropy:                       -7.1734                       
Distance:                      0.4620                        
Elapsed Time:                  0.6003                        
Validation Loss:               44797005824.0000              
Validation Rec. Loss:          8.5529                        


 40%|████      | 4/10 [00:04<00:06,  1.09s/it, Total Loss=2.11e+7, Rec_arm_1=8.51, Rec_arm_2=8.42, Joint Loss=2.1e+7, Entropy=-7.28, Distance=0.45, Elapsed Time=0.552, Validation Loss=6.94e+10, Validation Rec. Loss=8.29]

           -- EPOCH 3 --           
Total Loss:                    21116691.5000                 
Rec_arm_1:                     8.5100                        
Rec_arm_2:                     8.4189                        
Joint Loss:                    21031506.0000                 
Entropy:                       -7.2842                       
Distance:                      0.4500                        
Elapsed Time:                  0.5519                        
Validation Loss:               69400813568.0000              
Validation Rec. Loss:          8.2856                        


 50%|█████     | 5/10 [00:05<00:04,  1.12it/s, Total Loss=1.44e+7, Rec_arm_1=8.22, Rec_arm_2=8.03, Joint Loss=1.43e+7, Entropy=-7.34, Distance=0.442, Elapsed Time=0.55, Validation Loss=8.72e+10, Validation Rec. Loss=7.8]

           -- EPOCH 4 --           
Total Loss:                    14382812.7500                 
Rec_arm_1:                     8.2198                        
Rec_arm_2:                     8.0259                        
Joint Loss:                    14301064.0000                 
Entropy:                       -7.3414                       
Distance:                      0.4416                        
Elapsed Time:                  0.5497                        
Validation Loss:               87159382016.0000              
Validation Rec. Loss:          7.7980                        


 60%|██████    | 6/10 [00:06<00:03,  1.29it/s, Total Loss=1.16e+7, Rec_arm_1=7.7, Rec_arm_2=7.33, Joint Loss=1.15e+7, Entropy=-7.4, Distance=0.436, Elapsed Time=0.549, Validation Loss=7.31e+10, Validation Rec. Loss=6.98]

           -- EPOCH 5 --           
Total Loss:                    11607130.7500                 
Rec_arm_1:                     7.7031                        
Rec_arm_2:                     7.3309                        
Joint Loss:                    11531480.0000                 
Entropy:                       -7.3986                       
Distance:                      0.4365                        
Elapsed Time:                  0.5487                        
Validation Loss:               73135783936.0000              
Validation Rec. Loss:          6.9796                        


 70%|███████   | 7/10 [00:06<00:02,  1.42it/s, Total Loss=9.36e+6, Rec_arm_1=6.86, Rec_arm_2=6.25, Joint Loss=9.3e+6, Entropy=-7.44, Distance=0.431, Elapsed Time=0.55, Validation Loss=5.75e+10, Validation Rec. Loss=5.79]

           -- EPOCH 6 --           
Total Loss:                    9361855.0000                  
Rec_arm_1:                     6.8556                        
Rec_arm_2:                     6.2485                        
Joint Loss:                    9295914.0000                  
Entropy:                       -7.4392                       
Distance:                      0.4306                        
Elapsed Time:                  0.5496                        
Validation Loss:               57542168576.0000              
Validation Rec. Loss:          5.7899                        


 80%|████████  | 8/10 [00:07<00:01,  1.52it/s, Total Loss=7.49e+6, Rec_arm_1=5.65, Rec_arm_2=4.95, Joint Loss=7.44e+6, Entropy=-7.49, Distance=0.428, Elapsed Time=0.552, Validation Loss=5.08e+10, Validation Rec. Loss=4.57]

           -- EPOCH 7 --           
Total Loss:                    7491239.2500                  
Rec_arm_1:                     5.6494                        
Rec_arm_2:                     4.9548                        
Joint Loss:                    7437878.5000                  
Entropy:                       -7.4865                       
Distance:                      0.4278                        
Elapsed Time:                  0.5524                        
Validation Loss:               50779258880.0000              
Validation Rec. Loss:          4.5688                        


 90%|█████████ | 9/10 [00:07<00:00,  1.55it/s, Total Loss=6.18e+6, Rec_arm_1=4.43, Rec_arm_2=4.3, Joint Loss=6.14e+6, Entropy=-7.52, Distance=0.419, Elapsed Time=0.616, Validation Loss=4.72e+10, Validation Rec. Loss=4.18] 

           -- EPOCH 8 --           
Total Loss:                    6183121.6250                  
Rec_arm_1:                     4.4263                        
Rec_arm_2:                     4.2979                        
Joint Loss:                    6139221.0000                  
Entropy:                       -7.5208                       
Distance:                      0.4193                        
Elapsed Time:                  0.6163                        
Validation Loss:               47226417152.0000              
Validation Rec. Loss:          4.1849                        


100%|██████████| 10/10 [00:08<00:00,  1.19it/s, Total Loss=5.43e+6, Rec_arm_1=4.07, Rec_arm_2=4.24, Joint Loss=5.38e+6, Entropy=-7.54, Distance=0.419, Elapsed Time=0.606, Validation Loss=4.31e+10, Validation Rec. Loss=3.93]

           -- EPOCH 9 --           
Total Loss:                    5425349.5000                  
Rec_arm_1:                     4.0713                        
Rec_arm_2:                     4.2377                        
Joint Loss:                    5383538.5000                  
Entropy:                       -7.5447                       
Distance:                      0.4187                        
Elapsed Time:                  0.6056                        
Validation Loss:               43112144896.0000              
Validation Rec. Loss:          3.9337                        





Continue training with pruning ...
Pruned categories: [1]


 20%|██        | 1/5 [00:00<00:02,  1.79it/s, Total Loss=4.73e+6, Rec_arm_1=3.9, Rec_arm_2=3.73, Joint Loss=4.69e+6, Entropy=-7.56, Distance=0.418, Elapsed Time=0.558, Validation Loss=4.05e+10, Validation Rec. Loss=3.55]

           -- EPOCH 0 --           
Total Loss:                    4730384.7500                  
Rec_arm_1:                     3.8995                        
Rec_arm_2:                     3.7282                        
Joint Loss:                    4692002.0000                  
Entropy:                       -7.5631                       
Distance:                      0.4178                        
Elapsed Time:                  0.5580                        
Validation Loss:               40470106112.0000              
Validation Rec. Loss:          3.5467                        


 40%|████      | 2/5 [00:01<00:01,  1.87it/s, Total Loss=4.19e+6, Rec_arm_1=3.43, Rec_arm_2=3.56, Joint Loss=4.16e+6, Entropy=-7.59, Distance=0.414, Elapsed Time=0.516, Validation Loss=3.93e+10, Validation Rec. Loss=3.46]

           -- EPOCH 1 --           
Total Loss:                    4193511.7500                  
Rec_arm_1:                     3.4267                        
Rec_arm_2:                     3.5601                        
Joint Loss:                    4158353.7500                  
Entropy:                       -7.5918                       
Distance:                      0.4141                        
Elapsed Time:                  0.5156                        
Validation Loss:               39319117824.0000              
Validation Rec. Loss:          3.4641                        


 60%|██████    | 3/5 [00:04<00:03,  1.69s/it, Total Loss=4.06e+6, Rec_arm_1=3.33, Rec_arm_2=3.54, Joint Loss=4.03e+6, Entropy=-7.6, Distance=0.417, Elapsed Time=3.07, Validation Loss=3.9e+10, Validation Rec. Loss=3.42]   

           -- EPOCH 2 --           
Total Loss:                    4061759.6250                  
Rec_arm_1:                     3.3324                        
Rec_arm_2:                     3.5391                        
Joint Loss:                    4027181.5000                  
Entropy:                       -7.6043                       
Distance:                      0.4170                        
Elapsed Time:                  3.0675                        
Validation Loss:               38958002176.0000              
Validation Rec. Loss:          3.4183                        


 80%|████████  | 4/5 [00:04<00:01,  1.25s/it, Total Loss=3.7e+6, Rec_arm_1=3.3, Rec_arm_2=3.46, Joint Loss=3.67e+6, Entropy=-7.63, Distance=0.414, Elapsed Time=0.569, Validation Loss=3.9e+10, Validation Rec. Loss=3.34]

           -- EPOCH 3 --           
Total Loss:                    3703567.3750                  
Rec_arm_1:                     3.2965                        
Rec_arm_2:                     3.4604                        
Joint Loss:                    3669566.0000                  
Entropy:                       -7.6331                       
Distance:                      0.4143                        
Elapsed Time:                  0.5690                        
Validation Loss:               38964834304.0000              
Validation Rec. Loss:          3.3427                        


100%|██████████| 5/5 [00:05<00:00,  1.05s/it, Total Loss=3.35e+6, Rec_arm_1=3.22, Rec_arm_2=3.43, Joint Loss=3.32e+6, Entropy=-7.65, Distance=0.408, Elapsed Time=0.524, Validation Loss=3.89e+10, Validation Rec. Loss=3.32]

           -- EPOCH 4 --           
Total Loss:                    3351141.9375                  
Rec_arm_1:                     3.2169                        
Rec_arm_2:                     3.4335                        
Joint Loss:                    3317677.0000                  
Entropy:                       -7.6535                       
Distance:                      0.4082                        
Elapsed Time:                  0.5238                        
Validation Loss:               38904201216.0000              
Validation Rec. Loss:          3.3232                        





Continue training with pruning ...
Pruned categories: [1 2]


 20%|██        | 1/5 [00:00<00:03,  1.09it/s, Total Loss=3.08e+6, Rec_arm_1=3.2, Rec_arm_2=3.43, Joint Loss=3.05e+6, Entropy=-7.64, Distance=0.412, Elapsed Time=0.906, Validation Loss=3.65e+10, Validation Rec. Loss=3.3]

           -- EPOCH 0 --           
Total Loss:                    3084450.5000                  
Rec_arm_1:                     3.1967                        
Rec_arm_2:                     3.4279                        
Joint Loss:                    3051115.2500                  
Entropy:                       -7.6365                       
Distance:                      0.4119                        
Elapsed Time:                  0.9060                        
Validation Loss:               36537438208.0000              
Validation Rec. Loss:          3.3048                        


 40%|████      | 2/5 [00:01<00:02,  1.47it/s, Total Loss=3.04e+6, Rec_arm_1=3.18, Rec_arm_2=3.39, Joint Loss=3e+6, Entropy=-7.65, Distance=0.411, Elapsed Time=0.517, Validation Loss=3.64e+10, Validation Rec. Loss=3.27] 

           -- EPOCH 1 --           
Total Loss:                    3036263.1875                  
Rec_arm_1:                     3.1830                        
Rec_arm_2:                     3.3880                        
Joint Loss:                    3003197.2500                  
Entropy:                       -7.6541                       
Distance:                      0.4108                        
Elapsed Time:                  0.5172                        
Validation Loss:               36379058176.0000              
Validation Rec. Loss:          3.2726                        


 60%|██████    | 3/5 [00:01<00:01,  1.69it/s, Total Loss=2.96e+6, Rec_arm_1=3.15, Rec_arm_2=3.37, Joint Loss=2.93e+6, Entropy=-7.68, Distance=0.405, Elapsed Time=0.48, Validation Loss=3.52e+10, Validation Rec. Loss=3.26]

           -- EPOCH 2 --           
Total Loss:                    2961454.0625                  
Rec_arm_1:                     3.1451                        
Rec_arm_2:                     3.3717                        
Joint Loss:                    2928661.0000                  
Entropy:                       -7.6801                       
Distance:                      0.4051                        
Elapsed Time:                  0.4800                        
Validation Loss:               35229962240.0000              
Validation Rec. Loss:          3.2585                        


 80%|████████  | 4/5 [00:02<00:00,  1.74it/s, Total Loss=2.81e+6, Rec_arm_1=3.13, Rec_arm_2=3.36, Joint Loss=2.77e+6, Entropy=-7.69, Distance=0.404, Elapsed Time=0.549, Validation Loss=3.36e+10, Validation Rec. Loss=3.25]

           -- EPOCH 3 --           
Total Loss:                    2806129.3125                  
Rec_arm_1:                     3.1326                        
Rec_arm_2:                     3.3629                        
Joint Loss:                    2773443.7500                  
Entropy:                       -7.6921                       
Distance:                      0.4040                        
Elapsed Time:                  0.5494                        
Validation Loss:               33568557056.0000              
Validation Rec. Loss:          3.2475                        


100%|██████████| 5/5 [00:02<00:00,  1.67it/s, Total Loss=2.51e+6, Rec_arm_1=3.13, Rec_arm_2=3.35, Joint Loss=2.48e+6, Entropy=-7.69, Distance=0.405, Elapsed Time=0.518, Validation Loss=3.17e+10, Validation Rec. Loss=3.24]

           -- EPOCH 4 --           
Total Loss:                    2511441.1875                  
Rec_arm_1:                     3.1257                        
Rec_arm_2:                     3.3513                        
Joint Loss:                    2478848.0000                  
Entropy:                       -7.6928                       
Distance:                      0.4046                        
Elapsed Time:                  0.5182                        
Validation Loss:               31711250432.0000              
Validation Rec. Loss:          3.2406                        





No more pruning!
Training is done!


In [34]:
importlib.reload(mmidas.cpl_mixvae)
importlib.reload(mmidas.nn_model)
from mmidas.cpl_mixvae import cpl_mixVAE

In [35]:

cplMixVAE = cpl_mixVAE(saving_folder=saving_folder, device='cuda', seed=546)
cplMixVAE.init(categories=n_categories,
                     state_dim=state_dim,
                     input_dim=data['log1p'].shape[1],
                     lowD_dim=latent_dim,
                     lr=lr,
                     arms=n_arm,)

# cplMixVAE.model = t.compile(cplMixVAE.model)

device: Tesla V100-SXM2-32GB


In [36]:
fsdp_dct = cplMixVAE._fsdp(train_loader=trainloader,
                           val_loader=testloader,
                           epochs=n_epoch,
                           n_epoch_p=n_epoch_p,
                           min_con=min_con,
                           max_prun_it=max_prun_it,
                           device=cplMixVAE.device,
                           model=cplMixVAE.model,
                           opt=cplMixVAE.optimizer)

 10%|█         | 1/10 [00:00<00:06,  1.49it/s, Total Loss=tensor(1.6953e+08), Rec_arm_1=tensor(8.7993), Rec_arm_2=tensor(8.7950), Joint Loss=tensor(1.6944e+08), Entropy=tensor(-6.8606), Distance=tensor(1.6944e+08), Validation Loss=tensor(3.6639e+10), Validation Rec. Loss=tensor(8.7700)]

           -- EPOCH 0 --           
Total Loss:                    169525008.0000                
Rec_arm_1:                     8.7993                        
Rec_arm_2:                     8.7950                        
Joint Loss:                    169436480.0000                
Entropy:                       -6.8606                       
Distance:                      169436368.0000                
Validation Loss:               36639199232.0000              
Validation Rec. Loss:          8.7700                        


 20%|██        | 2/10 [00:01<00:04,  1.65it/s, Total Loss=tensor(64385472.), Rec_arm_1=tensor(8.7533), Rec_arm_2=tensor(8.7364), Joint Loss=tensor(64297460.), Entropy=tensor(-7.0197), Distance=tensor(64297360.), Validation Loss=tensor(4.4432e+10), Validation Rec. Loss=tensor(8.6952)]   

           -- EPOCH 1 --           
Total Loss:                    64385472.0000                 
Rec_arm_1:                     8.7533                        
Rec_arm_2:                     8.7364                        
Joint Loss:                    64297460.0000                 
Entropy:                       -7.0197                       
Distance:                      64297360.0000                 
Validation Loss:               44432412672.0000              
Validation Rec. Loss:          8.6952                        


 30%|███       | 3/10 [00:01<00:04,  1.70it/s, Total Loss=tensor(36484356.), Rec_arm_1=tensor(8.6678), Rec_arm_2=tensor(8.6273), Joint Loss=tensor(36397328.), Entropy=tensor(-7.1734), Distance=tensor(36397224.), Validation Loss=tensor(4.4797e+10), Validation Rec. Loss=tensor(8.5529)]

           -- EPOCH 2 --           
Total Loss:                    36484356.0000                 
Rec_arm_1:                     8.6678                        
Rec_arm_2:                     8.6273                        
Joint Loss:                    36397328.0000                 
Entropy:                       -7.1734                       
Distance:                      36397224.0000                 
Validation Loss:               44797005824.0000              
Validation Rec. Loss:          8.5529                        


 40%|████      | 4/10 [00:02<00:03,  1.73it/s, Total Loss=tensor(21116692.), Rec_arm_1=tensor(8.5100), Rec_arm_2=tensor(8.4189), Joint Loss=tensor(21031504.), Entropy=tensor(-7.2842), Distance=tensor(21031404.), Validation Loss=tensor(6.9401e+10), Validation Rec. Loss=tensor(8.2856)]

           -- EPOCH 3 --           
Total Loss:                    21116692.0000                 
Rec_arm_1:                     8.5100                        
Rec_arm_2:                     8.4189                        
Joint Loss:                    21031504.0000                 
Entropy:                       -7.2842                       
Distance:                      21031404.0000                 
Validation Loss:               69400813568.0000              
Validation Rec. Loss:          8.2856                        


 50%|█████     | 5/10 [00:02<00:02,  1.74it/s, Total Loss=tensor(14382813.), Rec_arm_1=tensor(8.2198), Rec_arm_2=tensor(8.0259), Joint Loss=tensor(14301064.), Entropy=tensor(-7.3414), Distance=tensor(14300961.), Validation Loss=tensor(8.7159e+10), Validation Rec. Loss=tensor(7.7980)]

           -- EPOCH 4 --           
Total Loss:                    14382813.0000                 
Rec_arm_1:                     8.2198                        
Rec_arm_2:                     8.0259                        
Joint Loss:                    14301064.0000                 
Entropy:                       -7.3414                       
Distance:                      14300961.0000                 
Validation Loss:               87159382016.0000              
Validation Rec. Loss:          7.7980                        


 60%|██████    | 6/10 [00:03<00:02,  1.75it/s, Total Loss=tensor(11607131.), Rec_arm_1=tensor(7.7031), Rec_arm_2=tensor(7.3309), Joint Loss=tensor(11531479.), Entropy=tensor(-7.3986), Distance=tensor(11531376.), Validation Loss=tensor(7.3136e+10), Validation Rec. Loss=tensor(6.9796)]

           -- EPOCH 5 --           
Total Loss:                    11607131.0000                 
Rec_arm_1:                     7.7031                        
Rec_arm_2:                     7.3309                        
Joint Loss:                    11531479.0000                 
Entropy:                       -7.3986                       
Distance:                      11531376.0000                 
Validation Loss:               73135783936.0000              
Validation Rec. Loss:          6.9796                        


 70%|███████   | 7/10 [00:04<00:01,  1.73it/s, Total Loss=tensor(9361855.), Rec_arm_1=tensor(6.8556), Rec_arm_2=tensor(6.2485), Joint Loss=tensor(9295914.), Entropy=tensor(-7.4392), Distance=tensor(9295811.), Validation Loss=tensor(5.7542e+10), Validation Rec. Loss=tensor(5.7899)]   

           -- EPOCH 6 --           
Total Loss:                    9361855.0000                  
Rec_arm_1:                     6.8556                        
Rec_arm_2:                     6.2485                        
Joint Loss:                    9295914.0000                  
Entropy:                       -7.4392                       
Distance:                      9295811.0000                  
Validation Loss:               57542168576.0000              
Validation Rec. Loss:          5.7899                        


 80%|████████  | 8/10 [00:04<00:01,  1.49it/s, Total Loss=tensor(7491239.), Rec_arm_1=tensor(5.6494), Rec_arm_2=tensor(4.9548), Joint Loss=tensor(7437878.), Entropy=tensor(-7.4865), Distance=tensor(7437776.), Validation Loss=tensor(5.0779e+10), Validation Rec. Loss=tensor(4.5688)]

           -- EPOCH 7 --           
Total Loss:                    7491239.0000                  
Rec_arm_1:                     5.6494                        
Rec_arm_2:                     4.9548                        
Joint Loss:                    7437878.0000                  
Entropy:                       -7.4865                       
Distance:                      7437776.0000                  
Validation Loss:               50779258880.0000              
Validation Rec. Loss:          4.5688                        


 90%|█████████ | 9/10 [00:05<00:00,  1.51it/s, Total Loss=tensor(6183121.5000), Rec_arm_1=tensor(4.4263), Rec_arm_2=tensor(4.2979), Joint Loss=tensor(6139221.), Entropy=tensor(-7.5208), Distance=tensor(6139118.5000), Validation Loss=tensor(4.7226e+10), Validation Rec. Loss=tensor(4.1849)]

           -- EPOCH 8 --           
Total Loss:                    6183121.5000                  
Rec_arm_1:                     4.4263                        
Rec_arm_2:                     4.2979                        
Joint Loss:                    6139221.0000                  
Entropy:                       -7.5208                       
Distance:                      6139118.5000                  
Validation Loss:               47226417152.0000              
Validation Rec. Loss:          4.1849                        


100%|██████████| 10/10 [00:06<00:00,  1.66it/s, Total Loss=tensor(5425349.5000), Rec_arm_1=tensor(4.0713), Rec_arm_2=tensor(4.2377), Joint Loss=tensor(5383538.5000), Entropy=tensor(-7.5447), Distance=tensor(5383436.), Validation Loss=tensor(4.3112e+10), Validation Rec. Loss=tensor(3.9337)]

           -- EPOCH 9 --           
Total Loss:                    5425349.5000                  
Rec_arm_1:                     4.0713                        
Rec_arm_2:                     4.2377                        
Joint Loss:                    5383538.5000                  
Entropy:                       -7.5447                       
Distance:                      5383436.0000                  
Validation Loss:               43112144896.0000              
Validation Rec. Loss:          3.9337                        





In [11]:
# Start training ...
# ====> Epoch:0, Total Loss: 137667456.0000, Rec_arm_1: 8.7999, Rec_arm_2: 8.7966, Joint Loss: 137578912.0000, Entropy: -7.1332, Distance: 0.4805, Elapsed Time:4.13
# ====> Validation Total Loss: 28064712704.0000, Rec. Loss: 8.7721
# ====> Epoch:1, Total Loss: 69825152.0000, Rec_arm_1: 8.7497, Rec_arm_2: 8.7451, Joint Loss: 69737120.0000, Entropy: -7.3747, Distance: 0.4512, Elapsed Time:3.64
# ====> Validation Total Loss: 25941938176.0000, Rec. Loss: 8.6993
# ====> Epoch:2, Total Loss: 42043478.0000, Rec_arm_1: 8.6571, Rec_arm_2: 8.6533, Joint Loss: 41956372.0000, Entropy: -7.3914, Distance: 0.4467, Elapsed Time:3.58
# ====> Validation Total Loss: 22861494272.0000, Rec. Loss: 8.5640
# ====> Epoch:3, Total Loss: 26674182.0000, Rec_arm_1: 8.4823, Rec_arm_2: 8.4839, Joint Loss: 26588808.0000, Entropy: -7.5388, Distance: 0.4278, Elapsed Time:3.60
# ====> Validation Total Loss: 20230934528.0000, Rec. Loss: 8.3146
# ====> Epoch:4, Total Loss: 17169558.2500, Rec_arm_1: 8.1604, Rec_arm_2: 8.1765, Joint Loss: 17087350.0000, Entropy: -7.6280, Distance: 0.4243, Elapsed Time:3.65
# ====> Validation Total Loss: 20135974912.0000, Rec. Loss: 7.8776
# ====> Epoch:5, Total Loss: 11332364.7500, Rec_arm_1: 7.5881, Rec_arm_2: 7.6376, Joint Loss: 11255748.0000, Entropy: -7.7191, Distance: 0.4143, Elapsed Time:3.53
# ====> Validation Total Loss: 23999512576.0000, Rec. Loss: 7.1436
# ====> Epoch:6, Total Loss: 8132679.8750, Rec_arm_1: 6.6561, Rec_arm_2: 6.7642, Joint Loss: 8065148.5000, Entropy: -7.7605, Distance: 0.4065, Elapsed Time:3.60
# ====> Validation Total Loss: 31686617088.0000, Rec. Loss: 6.0345
# ====> Epoch:7, Total Loss: 6335305.1250, Rec_arm_1: 5.3905, Rec_arm_2: 5.5418, Joint Loss: 6280293.0000, Entropy: -7.8043, Distance: 0.4005, Elapsed Time:3.63
# ====> Validation Total Loss: 36792733696.0000, Rec. Loss: 4.7336
# ====> Epoch:8, Total Loss: 5184599.3750, Rec_arm_1: 4.3659, Rec_arm_2: 4.3467, Joint Loss: 5140757.0000, Entropy: -7.8469, Distance: 0.3987, Elapsed Time:3.51
# ====> Validation Total Loss: 38023708672.0000, Rec. Loss: 4.0776
# ====> Epoch:9, Total Loss: 4790232.1250, Rec_arm_1: 4.2830, Rec_arm_2: 4.0376, Joint Loss: 4748362.0000, Entropy: -7.8604, Distance: 0.3998, Elapsed Time:3.57
# ====> Validation Total Loss: 36604887040.0000, Rec. Loss: 3.9312
# Training with pruning...
# Purned categories: [1]
# ====> Epoch:0, Total Loss: 4001048.3125, Rec_arm_1: 3.8611, Rec_arm_2: 3.8509, Joint Loss: 3962241.0000, Entropy: -7.8664, Distance: 0.3989, Elapsed Time:3.85
# ====> Validation Total Loss: nan, Rec. Loss: 3.5026
# ====> Epoch:1, Total Loss: 3639977.4375, Rec_arm_1: 3.5320, Rec_arm_2: 3.3895, Joint Loss: 3605148.2500, Entropy: -7.8894, Distance: 0.3974, Elapsed Time:3.83
# ====> Validation Total Loss: nan, Rec. Loss: 3.4378
# ====> Epoch:2, Total Loss: 3534881.8125, Rec_arm_1: 3.4955, Rec_arm_2: 3.3030, Joint Loss: 3500671.7500, Entropy: -7.9030, Distance: 0.3960, Elapsed Time:4.10
# ====> Validation Total Loss: nan, Rec. Loss: 3.4195
# ====> Epoch:3, Total Loss: 3155906.8125, Rec_arm_1: 3.4152, Rec_arm_2: 3.2708, Joint Loss: 3122262.7500, Entropy: -7.9187, Distance: 0.3939, Elapsed Time:4.40
# ====> Validation Total Loss: nan, Rec. Loss: 3.3164
# ====> Epoch:4, Total Loss: 2953709.7500, Rec_arm_1: 3.3518, Rec_arm_2: 3.1916, Joint Loss: 2920782.5000, Entropy: -7.9333, Distance: 0.3947, Elapsed Time:3.98
# ====> Validation Total Loss: nan, Rec. Loss: 3.2596
# Training with pruning...
# Purned categories: [1 2]
# ====> Epoch:0, Total Loss: 2872254.8750, Rec_arm_1: 3.3458, Rec_arm_2: 3.1647, Joint Loss: 2839493.0000, Entropy: -7.9148, Distance: 0.3927, Elapsed Time:3.81
# ====> Validation Total Loss: nan, Rec. Loss: 3.2327
# ====> Epoch:1, Total Loss: 2638138.1875, Rec_arm_1: 3.3082, Rec_arm_2: 3.1392, Joint Loss: 2605694.2500, Entropy: -7.9227, Distance: 0.3921, Elapsed Time:3.87
# ====> Validation Total Loss: nan, Rec. Loss: 3.1979
# ====> Epoch:2, Total Loss: 2455072.6250, Rec_arm_1: 3.2809, Rec_arm_2: 3.0947, Joint Loss: 2422990.5000, Entropy: -7.9277, Distance: 0.3917, Elapsed Time:3.78
# ====> Validation Total Loss: nan, Rec. Loss: 3.1861
# ====> Epoch:3, Total Loss: 2434174.7500, Rec_arm_1: 3.2733, Rec_arm_2: 3.0776, Joint Loss: 2402216.5000, Entropy: -7.9398, Distance: 0.3898, Elapsed Time:3.97
# ====> Validation Total Loss: nan, Rec. Loss: 3.1811
# ====> Epoch:4, Total Loss: 2520542.7500, Rec_arm_1: 3.2612, Rec_arm_2: 3.0682, Joint Loss: 2488693.0000, Entropy: -7.9395, Distance: 0.3899, Elapsed Time:3.85
# ====> Validation Total Loss: nan, Rec. Loss: 3.1711
# No more pruning!
# Training is done!

In [12]:
def join(s, *args, **kwargs): 
    return s.join(*args, **kwargs)

join('', ['a', 'b', 'c'])

'abc'

Working directly with command line, you have the option to train the model using a Python file, such as ```tutorial/train_unimodal.py``` as follows.

```
python train_unimodal.py --n_epoch 10 --n_epoch_p 5 --max_prun_it 2
```
or
```
python train_unimodal.py --n_epoch 10 --n_epoch_p 5 --max_prun_it 2 --device 'cuda'
```