In [1]:
import os
import json
import argparse
import datetime
import pickle

import numpy as np
import torch
from torch.optim import Adam, SGD, RAdam
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tensorboardX import SummaryWriter
from torchvision import transforms
from sklearn.model_selection import train_test_split
from warmup_scheduler import GradualWarmupScheduler

# from model import *
from betaVAE import *
from read_data import *
from utils import *

In [24]:
# Simulating argparse functionality in a Jupyter Notebook
class Args:
    def __init__(self):
        self.config = '../configs/betavae_tissues.json'  # Example: replace with your default config file
        self.checkpoint = None       # Example: replace with your default checkpoint if any
        self.seed = 99
        self.log = 0
        self.parallel = None

# Instantiate the simulated args
args = Args()

# Access the arguments like this:
print(f"Config file: {args.config}")
print(f"Checkpoint: {args.checkpoint}")
print(f"Seed: {args.seed}")
print(f"Log: {args.log}")
print(f"Parallel: {args.parallel}")


Config file: ../configs/betavae_tissues.json
Checkpoint: None
Seed: 99
Log: 0
Parallel: None


In [25]:
with open(args.config) as f:
    config = json.load(f)

In [26]:
config

{'path_csv': ['../../RNA/data_for_beta_VAE/GTex_Lung_data_SSL_proteincoding.csv',
  '../../RNA/data_for_beta_VAE/GTex_BrainCortex_proteincoding.csv',
  '../../RNA/data_for_beta_VAE/GTex_Liver_proteincoding.csv',
  '../../RNA/data_for_beta_VAE/GTex_Stomach_proteincoding.csv',
  '../../RNA/data_for_beta_VAE/GTex_Pancreas_proteincoding.csv'],
 'patch_data_path': '../../Histology/Lung_Patches256x256/',
 'img_size': 256,
 'max_patch_per_wsi': 100,
 'rna_features': 19198,
 'weights_decay': 0,
 'lr': 5e-05,
 'num_epochs': 500,
 'n_workers': 4,
 'device': 0,
 'flag': 'betavae_proteincoding_tissues',
 'save_dir': '../checkpoints/betavae_training_tissues/',
 'summary_path': '../summaries_betavae_tissues/',
 'log_interval': 20,
 'bag_size': 40,
 'batch_size': 128,
 'beta': 0.0005,
 'quick': 0,
 'optimizer': 'Adam'}

In [5]:
print(10*'-')
print('Config for this experiment \n')
print(config)
print(10*'-')

if 'flag' in config:
    args.flag = config['flag']
else:
    args.flag = 'train_{date:%Y-%m-%d %H:%M:%S}'.format(date=datetime.datetime.now())

if not os.path.exists(config['save_dir']):
    os.mkdir(config['save_dir'])

----------
Config for this experiment 

{'path_csv': ['../../RNA/data_for_beta_VAE/GTex_Lung_data_SSL_proteincoding.csv', '../../RNA/data_for_beta_VAE/GTex_BrainCortex_proteincoding.csv', '../../RNA/data_for_beta_VAE/GTex_Liver_proteincoding.csv', '../../RNA/data_for_beta_VAE/GTex_Stomach_proteincoding.csv', '../../RNA/data_for_beta_VAE/GTex_Pancreas_proteincoding.csv'], 'patch_data_path': '../../Histology/Lung_Patches256x256/', 'img_size': 256, 'max_patch_per_wsi': 100, 'rna_features': 19198, 'weights_decay': 0, 'lr': 5e-05, 'num_epochs': 500, 'n_workers': 4, 'device': 0, 'flag': 'betavae_proteincoding_tissues', 'save_dir': '../checkpoints/betavae_training_tissues/', 'summary_path': '../summaries_betavae_tissues/', 'log_interval': 20, 'bag_size': 40, 'batch_size': 128, 'beta': 0.0005, 'quick': 0, 'optimizer': 'Adam'}
----------


In [6]:
path_csv = config['path_csv']
rna_features = config['rna_features']
batch_size = config.get('batch_size', 64)
encoder_checkpoint = config.get('encoder_checkpoint', None)
beta = config.get('beta', 2)
quick = config.get('quick', 0)
opt = config.get('optimizer', 'Adam')

print('Loading dataset...')

datasets = {
    'train': [],
    'test': [],
    'val': []
}

test_labels = []

Loading dataset...


