### Training MMIDAS - a coupled mixture VAE model
This notebook guides you through the process of training a mixture variational autoencoder.

In [1]:
import mmidas
from mmidas.nn_model import mixVAEConfig, loss_fn
from mmidas.cpl_mixvae import cpl_mixVAE
from mmidas.utils.tools import get_paths
from mmidas.utils.dataloader import load_data, get_loaders
import importlib
import os
import torch

import warnings
warnings.filterwarnings('ignore')

Specify the training parameters.

In [2]:
n_categories = 120 # upper bound of number of categories (clusters)
state_dim = 2 # continuous (state) variable dimensionality 
n_arm = 2 # number of arms
latent_dim = 10 # latent dimensionality of the model
batch_size = 5000 # mini-batch size for training
n_epoch = 10 # number of epochs for training
n_epoch_p = 5 # number of epochs for pruning
min_con = 0.9 # minimum consensus among arms
max_prun_it = 2 # maximum number of pruning iterations
batch_size = 5000 # mini-batch size for training
lr = 1e-3 # learning rate for training

Load the prepared data (as described in ```1_data_prep.ipynb```) and create training and validation sets.

In [3]:
toml_file = 'pyproject.toml'
sub_file = 'smartseq_files'
config = get_paths(toml_file=toml_file, sub_file=sub_file)
data_path = config['paths']['main_dir'] / config['paths']['data_path']
data_file = data_path / config[sub_file]['anndata_file']

/allen/programs/celltypes/workgroups/mousecelltypes/Hilal/MMIDAS/pyproject.toml
Getting files directories belong to smartseq_files...


In [4]:
data = load_data(datafile=data_file)
trainloader, testloader, _, = get_loaders(dataset=data['log1p'], batch_size=batch_size)

data is loaded!
 --------- Data Summary --------- 
num cell types: 115, num cells: 22365, num genes:5032


Create a designated folder to store training files

In [5]:
n_run = 1
augmentation = False
folder_name = f'run_{n_run}_K_{n_categories}_Sdim_{state_dim}_aug_{augmentation}_lr_{lr}_n_arm_{n_arm}_nbatch_{batch_size}' + \
            f'_nepoch_{n_epoch}_nepochP_{n_epoch_p}'
saving_folder = config['paths']['main_dir'] / config['paths']['saving_path']
saving_folder = saving_folder / folder_name
os.makedirs(saving_folder, exist_ok=True)
os.makedirs(saving_folder / 'model', exist_ok=True)
saving_folder = str(saving_folder)

Construct a cpl-mixVAE object and launch its training on the prepared data.

In [14]:
importlib.reload(mmidas.cpl_mixvae)
importlib.reload(mmidas.nn_model)
from mmidas.cpl_mixvae import cpl_mixVAE

cplMixVAE = cpl_mixVAE(saving_folder=saving_folder, device='cuda', seed=546)
cplMixVAE.init(categories=n_categories,
                     state_dim=state_dim,
                     input_dim=data['log1p'].shape[1],
                     lowD_dim=latent_dim,
                     lr=lr,
                     arms=n_arm)

model_file = cplMixVAE.train(train_loader=trainloader,
                             test_loader=testloader,
                             n_epoch=n_epoch,
                             n_epoch_p=n_epoch_p,
                             min_con=min_con,
                             max_prun_it=max_prun_it)

device: Tesla V100-SXM2-32GB
Start training ...


  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:00<00:08,  1.06it/s, Total Loss=1.67e+8, Rec_arm_1=8.8, Rec_arm_2=8.8, Joint Loss=1.67e+8, Entropy=-6.82, Distance=0.513, Elapsed Time=0.944, Validation Loss=1.98e+10, Validation Rec. Loss=8.75]

           -- EPOCH 0 --           
Total Loss:                    167475746.0000                
Rec_arm_1:                     8.8015                        
Rec_arm_2:                     8.7973                        
Joint Loss:                    167387184.0000                
Entropy:                       -6.8200                       
Distance:                      0.5133                        
Elapsed Time:                  0.9438                        
Validation Loss:               19812003840.0000              
Validation Rec. Loss:          8.7490                        


 20%|██        | 2/10 [00:01<00:05,  1.50it/s, Total Loss=7.15e+7, Rec_arm_1=8.76, Rec_arm_2=8.74, Joint Loss=7.14e+7, Entropy=-6.96, Distance=0.484, Elapsed Time=0.469, Validation Loss=2.59e+10, Validation Rec. Loss=8.67]

           -- EPOCH 1 --           
