### Training MMIDAS - a coupled mixture VAE model
This notebook guides you through the process of training a mixture variational autoencoder.

In [2]:
import mmidas
from mmidas.nn_model import mixVAEConfig, loss_fn
from mmidas.cpl_mixvae import cpl_mixVAE
from mmidas.utils.tools import get_paths
from mmidas.utils.dataloader import load_data, get_loaders
import importlib
import os
import torch as t

import warnings
warnings.filterwarnings('ignore')

In [33]:
t.cuda.current_device()

0

In [2]:
t.cuda.is_available()
t.cuda.device_count()

4

Specify the training parameters.

In [3]:
n_categories = 120 # upper bound of number of categories (clusters)
state_dim = 2 # continuous (state) variable dimensionality 
n_arm = 2 # number of arms
latent_dim = 10 # latent dimensionality of the model
batch_size = 5000 # mini-batch size for training
n_epoch = 10 # number of epochs for training
n_epoch_p = 5 # number of epochs for pruning
min_con = 0.9 # minimum consensus among arms
max_prun_it = 2 # maximum number of pruning iterations
batch_size = 5000 # mini-batch size for training
lr = 1e-3 # learning rate for training

Load the prepared data (as described in ```1_data_prep.ipynb```) and create training and validation sets.

In [4]:
toml_file = 'pyproject.toml'
sub_file = 'smartseq_files'
config = get_paths(toml_file=toml_file, sub_file=sub_file)
data_path = config['paths']['main_dir'] / config['paths']['data_path']
data_file = data_path / config[sub_file]['anndata_file']

/allen/programs/celltypes/workgroups/mousecelltypes/Hilal/MMIDAS/pyproject.toml
Getting files directories belong to smartseq_files...


In [5]:
data = load_data(datafile=data_file)
trainloader, testloader, _, = get_loaders(dataset=data['log1p'], batch_size=batch_size)

data is loaded!
 --------- Data Summary --------- 
num cell types: 115, num cells: 22365, num genes:5032


Create a designated folder to store training files

In [6]:
n_run = 1
augmentation = False
folder_name = f'run_{n_run}_K_{n_categories}_Sdim_{state_dim}_aug_{augmentation}_lr_{lr}_n_arm_{n_arm}_nbatch_{batch_size}' + \
            f'_nepoch_{n_epoch}_nepochP_{n_epoch_p}'
saving_folder = config['paths']['main_dir'] / config['paths']['saving_path']
saving_folder = saving_folder / folder_name
os.makedirs(saving_folder, exist_ok=True)
os.makedirs(saving_folder / 'model', exist_ok=True)
saving_folder = str(saving_folder)

Construct a cpl-mixVAE object and launch its training on the prepared data.

In [31]:
importlib.reload(mmidas.cpl_mixvae)
importlib.reload(mmidas.nn_model)
from mmidas.cpl_mixvae import cpl_mixVAE

cplMixVAE = cpl_mixVAE(saving_folder=saving_folder, device='cuda', seed=546)
cplMixVAE.init(categories=n_categories,
                     state_dim=state_dim,
                     input_dim=data['log1p'].shape[1],
                     lowD_dim=latent_dim,
                     lr=lr,
                     arms=n_arm)

model_file = cplMixVAE.train(train_loader=trainloader,
                             test_loader=testloader,
                             n_epoch=n_epoch,
                             n_epoch_p=n_epoch_p,
                             min_con=min_con,
                             max_prun_it=max_prun_it)

device: NVIDIA A100-PCIE-40GB
Start training ...


 10%|█         | 1/10 [00:00<00:06,  1.39it/s, Total Loss=1.7e+8, Rec_arm_1=8.8, Rec_arm_2=8.8, Joint Loss=1.7e+8, Entropy=-6.88, Distance=0.508, Elapsed Time=0.718, Validation Loss=2.67e+10, Validation Rec. Loss=8.77]

           -- EPOCH 0 --           
Total Loss:                    169748142.0000                
Rec_arm_1:                     8.7996                        
Rec_arm_2:                     8.7953                        
Joint Loss:                    169659600.0000                
Entropy:                       -6.8828                       
Distance:                      0.5078                        
Elapsed Time:                  0.7175                        
Validation Loss:               26710128640.0000              
Validation Rec. Loss:          8.7692                        


 20%|██        | 2/10 [00:01<00:04,  1.81it/s, Total Loss=6.55e+7, Rec_arm_1=8.75, Rec_arm_2=8.74, Joint Loss=6.54e+7, Entropy=-6.96, Distance=0.486, Elapsed Time=0.435, Validation Loss=2.86e+10, Validation Rec. Loss=8.69]

           -- EPOCH 1 --           