In [7]:
def normalize_dfs(train_df, val_df, test_df, labels=False, norm_type='standard'):
    def _get_log(x):
        # trick to take into account zeros
        x = np.log(x.replace(0, np.nan))
        return x.replace(np.nan, 0)
    # get list of columns to scale
    rna_columns = [x for x in train_df.columns if 'rna_' in x]
    
    
    # log transform
    train_df[rna_columns] = train_df[rna_columns].apply(_get_log)
    val_df[rna_columns] = val_df[rna_columns].apply(_get_log)
    test_df[rna_columns] = test_df[rna_columns].apply(_get_log)
    
    
    train_df = train_df[rna_columns+['wsi_file_name']]
    val_df = val_df[rna_columns+['wsi_file_name']]
    test_df = test_df[rna_columns+['wsi_file_name']]
    
    rna_values = train_df[rna_columns].values

    if norm_type == 'standard':
        scaler = StandardScaler()
    elif norm_type == 'minmax':
        scaler = MinMaxScaler(feature_range=(0,1))
    rna_values = scaler.fit_transform(rna_values)

    train_df[rna_columns] = rna_values
    test_df[rna_columns] = scaler.transform(test_df[rna_columns].values)
    val_df[rna_columns] = scaler.transform(val_df[rna_columns].values)

    return train_df, val_df, test_df, scaler

In [8]:
for id, dataset in enumerate(path_csv):
    print(dataset)
    df = pd.read_csv(dataset)

    df_transposed = df.set_index('gene_id').transpose().reset_index()
    df_transposed.rename(columns={'index': 'wsi_file_name'}, inplace=True)

    train_df, test_df = train_test_split(df_transposed, test_size=0.2)

    train_df, val_df = train_test_split(train_df, test_size=0.2)

    train_df, val_df, test_df, scaler = normalize_dfs(train_df, val_df, test_df, norm_type='minmax')

    datasets['train'].append(train_df)
    datasets['test'].append(test_df)
    datasets['val'].append(val_df)
    
    test_labels = test_labels + ([id] * test_df.shape[0])

../../RNA/data_for_beta_VAE/GTex_Lung_data_SSL_proteincoding.csv
../../RNA/data_for_beta_VAE/GTex_BrainCortex_proteincoding.csv
../../RNA/data_for_beta_VAE/GTex_Liver_proteincoding.csv
../../RNA/data_for_beta_VAE/GTex_Stomach_proteincoding.csv
../../RNA/data_for_beta_VAE/GTex_Pancreas_proteincoding.csv


In [10]:
if(len(datasets['train']) >=2):
    train_df = pd.concat([datasets['train'][0], datasets['train'][1]])
    val_df = pd.concat([datasets['val'][0], datasets['val'][1]])
    test_df = pd.concat([datasets['test'][0], datasets['test'][1]])
    for i in range(2, len(datasets['train'])):
        train_df = pd.concat([train_df, datasets['train'][i]])
        val_df = pd.concat([val_df, datasets['val'][i]])
        test_df = pd.concat([test_df, datasets['test'][i]])
else:
    train_df = datasets['train'][0]
    val_df = datasets['val'][0]
    test_df = datasets['test'][0]

print('Train shape {}'.format(train_df.shape))
print('Val shape {}'.format(val_df.shape))
print('Test shape {}'.format(test_df.shape))
train_df, val_df, test_df, scaler = normalize_dfs(train_df, val_df, test_df, norm_type='standard')

train_dataset = RNADataset([train_df], quick=quick)
val_dataset = RNADataset([val_df])
test_dataset = RNADataset([test_df])

train_dataloader = DataLoader(train_dataset,batch_size=batch_size, 
               num_workers=16, shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size, 
               num_workers=16, 
               shuffle=False)
test_dataloader = DataLoader(test_dataset,batch_size=1, 
               num_workers=16, 
               shuffle=False)

Train shape (1084, 19199)
Val shape (274, 19199)
Test shape (341, 19199)


  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
1084it [00:12, 89.38it/s]
274it [00:03, 81.05it/s]
341it [00:03, 88.82it/s]


In [23]:
encoder_checkpoint

In [29]:
print('Finished loading dataset and creating dataloader')

print('Initializing models')