Total Loss:                    71481054.0000                 
Rec_arm_1:                     8.7551                        
Rec_arm_2:                     8.7384                        
Joint Loss:                    71393024.0000                 
Entropy:                       -6.9623                       
Distance:                      0.4839                        
Elapsed Time:                  0.4693                        
Validation Loss:               25910358016.0000              
Validation Rec. Loss:          8.6732                        


 30%|███       | 3/10 [00:01<00:04,  1.74it/s, Total Loss=3.83e+7, Rec_arm_1=8.67, Rec_arm_2=8.63, Joint Loss=3.82e+7, Entropy=-7.11, Distance=0.468, Elapsed Time=0.465, Validation Loss=2.5e+10, Validation Rec. Loss=8.53] 

           -- EPOCH 2 --           
Total Loss:                    38269001.0000                 
Rec_arm_1:                     8.6707                        
Rec_arm_2:                     8.6306                        
Joint Loss:                    38181940.0000                 
Entropy:                       -7.1136                       
Distance:                      0.4680                        
Elapsed Time:                  0.4653                        
Validation Loss:               25026070528.0000              
Validation Rec. Loss:          8.5306                        


 40%|████      | 4/10 [00:02<00:03,  1.87it/s, Total Loss=2.35e+7, Rec_arm_1=8.51, Rec_arm_2=8.42, Joint Loss=2.34e+7, Entropy=-7.23, Distance=0.455, Elapsed Time=0.473, Validation Loss=2.39e+10, Validation Rec. Loss=8.26]

           -- EPOCH 3 --           
Total Loss:                    23461117.0000                 
Rec_arm_1:                     8.5125                        
Rec_arm_2:                     8.4234                        
Joint Loss:                    23375896.0000                 
Entropy:                       -7.2337                       
Distance:                      0.4553                        
Elapsed Time:                  0.4726                        
Validation Loss:               23899086848.0000              
Validation Rec. Loss:          8.2605                        


 50%|█████     | 5/10 [00:03<00:03,  1.64it/s, Total Loss=1.69e+7, Rec_arm_1=8.22, Rec_arm_2=8.04, Joint Loss=1.68e+7, Entropy=-7.33, Distance=0.446, Elapsed Time=0.736, Validation Loss=2.26e+10, Validation Rec. Loss=7.77]

           -- EPOCH 4 --           
Total Loss:                    16900663.2500                 
Rec_arm_1:                     8.2216                        
Rec_arm_2:                     8.0354                        
Joint Loss:                    16818858.0000                 
Entropy:                       -7.3337                       
Distance:                      0.4458                        
Elapsed Time:                  0.7360                        
Validation Loss:               22597525504.0000              
Validation Rec. Loss:          7.7666                        


 60%|██████    | 6/10 [00:03<00:02,  1.78it/s, Total Loss=1.24e+7, Rec_arm_1=7.71, Rec_arm_2=7.35, Joint Loss=1.24e+7, Entropy=-7.36, Distance=0.444, Elapsed Time=0.466, Validation Loss=2.26e+10, Validation Rec. Loss=6.94]

           -- EPOCH 5 --           
Total Loss:                    12439736.7500                 
Rec_arm_1:                     7.7065                        
Rec_arm_2:                     7.3522                        
Joint Loss:                    12363961.0000                 
Entropy:                       -7.3568                       
Distance:                      0.4436                        
Elapsed Time:                  0.4658                        
Validation Loss:               22570041344.0000              
Validation Rec. Loss:          6.9416                        


 70%|███████   | 7/10 [00:04<00:01,  1.88it/s, Total Loss=9.6e+6, Rec_arm_1=6.86, Rec_arm_2=6.29, Joint Loss=9.53e+6, Entropy=-7.42, Distance=0.437, Elapsed Time=0.469, Validation Loss=2.31e+10, Validation Rec. Loss=5.75] 

           -- EPOCH 6 --           