Total Loss:                    65482671.0000                 
Rec_arm_1:                     8.7534                        
Rec_arm_2:                     8.7366                        
Joint Loss:                    65394656.0000                 
Entropy:                       -6.9614                       
Distance:                      0.4856                        
Elapsed Time:                  0.4350                        
Validation Loss:               28589053952.0000              
Validation Rec. Loss:          8.6938                        


 30%|███       | 3/10 [00:01<00:04,  1.71it/s, Total Loss=3.31e+7, Rec_arm_1=8.67, Rec_arm_2=8.63, Joint Loss=3.3e+7, Entropy=-7.09, Distance=0.466, Elapsed Time=0.618, Validation Loss=3.72e+10, Validation Rec. Loss=8.55] 

           -- EPOCH 2 --           
Total Loss:                    33122370.5000                 
Rec_arm_1:                     8.6686                        
Rec_arm_2:                     8.6281                        
Joint Loss:                    33035336.0000                 
Entropy:                       -7.0932                       
Distance:                      0.4662                        
Elapsed Time:                  0.6178                        
Validation Loss:               37221650432.0000              
Validation Rec. Loss:          8.5526                        


 40%|████      | 4/10 [00:02<00:03,  1.92it/s, Total Loss=2.12e+7, Rec_arm_1=8.51, Rec_arm_2=8.42, Joint Loss=2.11e+7, Entropy=-7.24, Distance=0.453, Elapsed Time=0.426, Validation Loss=5.11e+10, Validation Rec. Loss=8.29]

           -- EPOCH 3 --           
Total Loss:                    21219178.0000                 
Rec_arm_1:                     8.5115                        
Rec_arm_2:                     8.4205                        
Joint Loss:                    21133976.0000                 
Entropy:                       -7.2435                       
Distance:                      0.4528                        
Elapsed Time:                  0.4264                        
Validation Loss:               51132751872.0000              
Validation Rec. Loss:          8.2870                        


 50%|█████     | 5/10 [00:02<00:02,  2.05it/s, Total Loss=1.34e+7, Rec_arm_1=8.22, Rec_arm_2=8.03, Joint Loss=1.33e+7, Entropy=-7.37, Distance=0.434, Elapsed Time=0.426, Validation Loss=5.58e+10, Validation Rec. Loss=7.8] 

           -- EPOCH 4 --           
Total Loss:                    13367817.0000                 
Rec_arm_1:                     8.2216                        
Rec_arm_2:                     8.0294                        
Joint Loss:                    13286041.0000                 
Entropy:                       -7.3693                       
Distance:                      0.4344                        
Elapsed Time:                  0.4256                        
Validation Loss:               55789199360.0000              
Validation Rec. Loss:          7.7991                        


 60%|██████    | 6/10 [00:03<00:02,  1.86it/s, Total Loss=1e+7, Rec_arm_1=7.71, Rec_arm_2=7.34, Joint Loss=9.97e+6, Entropy=-7.38, Distance=0.436, Elapsed Time=0.634, Validation Loss=5.52e+10, Validation Rec. Loss=6.98]  

           -- EPOCH 5 --           
Total Loss:                    10043858.2500                 
Rec_arm_1:                     7.7055                        
Rec_arm_2:                     7.3390                        
Joint Loss:                    9968154.0000                  
Entropy:                       -7.3769                       
Distance:                      0.4358                        
Elapsed Time:                  0.6342                        
Validation Loss:               55231025152.0000              
Validation Rec. Loss:          6.9806                        


 70%|███████   | 7/10 [00:03<00:01,  1.99it/s, Total Loss=8.46e+6, Rec_arm_1=6.85, Rec_arm_2=6.26, Joint Loss=8.39e+6, Entropy=-7.43, Distance=0.433, Elapsed Time=0.425, Validation Loss=5.81e+10, Validation Rec. Loss=5.79]

           -- EPOCH 6 --           