if encoder_checkpoint:
    model = betaVAE(rna_features, 2048, [12000, 4096, 2048], [4096, 12000],
                      encoder_checkpoint=encoder_checkpoint)
    if args.checkpoint is not None:
        print('Restoring from checkpoint')
        print(args.checkpoint)
        model.load_state_dict(torch.load(args.checkpoint))
        print('Loaded model from checkpoint')
    else:
        model.z_mu.apply(init_weights_uniform)
        model.decoder.apply(init_weights_uniform)
        model.z_logvar.apply(init_weights_uniform)
else:
    model = betaVAE(rna_features, 2048, [6000, 4000, 2048], [4000, 6000], beta=beta)
    if args.checkpoint is not None:
        print('Restoring from checkpoint')
        print(args.checkpoint)
        model.load_state_dict(torch.load(args.checkpoint))
        print('Loaded model from checkpoint')
    else:
        model.apply(init_weights_xavier)


#torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
print('Model initialized')

if args.parallel:
    print('Using more than one gpu...')
    model = nn.DataParallel(model)

if torch.cuda.is_available():
    model = model.cuda()

lr = config.get('lr', 3e-3)

if opt == 'RAdam':
    optimizer = RAdam(model.parameters(), weight_decay = config['weights_decay'], lr=lr)
elif opt == 'SGD':
    optimizer = SGD(model.parameters(), weight_decay = config['weights_decay'], lr=lr)
else:
    optimizer = Adam(model.parameters(), weight_decay = config['weights_decay'], lr=lr)

#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=125, gamma=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=1000, after_scheduler=scheduler)
# train model

if args.log:
    summary_writer = SummaryWriter(
            os.path.join(config['summary_path'],
                datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + "_{0}".format(args.flag)))

    summary_writer.add_text('config', str(config))
else:
    summary_writer = None


Finished loading dataset and creating dataloader
Initializing models
Model initialized


In [None]:
dataloaders = {
    'train': train_dataloader,
    'val': val_dataloader
}
model, results = train_betaVAE(model, optimizer, dataloaders,
                               save_dir=config['save_dir'],
                               device=config['device'], 
                               log_interval=config['log_interval'],
                               summary_writer=summary_writer,
                               num_epochs=config['num_epochs'],
                               scheduler=scheduler_warmup)

Epoch 0/499
----------


9it [01:09,  7.68s/it]

train Total Loss: 1.5821 | Reconstruction Loss: 1.1560 | KL Loss: 852.1877