Total Loss:                    9596163.2500                  
Rec_arm_1:                     6.8626                        
Rec_arm_2:                     6.2879                        
Joint Loss:                    9529990.0000                  
Entropy:                       -7.4152                       
Distance:                      0.4365                        
Elapsed Time:                  0.4693                        
Validation Loss:               23116150784.0000              
Validation Rec. Loss:          5.7474                        


 80%|████████  | 8/10 [00:04<00:01,  1.88it/s, Total Loss=7.34e+6, Rec_arm_1=5.66, Rec_arm_2=5, Joint Loss=7.29e+6, Entropy=-7.45, Distance=0.433, Elapsed Time=0.533, Validation Loss=2.24e+10, Validation Rec. Loss=4.54]  

           -- EPOCH 7 --           
Total Loss:                    7341423.0000                  
Rec_arm_1:                     5.6580                        
Rec_arm_2:                     4.9977                        
Joint Loss:                    7287803.0000                  
Entropy:                       -7.4505                       
Distance:                      0.4330                        
Elapsed Time:                  0.5334                        
Validation Loss:               22374895616.0000              
Validation Rec. Loss:          4.5392                        


 90%|█████████ | 9/10 [00:05<00:00,  1.57it/s, Total Loss=6.75e+6, Rec_arm_1=4.43, Rec_arm_2=4.3, Joint Loss=6.71e+6, Entropy=-7.5, Distance=0.423, Elapsed Time=0.86, Validation Loss=2.09e+10, Validation Rec. Loss=4.22]

           -- EPOCH 8 --           
Total Loss:                    6752244.6250                  
Rec_arm_1:                     4.4336                        
Rec_arm_2:                     4.2963                        
Joint Loss:                    6708315.0000                  
Entropy:                       -7.4996                       
Distance:                      0.4225                        
Elapsed Time:                  0.8601                        
Validation Loss:               20890984448.0000              
Validation Rec. Loss:          4.2198                        


100%|██████████| 10/10 [00:05<00:00,  1.69it/s, Total Loss=5.58e+6, Rec_arm_1=4.06, Rec_arm_2=4.25, Joint Loss=5.53e+6, Entropy=-7.54, Distance=0.423, Elapsed Time=0.47, Validation Loss=2e+10, Validation Rec. Loss=4.14]

           -- EPOCH 9 --           
Total Loss:                    5576656.6250                  
Rec_arm_1:                     4.0647                        
Rec_arm_2:                     4.2469                        
Joint Loss:                    5534832.5000                  
Entropy:                       -7.5353                       
Distance:                      0.4233                        
Elapsed Time:                  0.4701                        
Validation Loss:               19991003136.0000              
Validation Rec. Loss:          4.1372                        





Continue training with pruning ...
Pruned categories: [1]


 20%|██        | 1/5 [00:00<00:02,  2.00it/s, Total Loss=5.04e+6, Rec_arm_1=3.91, Rec_arm_2=3.75, Joint Loss=5e+6, Entropy=-7.55, Distance=0.423, Elapsed Time=0.499, Validation Loss=1.9e+10, Validation Rec. Loss=3.63]

           -- EPOCH 0 --           
Total Loss:                    5035232.2500                  
Rec_arm_1:                     3.9061                        
Rec_arm_2:                     3.7475                        
Joint Loss:                    4996719.0000                  
Entropy:                       -7.5505                       
Distance:                      0.4229                        
Elapsed Time:                  0.4990                        
Validation Loss:               18995996672.0000              
Validation Rec. Loss:          3.6262                        


 40%|████      | 2/5 [00:01<00:01,  1.52it/s, Total Loss=4.35e+6, Rec_arm_1=3.43, Rec_arm_2=3.56, Joint Loss=4.31e+6, Entropy=-7.58, Distance=0.418, Elapsed Time=0.762, Validation Loss=1.84e+10, Validation Rec. Loss=3.44]

           -- EPOCH 1 --           