Total Loss:                    8460194.3750                  
Rec_arm_1:                     6.8542                        
Rec_arm_2:                     6.2647                        
Joint Loss:                    8394180.0000                  
Entropy:                       -7.4284                       
Distance:                      0.4328                        
Elapsed Time:                  0.4248                        
Validation Loss:               58122186752.0000              
Validation Rec. Loss:          5.7926                        


 80%|████████  | 8/10 [00:04<00:01,  1.86it/s, Total Loss=6.83e+6, Rec_arm_1=5.64, Rec_arm_2=4.98, Joint Loss=6.78e+6, Entropy=-7.48, Distance=0.424, Elapsed Time=0.616, Validation Loss=6.04e+10, Validation Rec. Loss=4.58]

           -- EPOCH 7 --           
Total Loss:                    6829084.6250                  
Rec_arm_1:                     5.6395                        
Rec_arm_2:                     4.9762                        
Joint Loss:                    6775666.0000                  
Entropy:                       -7.4775                       
Distance:                      0.4240                        
Elapsed Time:                  0.6159                        
Validation Loss:               60393889792.0000              
Validation Rec. Loss:          4.5808                        


 90%|█████████ | 9/10 [00:04<00:00,  1.98it/s, Total Loss=6.25e+6, Rec_arm_1=4.42, Rec_arm_2=4.3, Joint Loss=6.21e+6, Entropy=-7.53, Distance=0.42, Elapsed Time=0.433, Validation Loss=5.73e+10, Validation Rec. Loss=4.2]   

           -- EPOCH 8 --           
Total Loss:                    6251051.0000                  
Rec_arm_1:                     4.4184                        
Rec_arm_2:                     4.3025                        
Joint Loss:                    6207167.0000                  
Entropy:                       -7.5261                       
Distance:                      0.4201                        
Elapsed Time:                  0.4334                        
Validation Loss:               57300779008.0000              
Validation Rec. Loss:          4.1969                        


100%|██████████| 10/10 [00:05<00:00,  1.92it/s, Total Loss=5.5e+6, Rec_arm_1=4.09, Rec_arm_2=4.25, Joint Loss=5.46e+6, Entropy=-7.57, Distance=0.414, Elapsed Time=0.465, Validation Loss=5.08e+10, Validation Rec. Loss=4.01]

           -- EPOCH 9 --           
Total Loss:                    5497651.3750                  
Rec_arm_1:                     4.0858                        
Rec_arm_2:                     4.2468                        
Joint Loss:                    5455721.5000                  
Entropy:                       -7.5713                       
Distance:                      0.4139                        
Elapsed Time:                  0.4655                        
Validation Loss:               50776092672.0000              
Validation Rec. Loss:          4.0118                        





Continue training with pruning ...
Pruned categories: [1]


 20%|██        | 1/5 [00:00<00:01,  2.04it/s, Total Loss=4.64e+6, Rec_arm_1=3.9, Rec_arm_2=3.74, Joint Loss=4.6e+6, Entropy=-7.57, Distance=0.415, Elapsed Time=0.489, Validation Loss=4.18e+10, Validation Rec. Loss=3.56]

           -- EPOCH 0 --           
Total Loss:                    4641829.3750                  
Rec_arm_1:                     3.8957                        
Rec_arm_2:                     3.7443                        
Joint Loss:                    4603384.5000                  
Entropy:                       -7.5727                       
Distance:                      0.4155                        
Elapsed Time:                  0.4889                        
Validation Loss:               41787138048.0000              
Validation Rec. Loss:          3.5618                        


 40%|████      | 2/5 [00:01<00:02,  1.35it/s, Total Loss=4.42e+6, Rec_arm_1=3.43, Rec_arm_2=3.57, Joint Loss=4.39e+6, Entropy=-7.61, Distance=0.41, Elapsed Time=0.913, Validation Loss=3.78e+10, Validation Rec. Loss=3.44]

           -- EPOCH 1 --           
Total Loss:                    4420528.2500                  
Rec_arm_1:                     3.4265                        
Rec_arm_2:                     3.5658                        
Joint Loss:                    4385342.5000                  
Entropy:                       -7.6094                       
Distance:                      0.4100                        
Elapsed Time:                  0.9133                        
Validation Loss:               37824999424.0000              
Validation Rec. Loss:          3.4384                        


 60%|██████    | 3/5 [00:01<00:01,  1.50it/s, Total Loss=3.95e+6, Rec_arm_1=3.34, Rec_arm_2=3.55, Joint Loss=3.92e+6, Entropy=-7.62, Distance=0.414, Elapsed Time=0.581, Validation Loss=3.56e+10, Validation Rec. Loss=3.4]

           -- EPOCH 2 --           