3it [00:08,  2.81s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 1/499
----------



9it [01:06,  7.40s/it]

train Total Loss: 1.5811 | Reconstruction Loss: 1.1632 | KL Loss: 835.7233



3it [00:09,  3.06s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 2/499
----------



9it [01:06,  7.35s/it]

train Total Loss: 1.5729 | Reconstruction Loss: 1.1597 | KL Loss: 826.4209



3it [00:08,  2.91s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 3/499
----------



9it [01:06,  7.44s/it]

train Total Loss: 1.5553 | Reconstruction Loss: 1.1515 | KL Loss: 807.5507



3it [00:09,  3.05s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 4/499
----------



9it [01:07,  7.46s/it]

train Total Loss: 1.5381 | Reconstruction Loss: 1.1465 | KL Loss: 783.1877



3it [00:08,  2.78s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 5/499
----------



9it [01:05,  7.30s/it]

train Total Loss: 1.5156 | Reconstruction Loss: 1.1367 | KL Loss: 757.9327



3it [00:08,  2.99s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 6/499
----------



9it [01:05,  7.31s/it]

train Total Loss: 1.5080 | Reconstruction Loss: 1.1420 | KL Loss: 732.0149



3it [00:08,  2.99s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 7/499
----------



9it [01:05,  7.31s/it]

train Total Loss: 1.4912 | Reconstruction Loss: 1.1372 | KL Loss: 707.9204



3it [00:09,  3.02s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 8/499
----------



9it [01:06,  7.44s/it]

train Total Loss: 1.4626 | Reconstruction Loss: 1.1224 | KL Loss: 680.4498



3it [00:09,  3.00s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 9/499
----------



9it [01:06,  7.43s/it]

train Total Loss: 1.4493 | Reconstruction Loss: 1.1230 | KL Loss: 652.5200



3it [00:08,  2.85s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 10/499
----------



9it [01:07,  7.49s/it]

train Total Loss: 1.4187 | Reconstruction Loss: 1.1061 | KL Loss: 625.1059



3it [00:09,  3.16s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 11/499
----------



9it [01:07,  7.51s/it]

train Total Loss: 1.4024 | Reconstruction Loss: 1.1034 | KL Loss: 597.9636



3it [00:09,  3.02s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 12/499
----------



9it [01:06,  7.35s/it]

train Total Loss: 1.3817 | Reconstruction Loss: 1.0971 | KL Loss: 569.1502



3it [00:09,  3.06s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 13/499
----------



9it [01:06,  7.43s/it]

train Total Loss: 1.3639 | Reconstruction Loss: 1.0941 | KL Loss: 539.6028



3it [00:09,  3.16s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 14/499
----------



9it [01:05,  7.28s/it]

train Total Loss: 1.3205 | Reconstruction Loss: 1.0647 | KL Loss: 511.4879



3it [00:08,  2.89s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 15/499
----------



9it [01:07,  7.51s/it]

train Total Loss: 1.2915 | Reconstruction Loss: 1.0501 | KL Loss: 482.7318



3it [00:08,  2.86s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 16/499
----------



9it [01:05,  7.31s/it]

train Total Loss: 1.2421 | Reconstruction Loss: 1.0139 | KL Loss: 456.3419



3it [00:09,  3.05s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 17/499
----------



9it [01:07,  7.53s/it]

train Total Loss: 1.2020 | Reconstruction Loss: 0.9872 | KL Loss: 429.4897



3it [00:09,  3.05s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 18/499
----------



9it [01:06,  7.40s/it]

train Total Loss: 1.1524 | Reconstruction Loss: 0.9527 | KL Loss: 399.3086



3it [00:08,  2.80s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 19/499
----------



9it [01:06,  7.35s/it]

train Total Loss: 1.1163 | Reconstruction Loss: 0.9320 | KL Loss: 368.7169



3it [00:08,  2.97s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 20/499
----------



9it [01:07,  7.51s/it]

train Total Loss: 1.0729 | Reconstruction Loss: 0.9066 | KL Loss: 332.5914



3it [00:08,  2.99s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 21/499
----------



9it [01:06,  7.41s/it]

train Total Loss: 1.0471 | Reconstruction Loss: 0.8977 | KL Loss: 298.6544



3it [00:08,  2.93s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 22/499
----------



9it [01:06,  7.41s/it]

train Total Loss: 1.0230 | Reconstruction Loss: 0.8889 | KL Loss: 268.2698



3it [00:08,  2.88s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 23/499
----------



9it [01:08,  7.63s/it]

train Total Loss: 1.0099 | Reconstruction Loss: 0.8931 | KL Loss: 233.4272



3it [00:09,  3.04s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 24/499
----------



9it [01:06,  7.39s/it]

train Total Loss: 0.9853 | Reconstruction Loss: 0.8823 | KL Loss: 206.0254



3it [00:08,  2.98s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 25/499
----------



9it [01:07,  7.53s/it]

train Total Loss: 0.9706 | Reconstruction Loss: 0.8783 | KL Loss: 184.6163



3it [00:08,  2.93s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 26/499
----------



9it [01:06,  7.41s/it]

train Total Loss: 0.9649 | Reconstruction Loss: 0.8842 | KL Loss: 161.5191



3it [00:08,  2.92s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 27/499
----------



9it [01:07,  7.46s/it]

train Total Loss: 0.9472 | Reconstruction Loss: 0.8734 | KL Loss: 147.7120



3it [00:09,  3.13s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 28/499
----------



9it [01:07,  7.55s/it]

train Total Loss: 0.9383 | Reconstruction Loss: 0.8719 | KL Loss: 132.8155



3it [00:08,  2.94s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 29/499
----------



9it [01:06,  7.40s/it]

train Total Loss: 0.9209 | Reconstruction Loss: 0.8620 | KL Loss: 117.8949



3it [00:08,  2.91s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 30/499
----------



9it [01:06,  7.41s/it]

train Total Loss: 0.9150 | Reconstruction Loss: 0.8610 | KL Loss: 107.8617



3it [00:09,  3.02s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 31/499
----------



9it [01:07,  7.48s/it]

train Total Loss: 0.9057 | Reconstruction Loss: 0.8562 | KL Loss: 99.1122



3it [00:08,  2.73s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 32/499
----------



9it [01:06,  7.41s/it]

train Total Loss: 0.9086 | Reconstruction Loss: 0.8618 | KL Loss: 93.6613



3it [00:08,  2.94s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 33/499
----------



9it [01:06,  7.37s/it]

train Total Loss: 0.8990 | Reconstruction Loss: 0.8559 | KL Loss: 86.2585



3it [00:08,  2.75s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 34/499
----------



9it [01:07,  7.48s/it]

train Total Loss: 0.8957 | Reconstruction Loss: 0.8557 | KL Loss: 80.0493



3it [00:09,  3.02s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 35/499
----------



9it [01:06,  7.38s/it]

train Total Loss: 0.8880 | Reconstruction Loss: 0.8499 | KL Loss: 76.2401



3it [00:08,  2.99s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 36/499
----------



9it [01:07,  7.48s/it]

train Total Loss: 0.8871 | Reconstruction Loss: 0.8516 | KL Loss: 70.9259



3it [00:09,  3.12s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 37/499
----------



9it [01:05,  7.26s/it]

train Total Loss: 0.8755 | Reconstruction Loss: 0.8413 | KL Loss: 68.3827



3it [00:08,  2.87s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 38/499
----------



9it [01:06,  7.40s/it]

train Total Loss: 0.8667 | Reconstruction Loss: 0.8339 | KL Loss: 65.6652



3it [00:08,  2.81s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 39/499
----------



9it [01:07,  7.53s/it]

train Total Loss: 0.8676 | Reconstruction Loss: 0.8365 | KL Loss: 62.1588



3it [00:08,  2.99s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 40/499
----------



9it [01:07,  7.53s/it]

train Total Loss: 0.8649 | Reconstruction Loss: 0.8350 | KL Loss: 59.7078



3it [00:09,  3.07s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 41/499
----------



9it [01:05,  7.24s/it]

train Total Loss: 0.8591 | Reconstruction Loss: 0.8299 | KL Loss: 58.3469



3it [00:08,  2.86s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 42/499
----------



9it [01:06,  7.41s/it]

train Total Loss: 0.8584 | Reconstruction Loss: 0.8298 | KL Loss: 57.2327



3it [00:08,  2.87s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 43/499
----------



9it [01:07,  7.50s/it]

train Total Loss: 0.8554 | Reconstruction Loss: 0.8269 | KL Loss: 57.0028



3it [00:08,  2.78s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 44/499
----------



9it [01:06,  7.40s/it]

train Total Loss: 0.8533 | Reconstruction Loss: 0.8263 | KL Loss: 54.0912



3it [00:08,  2.89s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 45/499
----------



9it [01:04,  7.21s/it]

train Total Loss: 0.8413 | Reconstruction Loss: 0.8152 | KL Loss: 52.3147



3it [00:09,  3.06s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 46/499
----------



9it [01:07,  7.55s/it]

train Total Loss: 0.8378 | Reconstruction Loss: 0.8107 | KL Loss: 54.2033



3it [00:08,  2.97s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 47/499
----------



9it [01:06,  7.34s/it]

train Total Loss: 0.8355 | Reconstruction Loss: 0.8096 | KL Loss: 51.7993



3it [00:08,  2.92s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 48/499
----------



9it [01:09,  7.72s/it]

train Total Loss: 0.8447 | Reconstruction Loss: 0.8172 | KL Loss: 54.8543



3it [00:08,  2.98s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 49/499
----------



9it [01:07,  7.49s/it]

train Total Loss: 0.8302 | Reconstruction Loss: 0.8035 | KL Loss: 53.3855



3it [00:08,  2.90s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 50/499
----------



9it [01:07,  7.52s/it]

train Total Loss: 0.8222 | Reconstruction Loss: 0.7942 | KL Loss: 56.1456



3it [00:09,  3.07s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 51/499
----------



9it [01:07,  7.46s/it]

train Total Loss: 0.8088 | Reconstruction Loss: 0.7795 | KL Loss: 58.5348



3it [00:08,  2.82s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 52/499
----------



9it [01:07,  7.45s/it]

train Total Loss: 0.8070 | Reconstruction Loss: 0.7792 | KL Loss: 55.6925



3it [00:08,  2.85s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 53/499
----------



9it [01:07,  7.49s/it]

train Total Loss: 0.8008 | Reconstruction Loss: 0.7723 | KL Loss: 57.0289



3it [00:09,  3.14s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 54/499
----------



9it [01:07,  7.53s/it]

train Total Loss: 0.8120 | Reconstruction Loss: 0.7850 | KL Loss: 53.9399



3it [00:08,  2.99s/it]

val Total Loss: nan | Reconstruction Loss: nan | KL Loss: inf
Epoch 55/499
----------



2it [00:14,  7.44s/it]