Total Loss:                    4346429.7500                  
Rec_arm_1:                     3.4303                        
Rec_arm_2:                     3.5615                        
Joint Loss:                    4311246.5000                  
Entropy:                       -7.5817                       
Distance:                      0.4178                        
Elapsed Time:                  0.7618                        
Validation Loss:               18447450112.0000              
Validation Rec. Loss:          3.4425                        


 60%|██████    | 3/5 [00:01<00:01,  1.64it/s, Total Loss=4.22e+6, Rec_arm_1=3.33, Rec_arm_2=3.54, Joint Loss=4.19e+6, Entropy=-7.62, Distance=0.418, Elapsed Time=0.548, Validation Loss=1.8e+10, Validation Rec. Loss=3.39] 

           -- EPOCH 2 --           
Total Loss:                    4221521.7500                  
Rec_arm_1:                     3.3278                        
Rec_arm_2:                     3.5443                        
Joint Loss:                    4186941.2500                  
Entropy:                       -7.6183                       
Distance:                      0.4182                        
Elapsed Time:                  0.5483                        
Validation Loss:               17952966656.0000              
Validation Rec. Loss:          3.3890                        


 80%|████████  | 4/5 [00:02<00:00,  1.71it/s, Total Loss=3.69e+6, Rec_arm_1=3.3, Rec_arm_2=3.47, Joint Loss=3.66e+6, Entropy=-7.62, Distance=0.412, Elapsed Time=0.546, Validation Loss=1.69e+10, Validation Rec. Loss=3.33]

           -- EPOCH 3 --           
Total Loss:                    3691101.8750                  
Rec_arm_1:                     3.2962                        
Rec_arm_2:                     3.4684                        
Joint Loss:                    3657062.2500                  
Entropy:                       -7.6197                       
Distance:                      0.4116                        
Elapsed Time:                  0.5464                        
Validation Loss:               16942145536.0000              
Validation Rec. Loss:          3.3310                        


100%|██████████| 5/5 [00:02<00:00,  1.70it/s, Total Loss=3.6e+6, Rec_arm_1=3.22, Rec_arm_2=3.44, Joint Loss=3.56e+6, Entropy=-7.64, Distance=0.412, Elapsed Time=0.56, Validation Loss=1.6e+10, Validation Rec. Loss=3.33]  

           -- EPOCH 4 --           
Total Loss:                    3596937.8750                  
Rec_arm_1:                     3.2177                        
Rec_arm_2:                     3.4383                        
Joint Loss:                    3563444.0000                  
Entropy:                       -7.6351                       
Distance:                      0.4121                        
Elapsed Time:                  0.5600                        
Validation Loss:               16008756224.0000              
Validation Rec. Loss:          3.3345                        





Continue training with pruning ...
Pruned categories: [1 2]


 20%|██        | 1/5 [00:00<00:01,  2.04it/s, Total Loss=3.4e+6, Rec_arm_1=3.2, Rec_arm_2=3.43, Joint Loss=3.37e+6, Entropy=-7.62, Distance=0.41, Elapsed Time=0.489, Validation Loss=1.48e+10, Validation Rec. Loss=3.34]

           -- EPOCH 0 --           
Total Loss:                    3399608.1875                  
Rec_arm_1:                     3.1951                        
Rec_arm_2:                     3.4345                        
Joint Loss:                    3366247.7500                  
Entropy:                       -7.6223                       
Distance:                      0.4105                        
Elapsed Time:                  0.4888                        
Validation Loss:               14773789696.0000              
Validation Rec. Loss:          3.3356                        


 40%|████      | 2/5 [00:00<00:01,  2.06it/s, Total Loss=3.11e+6, Rec_arm_1=3.18, Rec_arm_2=3.4, Joint Loss=3.07e+6, Entropy=-7.63, Distance=0.415, Elapsed Time=0.479, Validation Loss=1.35e+10, Validation Rec. Loss=3.29]

           -- EPOCH 1 --           
Total Loss:                    3107735.1875                  
Rec_arm_1:                     3.1847                        
Rec_arm_2:                     3.3950                        
Joint Loss:                    3074625.7500                  
Entropy:                       -7.6286                       
Distance:                      0.4151                        
Elapsed Time:                  0.4791                        
Validation Loss:               13520855040.0000              
Validation Rec. Loss:          3.2950                        


 60%|██████    | 3/5 [00:01<00:01,  1.63it/s, Total Loss=2.93e+6, Rec_arm_1=3.15, Rec_arm_2=3.38, Joint Loss=2.9e+6, Entropy=-7.64, Distance=0.409, Elapsed Time=0.763, Validation Loss=1.29e+10, Validation Rec. Loss=3.27]

           -- EPOCH 2 --           