Total Loss:                    3950579.4375                  
Rec_arm_1:                     3.3368                        
Rec_arm_2:                     3.5465                        
Joint Loss:                    3915942.5000                  
Entropy:                       -7.6212                       
Distance:                      0.4141                        
Elapsed Time:                  0.5814                        
Validation Loss:               35610374144.0000              
Validation Rec. Loss:          3.3982                        


 80%|████████  | 4/5 [00:02<00:00,  1.32it/s, Total Loss=3.85e+6, Rec_arm_1=3.3, Rec_arm_2=3.47, Joint Loss=3.81e+6, Entropy=-7.63, Distance=0.416, Elapsed Time=0.892, Validation Loss=3.33e+10, Validation Rec. Loss=3.34]

           -- EPOCH 3 --           
Total Loss:                    3845911.7500                  
Rec_arm_1:                     3.2980                        
Rec_arm_2:                     3.4702                        
Joint Loss:                    3811853.5000                  
Entropy:                       -7.6339                       
Distance:                      0.4159                        
Elapsed Time:                  0.8921                        
Validation Loss:               33270839296.0000              
Validation Rec. Loss:          3.3357                        


100%|██████████| 5/5 [00:03<00:00,  1.46it/s, Total Loss=3.48e+6, Rec_arm_1=3.22, Rec_arm_2=3.44, Joint Loss=3.44e+6, Entropy=-7.64, Distance=0.406, Elapsed Time=0.542, Validation Loss=3.3e+10, Validation Rec. Loss=3.33]

           -- EPOCH 4 --           
Total Loss:                    3475475.8125                  
Rec_arm_1:                     3.2165                        
Rec_arm_2:                     3.4411                        
Joint Loss:                    3441974.7500                  
Entropy:                       -7.6382                       
Distance:                      0.4062                        
Elapsed Time:                  0.5416                        
Validation Loss:               33047609344.0000              
Validation Rec. Loss:          3.3272                        





Continue training with pruning ...
Pruned categories: [1 2]


 20%|██        | 1/5 [00:00<00:01,  2.18it/s, Total Loss=3.28e+6, Rec_arm_1=3.19, Rec_arm_2=3.44, Joint Loss=3.25e+6, Entropy=-7.63, Distance=0.411, Elapsed Time=0.457, Validation Loss=3.17e+10, Validation Rec. Loss=3.31]

           -- EPOCH 0 --           
Total Loss:                    3284169.5000                  
Rec_arm_1:                     3.1944                        
Rec_arm_2:                     3.4369                        
Joint Loss:                    3250800.0000                  
Entropy:                       -7.6314                       
Distance:                      0.4106                        
Elapsed Time:                  0.4570                        
Validation Loss:               31715076096.0000              
Validation Rec. Loss:          3.3118                        


 40%|████      | 2/5 [00:00<00:01,  2.19it/s, Total Loss=3.04e+6, Rec_arm_1=3.18, Rec_arm_2=3.4, Joint Loss=3.01e+6, Entropy=-7.65, Distance=0.41, Elapsed Time=0.454, Validation Loss=2.98e+10, Validation Rec. Loss=3.27]  

           -- EPOCH 1 --           
Total Loss:                    3043093.3750                  
Rec_arm_1:                     3.1751                        
Rec_arm_2:                     3.3986                        
Joint Loss:                    3010014.5000                  
Entropy:                       -7.6484                       
Distance:                      0.4098                        
Elapsed Time:                  0.4542                        
Validation Loss:               29846585344.0000              
Validation Rec. Loss:          3.2655                        


 60%|██████    | 3/5 [00:01<00:01,  1.80it/s, Total Loss=2.95e+6, Rec_arm_1=3.13, Rec_arm_2=3.38, Joint Loss=2.92e+6, Entropy=-7.66, Distance=0.405, Elapsed Time=0.673, Validation Loss=2.68e+10, Validation Rec. Loss=3.25]

           -- EPOCH 2 --           
Total Loss:                    2954351.1875                  
Rec_arm_1:                     3.1344                        
Rec_arm_2:                     3.3814                        
Joint Loss:                    2921563.5000                  
Entropy:                       -7.6620                       
Distance:                      0.4047                        
Elapsed Time:                  0.6734                        
Validation Loss:               26805893120.0000              
Validation Rec. Loss:          3.2454                        


 80%|████████  | 4/5 [00:02<00:00,  1.93it/s, Total Loss=2.77e+6, Rec_arm_1=3.12, Rec_arm_2=3.37, Joint Loss=2.73e+6, Entropy=-7.67, Distance=0.404, Elapsed Time=0.454, Validation Loss=2.47e+10, Validation Rec. Loss=3.24]

           -- EPOCH 3 --           
Total Loss:                    2767516.3125                  
Rec_arm_1:                     3.1211                        
Rec_arm_2:                     3.3730                        
Joint Loss:                    2734837.5000                  
Entropy:                       -7.6743                       
Distance:                      0.4038                        
Elapsed Time:                  0.4542                        
Validation Loss:               24659744768.0000              
Validation Rec. Loss:          3.2373                        


100%|██████████| 5/5 [00:02<00:00,  1.83it/s, Total Loss=2.72e+6, Rec_arm_1=3.11, Rec_arm_2=3.36, Joint Loss=2.69e+6, Entropy=-7.69, Distance=0.402, Elapsed Time=0.677, Validation Loss=2.32e+10, Validation Rec. Loss=3.23]

           -- EPOCH 4 --           
Total Loss:                    2724886.8750                  
Rec_arm_1:                     3.1126                        
Rec_arm_2:                     3.3619                        
Joint Loss:                    2692306.7500                  
Entropy:                       -7.6895                       
Distance:                      0.4022                        
Elapsed Time:                  0.6765                        
Validation Loss:               23217799168.0000              
Validation Rec. Loss:          3.2327                        





No more pruning!
Training is done!


In [14]:
x = t.zeros(1)

x += t.tensor([5])

x

tensor([5.])

In [32]:
importlib.reload(mmidas.cpl_mixvae)
importlib.reload(mmidas.nn_model)
from mmidas.cpl_mixvae import cpl_mixVAE

cplMixVAE = cpl_mixVAE(saving_folder=saving_folder, device='cuda', seed=546)
cplMixVAE.init(categories=n_categories,
                     state_dim=state_dim,
                     input_dim=data['log1p'].shape[1],
                     lowD_dim=latent_dim,
                     lr=lr,
                     arms=n_arm,)
fsdp_dct = cplMixVAE._fsdp(train_loader=trainloader,
                           val_loader=testloader,
                           epochs=n_epoch,
                           n_epoch_p=n_epoch_p,
                           min_con=min_con,
                           max_prun_it=max_prun_it,
                           device=cplMixVAE.device,
                           model=cplMixVAE.model,
                           opt=cplMixVAE.optimizer,
                           rank=cplMixVAE.device)

device: NVIDIA A100-PCIE-40GB


 10%|█         | 1/10 [00:00<00:07,  1.25it/s, Total Loss=tensor(1.6975e+08, device='cuda:0'), Rec_arm_1=tensor(8.7996, device='cuda:0'), Rec_arm_2=tensor(8.7953, device='cuda:0'), Joint Loss=tensor(1.6966e+08, device='cuda:0'), Entropy=tensor(-6.8828, device='cuda:0'), Distance=tensor(1.6966e+08, device='cuda:0')]

           -- EPOCH 0 --           
Total Loss:                    169748144.0000                
Rec_arm_1:                     8.7996                        
Rec_arm_2:                     8.7953                        
Joint Loss:                    169659600.0000                
Entropy:                       -6.8828                       
Distance:                      169659488.0000                


 20%|██        | 2/10 [00:01<00:04,  1.65it/s, Total Loss=tensor(65578188., device='cuda:0'), Rec_arm_1=tensor(8.7534, device='cuda:0'), Rec_arm_2=tensor(8.7366, device='cuda:0'), Joint Loss=tensor(65490176., device='cuda:0'), Entropy=tensor(-6.9559, device='cuda:0'), Distance=tensor(65490076., device='cuda:0')]   

           -- EPOCH 1 --           
Total Loss:                    65578188.0000                 
Rec_arm_1:                     8.7534                        
Rec_arm_2:                     8.7366                        
Joint Loss:                    65490176.0000                 
Entropy:                       -6.9559                       
Distance:                      65490076.0000                 


 30%|███       | 3/10 [00:01<00:03,  1.83it/s, Total Loss=tensor(35376064., device='cuda:0'), Rec_arm_1=tensor(8.6682, device='cuda:0'), Rec_arm_2=tensor(8.6280, device='cuda:0'), Joint Loss=tensor(35289028., device='cuda:0'), Entropy=tensor(-7.1085, device='cuda:0'), Distance=tensor(35288928., device='cuda:0')]

           -- EPOCH 2 --           