Total Loss:                    2931390.5625                  
Rec_arm_1:                     3.1476                        
Rec_arm_2:                     3.3765                        
Joint Loss:                    2898560.5000                  
Entropy:                       -7.6435                       
Distance:                      0.4095                        
Elapsed Time:                  0.7627                        
Validation Loss:               12892220416.0000              
Validation Rec. Loss:          3.2663                        


 80%|████████  | 4/5 [00:02<00:00,  1.77it/s, Total Loss=2.99e+6, Rec_arm_1=3.13, Rec_arm_2=3.37, Joint Loss=2.96e+6, Entropy=-7.66, Distance=0.408, Elapsed Time=0.486, Validation Loss=1.22e+10, Validation Rec. Loss=3.25]

           -- EPOCH 3 --           
Total Loss:                    2993743.8750                  
Rec_arm_1:                     3.1343                        
Rec_arm_2:                     3.3677                        
Joint Loss:                    2961025.2500                  
Entropy:                       -7.6569                       
Distance:                      0.4081                        
Elapsed Time:                  0.4856                        
Validation Loss:               12197145600.0000              
Validation Rec. Loss:          3.2504                        


100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Total Loss=2.7e+6, Rec_arm_1=3.13, Rec_arm_2=3.36, Joint Loss=2.66e+6, Entropy=-7.67, Distance=0.409, Elapsed Time=0.481, Validation Loss=1.17e+10, Validation Rec. Loss=3.24] 

           -- EPOCH 4 --           
Total Loss:                    2697108.5000                  
Rec_arm_1:                     3.1282                        
Rec_arm_2:                     3.3559                        
Joint Loss:                    2664480.0000                  
Entropy:                       -7.6653                       
Distance:                      0.4091                        
Elapsed Time:                  0.4805                        
Validation Loss:               11679062016.0000              
Validation Rec. Loss:          3.2436                        





No more pruning!
Training is done!


In [15]:
importlib.reload(mmidas.cpl_mixvae)
importlib.reload(mmidas.nn_model)

<module 'mmidas.nn_model' from '/allen/programs/celltypes/workgroups/mousecelltypes/Hilal/MMIDAS/mmidas/nn_model.py'>

In [16]:
importlib.reload(mmidas.cpl_mixvae)
importlib.reload(mmidas.nn_model)
from mmidas.cpl_mixvae import cpl_mixVAE

cplMixVAE = cpl_mixVAE(saving_folder=saving_folder, device='cuda', seed=546)
cplMixVAE.init(categories=n_categories,
                     state_dim=state_dim,
                     input_dim=data['log1p'].shape[1],
                     lowD_dim=latent_dim,
                     lr=lr,
                     arms=n_arm,)

fsdp_dct = cplMixVAE._fsdp(train_loader=trainloader,
                           val_loader=testloader,
                           epochs=n_epoch,
                           n_epoch_p=n_epoch_p,
                           min_con=min_con,
                           max_prun_it=max_prun_it,
                           device=cplMixVAE.device,
                           model=cplMixVAE.model,
                           opt=cplMixVAE.optimizer)

# TODO:
# fix nans

device: Tesla V100-SXM2-32GB


  0%|          | 0/10 [00:03<?, ?it/s]


BackendCompilerFailed: backend='inductor' raised:
LoweringException: ImportError: cannot import name 'add_external_libs' from 'triton._C.libtriton.triton' (unknown location)
  target: aten.addmm.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_2', layout=FixedLayout('cuda', torch.float32, size=[100], stride=[1]))
  ))
  args[1]: TensorBox(StorageBox(
    ComputedBuffer(name='buf2', layout=FixedLayout('cuda', torch.float32, size=[5000, 5032], stride=[5032, 1]), data=Pointwise(
      'cuda',
      torch.float32,
      def inner_fn(index):
          i0, i1 = index
          tmp0 = ops.load(buf1, i1 + 5032 * i0)
          tmp1 = ops.constant(0.5, torch.float32)
          tmp2 = tmp0 > tmp1
          tmp3 = ops.to_dtype(tmp2, torch.float32)
          tmp4 = ops.load(primals_87, i1 + 5032 * i0)
          tmp5 = tmp3 * tmp4
          tmp6 = ops.constant(2.0, torch.float32)
          tmp7 = tmp5 * tmp6
          return tmp7
      ,
      ranges=[5000, 5032],
      origin_node=mul_1,
      origins={gt, mul_1, mul}
    ))
  ))
  args[2]: TensorBox(
    ReinterpretView(
      StorageBox(
        InputBuffer(name='primals_1', layout=FixedLayout('cuda', torch.float32, size=[100, 5032], stride=[5032, 1]))
      ),
      FixedLayout('cuda', torch.float32, size=[5032, 100], stride=[1, 5032]),
      origins={permute}
    )
  )

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