Total Loss:                    35376064.0000                 
Rec_arm_1:                     8.6682                        
Rec_arm_2:                     8.6280                        
Joint Loss:                    35289028.0000                 
Entropy:                       -7.1085                       
Distance:                      35288928.0000                 


 40%|████      | 4/10 [00:02<00:03,  1.68it/s, Total Loss=tensor(20463328., device='cuda:0'), Rec_arm_1=tensor(8.5101, device='cuda:0'), Rec_arm_2=tensor(8.4201, device='cuda:0'), Joint Loss=tensor(20378136., device='cuda:0'), Entropy=tensor(-7.2124, device='cuda:0'), Distance=tensor(20378034., device='cuda:0')]

           -- EPOCH 3 --           
Total Loss:                    20463328.0000                 
Rec_arm_1:                     8.5101                        
Rec_arm_2:                     8.4201                        
Joint Loss:                    20378136.0000                 
Entropy:                       -7.2124                       
Distance:                      20378034.0000                 


 50%|█████     | 5/10 [00:02<00:02,  1.81it/s, Total Loss=tensor(14095076., device='cuda:0'), Rec_arm_1=tensor(8.2182, device='cuda:0'), Rec_arm_2=tensor(8.0282, device='cuda:0'), Joint Loss=tensor(14013324., device='cuda:0'), Entropy=tensor(-7.2897, device='cuda:0'), Distance=tensor(14013221., device='cuda:0')]

           -- EPOCH 4 --           
Total Loss:                    14095076.0000                 
Rec_arm_1:                     8.2182                        
Rec_arm_2:                     8.0282                        
Joint Loss:                    14013324.0000                 
Entropy:                       -7.2897                       
Distance:                      14013221.0000                 


 60%|██████    | 6/10 [00:03<00:02,  1.90it/s, Total Loss=tensor(10539843., device='cuda:0'), Rec_arm_1=tensor(7.6979, device='cuda:0'), Rec_arm_2=tensor(7.3367, device='cuda:0'), Joint Loss=tensor(10464188., device='cuda:0'), Entropy=tensor(-7.3604, device='cuda:0'), Distance=tensor(10464085., device='cuda:0')]

           -- EPOCH 5 --           
Total Loss:                    10539843.0000                 
Rec_arm_1:                     7.6979                        
Rec_arm_2:                     7.3367                        
Joint Loss:                    10464188.0000                 
Entropy:                       -7.3604                       
Distance:                      10464085.0000                 


 70%|███████   | 7/10 [00:04<00:01,  1.73it/s, Total Loss=tensor(8475192., device='cuda:0'), Rec_arm_1=tensor(6.8419, device='cuda:0'), Rec_arm_2=tensor(6.2607, device='cuda:0'), Joint Loss=tensor(8409260., device='cuda:0'), Entropy=tensor(-7.4328, device='cuda:0'), Distance=tensor(8409157., device='cuda:0')]   

           -- EPOCH 6 --           
Total Loss:                    8475192.0000                  
Rec_arm_1:                     6.8419                        
Rec_arm_2:                     6.2607                        
Joint Loss:                    8409260.0000                  
Entropy:                       -7.4328                       
Distance:                      8409157.0000                  


 80%|████████  | 8/10 [00:04<00:01,  1.84it/s, Total Loss=tensor(7336511., device='cuda:0'), Rec_arm_1=tensor(5.6224, device='cuda:0'), Rec_arm_2=tensor(4.9703, device='cuda:0'), Joint Loss=tensor(7283208.5000, device='cuda:0'), Entropy=tensor(-7.4925, device='cuda:0'), Distance=tensor(7283106., device='cuda:0')]

           -- EPOCH 7 --           
Total Loss:                    7336511.0000                  
Rec_arm_1:                     5.6224                        
Rec_arm_2:                     4.9703                        
Joint Loss:                    7283208.5000                  
Entropy:                       -7.4925                       
Distance:                      7283106.0000                  


 90%|█████████ | 9/10 [00:05<00:00,  1.72it/s, Total Loss=tensor(5922421., device='cuda:0'), Rec_arm_1=tensor(4.4031, device='cuda:0'), Rec_arm_2=tensor(4.2988, device='cuda:0'), Joint Loss=tensor(5878633., device='cuda:0'), Entropy=tensor(-7.5280, device='cuda:0'), Distance=tensor(5878530., device='cuda:0')]    

           -- EPOCH 8 --           
Total Loss:                    5922421.0000                  
Rec_arm_1:                     4.4031                        
Rec_arm_2:                     4.2988                        
Joint Loss:                    5878633.0000                  
Entropy:                       -7.5280                       
Distance:                      5878530.0000                  


100%|██████████| 10/10 [00:05<00:00,  1.76it/s, Total Loss=tensor(5494302.5000, device='cuda:0'), Rec_arm_1=tensor(4.0783, device='cuda:0'), Rec_arm_2=tensor(4.2416, device='cuda:0'), Joint Loss=tensor(5452436.5000, device='cuda:0'), Entropy=tensor(-7.5488, device='cuda:0'), Distance=tensor(5452334., device='cuda:0')]

           -- EPOCH 9 --           
Total Loss:                    5494302.5000                  
Rec_arm_1:                     4.0783                        
Rec_arm_2:                     4.2416                        
Joint Loss:                    5452436.5000                  
Entropy:                       -7.5488                       
Distance:                      5452334.0000                  





In [11]:
# Start training ...
# ====> Epoch:0, Total Loss: 137667456.0000, Rec_arm_1: 8.7999, Rec_arm_2: 8.7966, Joint Loss: 137578912.0000, Entropy: -7.1332, Distance: 0.4805, Elapsed Time:4.13
# ====> Validation Total Loss: 28064712704.0000, Rec. Loss: 8.7721
# ====> Epoch:1, Total Loss: 69825152.0000, Rec_arm_1: 8.7497, Rec_arm_2: 8.7451, Joint Loss: 69737120.0000, Entropy: -7.3747, Distance: 0.4512, Elapsed Time:3.64
# ====> Validation Total Loss: 25941938176.0000, Rec. Loss: 8.6993
# ====> Epoch:2, Total Loss: 42043478.0000, Rec_arm_1: 8.6571, Rec_arm_2: 8.6533, Joint Loss: 41956372.0000, Entropy: -7.3914, Distance: 0.4467, Elapsed Time:3.58
# ====> Validation Total Loss: 22861494272.0000, Rec. Loss: 8.5640
# ====> Epoch:3, Total Loss: 26674182.0000, Rec_arm_1: 8.4823, Rec_arm_2: 8.4839, Joint Loss: 26588808.0000, Entropy: -7.5388, Distance: 0.4278, Elapsed Time:3.60
# ====> Validation Total Loss: 20230934528.0000, Rec. Loss: 8.3146
# ====> Epoch:4, Total Loss: 17169558.2500, Rec_arm_1: 8.1604, Rec_arm_2: 8.1765, Joint Loss: 17087350.0000, Entropy: -7.6280, Distance: 0.4243, Elapsed Time:3.65
# ====> Validation Total Loss: 20135974912.0000, Rec. Loss: 7.8776
# ====> Epoch:5, Total Loss: 11332364.7500, Rec_arm_1: 7.5881, Rec_arm_2: 7.6376, Joint Loss: 11255748.0000, Entropy: -7.7191, Distance: 0.4143, Elapsed Time:3.53
# ====> Validation Total Loss: 23999512576.0000, Rec. Loss: 7.1436
# ====> Epoch:6, Total Loss: 8132679.8750, Rec_arm_1: 6.6561, Rec_arm_2: 6.7642, Joint Loss: 8065148.5000, Entropy: -7.7605, Distance: 0.4065, Elapsed Time:3.60
# ====> Validation Total Loss: 31686617088.0000, Rec. Loss: 6.0345
# ====> Epoch:7, Total Loss: 6335305.1250, Rec_arm_1: 5.3905, Rec_arm_2: 5.5418, Joint Loss: 6280293.0000, Entropy: -7.8043, Distance: 0.4005, Elapsed Time:3.63
# ====> Validation Total Loss: 36792733696.0000, Rec. Loss: 4.7336
# ====> Epoch:8, Total Loss: 5184599.3750, Rec_arm_1: 4.3659, Rec_arm_2: 4.3467, Joint Loss: 5140757.0000, Entropy: -7.8469, Distance: 0.3987, Elapsed Time:3.51
# ====> Validation Total Loss: 38023708672.0000, Rec. Loss: 4.0776
# ====> Epoch:9, Total Loss: 4790232.1250, Rec_arm_1: 4.2830, Rec_arm_2: 4.0376, Joint Loss: 4748362.0000, Entropy: -7.8604, Distance: 0.3998, Elapsed Time:3.57
# ====> Validation Total Loss: 36604887040.0000, Rec. Loss: 3.9312
# Training with pruning...
# Purned categories: [1]
# ====> Epoch:0, Total Loss: 4001048.3125, Rec_arm_1: 3.8611, Rec_arm_2: 3.8509, Joint Loss: 3962241.0000, Entropy: -7.8664, Distance: 0.3989, Elapsed Time:3.85
# ====> Validation Total Loss: nan, Rec. Loss: 3.5026
# ====> Epoch:1, Total Loss: 3639977.4375, Rec_arm_1: 3.5320, Rec_arm_2: 3.3895, Joint Loss: 3605148.2500, Entropy: -7.8894, Distance: 0.3974, Elapsed Time:3.83
# ====> Validation Total Loss: nan, Rec. Loss: 3.4378
# ====> Epoch:2, Total Loss: 3534881.8125, Rec_arm_1: 3.4955, Rec_arm_2: 3.3030, Joint Loss: 3500671.7500, Entropy: -7.9030, Distance: 0.3960, Elapsed Time:4.10
# ====> Validation Total Loss: nan, Rec. Loss: 3.4195
# ====> Epoch:3, Total Loss: 3155906.8125, Rec_arm_1: 3.4152, Rec_arm_2: 3.2708, Joint Loss: 3122262.7500, Entropy: -7.9187, Distance: 0.3939, Elapsed Time:4.40
# ====> Validation Total Loss: nan, Rec. Loss: 3.3164
# ====> Epoch:4, Total Loss: 2953709.7500, Rec_arm_1: 3.3518, Rec_arm_2: 3.1916, Joint Loss: 2920782.5000, Entropy: -7.9333, Distance: 0.3947, Elapsed Time:3.98
# ====> Validation Total Loss: nan, Rec. Loss: 3.2596
# Training with pruning...
# Purned categories: [1 2]
# ====> Epoch:0, Total Loss: 2872254.8750, Rec_arm_1: 3.3458, Rec_arm_2: 3.1647, Joint Loss: 2839493.0000, Entropy: -7.9148, Distance: 0.3927, Elapsed Time:3.81
# ====> Validation Total Loss: nan, Rec. Loss: 3.2327
# ====> Epoch:1, Total Loss: 2638138.1875, Rec_arm_1: 3.3082, Rec_arm_2: 3.1392, Joint Loss: 2605694.2500, Entropy: -7.9227, Distance: 0.3921, Elapsed Time:3.87
# ====> Validation Total Loss: nan, Rec. Loss: 3.1979
# ====> Epoch:2, Total Loss: 2455072.6250, Rec_arm_1: 3.2809, Rec_arm_2: 3.0947, Joint Loss: 2422990.5000, Entropy: -7.9277, Distance: 0.3917, Elapsed Time:3.78
# ====> Validation Total Loss: nan, Rec. Loss: 3.1861
# ====> Epoch:3, Total Loss: 2434174.7500, Rec_arm_1: 3.2733, Rec_arm_2: 3.0776, Joint Loss: 2402216.5000, Entropy: -7.9398, Distance: 0.3898, Elapsed Time:3.97
# ====> Validation Total Loss: nan, Rec. Loss: 3.1811
# ====> Epoch:4, Total Loss: 2520542.7500, Rec_arm_1: 3.2612, Rec_arm_2: 3.0682, Joint Loss: 2488693.0000, Entropy: -7.9395, Distance: 0.3899, Elapsed Time:3.85
# ====> Validation Total Loss: nan, Rec. Loss: 3.1711
# No more pruning!
# Training is done!

In [12]:
def join(s, *args, **kwargs): 
    return s.join(*args, **kwargs)

join('', ['a', 'b', 'c'])

'abc'

Working directly with command line, you have the option to train the model using a Python file, such as ```tutorial/train_unimodal.py``` as follows.

```
python train_unimodal.py --n_epoch 10 --n_epoch_p 5 --max_prun_it 2
```
or
```
python train_unimodal.py --n_epoch 10 --n_epoch_p 5 --max_prun_it 2 --device 'cuda'
```