In [None]:
# Start training ...
# ====> Epoch:0, Total Loss: 137667456.0000, Rec_arm_1: 8.7999, Rec_arm_2: 8.7966, Joint Loss: 137578912.0000, Entropy: -7.1332, Distance: 0.4805, Elapsed Time:4.13
# ====> Validation Total Loss: 28064712704.0000, Rec. Loss: 8.7721
# ====> Epoch:1, Total Loss: 69825152.0000, Rec_arm_1: 8.7497, Rec_arm_2: 8.7451, Joint Loss: 69737120.0000, Entropy: -7.3747, Distance: 0.4512, Elapsed Time:3.64
# ====> Validation Total Loss: 25941938176.0000, Rec. Loss: 8.6993
# ====> Epoch:2, Total Loss: 42043478.0000, Rec_arm_1: 8.6571, Rec_arm_2: 8.6533, Joint Loss: 41956372.0000, Entropy: -7.3914, Distance: 0.4467, Elapsed Time:3.58
# ====> Validation Total Loss: 22861494272.0000, Rec. Loss: 8.5640
# ====> Epoch:3, Total Loss: 26674182.0000, Rec_arm_1: 8.4823, Rec_arm_2: 8.4839, Joint Loss: 26588808.0000, Entropy: -7.5388, Distance: 0.4278, Elapsed Time:3.60
# ====> Validation Total Loss: 20230934528.0000, Rec. Loss: 8.3146
# ====> Epoch:4, Total Loss: 17169558.2500, Rec_arm_1: 8.1604, Rec_arm_2: 8.1765, Joint Loss: 17087350.0000, Entropy: -7.6280, Distance: 0.4243, Elapsed Time:3.65
# ====> Validation Total Loss: 20135974912.0000, Rec. Loss: 7.8776
# ====> Epoch:5, Total Loss: 11332364.7500, Rec_arm_1: 7.5881, Rec_arm_2: 7.6376, Joint Loss: 11255748.0000, Entropy: -7.7191, Distance: 0.4143, Elapsed Time:3.53
# ====> Validation Total Loss: 23999512576.0000, Rec. Loss: 7.1436
# ====> Epoch:6, Total Loss: 8132679.8750, Rec_arm_1: 6.6561, Rec_arm_2: 6.7642, Joint Loss: 8065148.5000, Entropy: -7.7605, Distance: 0.4065, Elapsed Time:3.60
# ====> Validation Total Loss: 31686617088.0000, Rec. Loss: 6.0345
# ====> Epoch:7, Total Loss: 6335305.1250, Rec_arm_1: 5.3905, Rec_arm_2: 5.5418, Joint Loss: 6280293.0000, Entropy: -7.8043, Distance: 0.4005, Elapsed Time:3.63
# ====> Validation Total Loss: 36792733696.0000, Rec. Loss: 4.7336
# ====> Epoch:8, Total Loss: 5184599.3750, Rec_arm_1: 4.3659, Rec_arm_2: 4.3467, Joint Loss: 5140757.0000, Entropy: -7.8469, Distance: 0.3987, Elapsed Time:3.51
# ====> Validation Total Loss: 38023708672.0000, Rec. Loss: 4.0776
# ====> Epoch:9, Total Loss: 4790232.1250, Rec_arm_1: 4.2830, Rec_arm_2: 4.0376, Joint Loss: 4748362.0000, Entropy: -7.8604, Distance: 0.3998, Elapsed Time:3.57
# ====> Validation Total Loss: 36604887040.0000, Rec. Loss: 3.9312
# Training with pruning...
# Purned categories: [1]
# ====> Epoch:0, Total Loss: 4001048.3125, Rec_arm_1: 3.8611, Rec_arm_2: 3.8509, Joint Loss: 3962241.0000, Entropy: -7.8664, Distance: 0.3989, Elapsed Time:3.85
# ====> Validation Total Loss: nan, Rec. Loss: 3.5026
# ====> Epoch:1, Total Loss: 3639977.4375, Rec_arm_1: 3.5320, Rec_arm_2: 3.3895, Joint Loss: 3605148.2500, Entropy: -7.8894, Distance: 0.3974, Elapsed Time:3.83
# ====> Validation Total Loss: nan, Rec. Loss: 3.4378
# ====> Epoch:2, Total Loss: 3534881.8125, Rec_arm_1: 3.4955, Rec_arm_2: 3.3030, Joint Loss: 3500671.7500, Entropy: -7.9030, Distance: 0.3960, Elapsed Time:4.10
# ====> Validation Total Loss: nan, Rec. Loss: 3.4195
# ====> Epoch:3, Total Loss: 3155906.8125, Rec_arm_1: 3.4152, Rec_arm_2: 3.2708, Joint Loss: 3122262.7500, Entropy: -7.9187, Distance: 0.3939, Elapsed Time:4.40
# ====> Validation Total Loss: nan, Rec. Loss: 3.3164
# ====> Epoch:4, Total Loss: 2953709.7500, Rec_arm_1: 3.3518, Rec_arm_2: 3.1916, Joint Loss: 2920782.5000, Entropy: -7.9333, Distance: 0.3947, Elapsed Time:3.98
# ====> Validation Total Loss: nan, Rec. Loss: 3.2596
# Training with pruning...
# Purned categories: [1 2]
# ====> Epoch:0, Total Loss: 2872254.8750, Rec_arm_1: 3.3458, Rec_arm_2: 3.1647, Joint Loss: 2839493.0000, Entropy: -7.9148, Distance: 0.3927, Elapsed Time:3.81
# ====> Validation Total Loss: nan, Rec. Loss: 3.2327
# ====> Epoch:1, Total Loss: 2638138.1875, Rec_arm_1: 3.3082, Rec_arm_2: 3.1392, Joint Loss: 2605694.2500, Entropy: -7.9227, Distance: 0.3921, Elapsed Time:3.87
# ====> Validation Total Loss: nan, Rec. Loss: 3.1979
# ====> Epoch:2, Total Loss: 2455072.6250, Rec_arm_1: 3.2809, Rec_arm_2: 3.0947, Joint Loss: 2422990.5000, Entropy: -7.9277, Distance: 0.3917, Elapsed Time:3.78
# ====> Validation Total Loss: nan, Rec. Loss: 3.1861
# ====> Epoch:3, Total Loss: 2434174.7500, Rec_arm_1: 3.2733, Rec_arm_2: 3.0776, Joint Loss: 2402216.5000, Entropy: -7.9398, Distance: 0.3898, Elapsed Time:3.97
# ====> Validation Total Loss: nan, Rec. Loss: 3.1811
# ====> Epoch:4, Total Loss: 2520542.7500, Rec_arm_1: 3.2612, Rec_arm_2: 3.0682, Joint Loss: 2488693.0000, Entropy: -7.9395, Distance: 0.3899, Elapsed Time:3.85
# ====> Validation Total Loss: nan, Rec. Loss: 3.1711
# No more pruning!
# Training is done!

In [None]:
def join(s, *args, **kwargs): 
    return s.join(*args, **kwargs)

join('', ['a', 'b', 'c'])

'abc'

In [None]:
# TODO: [] make more robust
def pprint(dct): 
  for k, v in fsdp_dct.items():
    print(f"{k}: {{")
    for kk, vv in v.items():
        print(f"\t{kk}: [{', '.join(map(str, vv))}]")
    print("}")

Working directly with command line, you have the option to train the model using a Python file, such as ```tutorial/train_unimodal.py``` as follows.

```
python train_unimodal.py --n_epoch 10 --n_epoch_p 5 --max_prun_it 2
```
or
```
python train_unimodal.py --n_epoch 10 --n_epoch_p 5 --max_prun_it 2 --device 'cuda'
```