In [1]:
import os
import torch
import scipy.stats as stats
import numpy as np

# CODE FILES HERE
from models.vae.vae import Encoder, Decoder, Vae, MODEL_NAME
from solver import Solver
from directories import Directories
from dataloader import DataLoader
from plots import plot_losses, plot_gaussian_distributions, plot_rl_kl, plot_latent_space, \
plot_latent_manifold, plot_faces_grid, plot_faces_samples_grid

# SETTINGS HERE
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # to see the CUDA stack
%matplotlib inline
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
# supress cluttering warnings in solutions
import warnings
warnings.filterwarnings('ignore')

In [2]:
# setting device on GPU if available, else CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

# Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

Using device: cpu



In [3]:
# Choose the dataset and tune hyperparameters here!
dataset = "FF"

batch_size = 128
optimizer = torch.optim.Adam

if dataset == "MNIST":
    epochs = 20
    hidden_dim = 500
    z_dim = 2
    beta = 4 if z_dim == 2 else 4
    batch_norm = False
    lr_scheduler = torch.optim.lr_scheduler.StepLR
    step_config = {
        "step_size" : 300,
        "gamma" : 0.1 # or 0.75
    }
    optim_config = {
        "lr": 1e-3,
        "weight_decay": None
    }
elif dataset == "LFW":
    epochs = 30
    hidden_dim = 700
    z_dim = 40
    beta = 1 if z_dim == 2 else 1
    batch_norm = True
    lr_scheduler = torch.optim.lr_scheduler.StepLR
    step_config = {
        "step_size" : 30,
        "gamma" : 0.1
    }
    optim_config = {
        "lr": 1e-2,
        "weight_decay": None
    }
elif dataset == "FF":
    epochs = 1500
    hidden_dim = 200
    z_dim = 2
    beta = 1
    batch_norm = True
    lr_scheduler = torch.optim.lr_scheduler.StepLR
    step_config = {
        "step_size" : 300,
        "gamma" : 0.1
    }
    optim_config = {
        "lr": 1e-2,
        "weight_decay": None
    }

In [4]:
directories = Directories(MODEL_NAME, dataset, z_dim)
data_loader = DataLoader(directories, batch_size, dataset)
model = Vae(data_loader.input_dim, hidden_dim, z_dim, batch_norm)
solver = Solver(model, data_loader, optimizer, z_dim, epochs, beta, step_config, optim_config, lr_scheduler=lr_scheduler)
solver.main()

+++++ START RUN | saved files in vae/FF_z=2_8 +++++
====> Epoch: 1 train set loss avg: 398.8467
====> Test set loss avg: 518.7070
0.80 seconds for epoch 1
====> Epoch: 2 train set loss avg: 382.9495
====> Test set loss avg: 390.6739
0.50 seconds for epoch 2
====> Epoch: 3 train set loss avg: 373.3598
====> Test set loss avg: 369.8354
0.49 seconds for epoch 3
====> Epoch: 4 train set loss avg: 366.8275
====> Test set loss avg: 362.6427
0.48 seconds for epoch 4
====> Epoch: 5 train set loss avg: 362.2196
====> Test set loss avg: 355.9145
0.48 seconds for epoch 5
====> Epoch: 6 train set loss avg: 358.8214
====> Test set loss avg: 352.2199
0.47 seconds for epoch 6
====> Epoch: 7 train set loss avg: 356.4108
====> Test set loss avg: 350.1155
0.49 seconds for epoch 7
====> Epoch: 8 train set loss avg: 354.6864
====> Test set loss avg: 348.0011
0.60 seconds for epoch 8
====> Epoch: 9 train set loss avg: 353.3590
====> Test set loss avg: 346.5825
0.51 seconds for epoch 9
====> Epoch: 10 train

====> Epoch: 79 train set loss avg: 347.3156
====> Test set loss avg: 346.0152
0.57 seconds for epoch 79
====> Epoch: 80 train set loss avg: 347.3455
====> Test set loss avg: 345.1803
0.63 seconds for epoch 80
====> Epoch: 81 train set loss avg: 347.3257
====> Test set loss avg: 344.3811
0.65 seconds for epoch 81
====> Epoch: 82 train set loss avg: 347.3947
====> Test set loss avg: 344.7350
0.53 seconds for epoch 82
====> Epoch: 83 train set loss avg: 347.3260
====> Test set loss avg: 347.6412
0.53 seconds for epoch 83
====> Epoch: 84 train set loss avg: 347.4352
====> Test set loss avg: 343.7434
0.55 seconds for epoch 84
====> Epoch: 85 train set loss avg: 347.3822
====> Test set loss avg: 344.3548
0.54 seconds for epoch 85
====> Epoch: 86 train set loss avg: 347.3971
====> Test set loss avg: 344.5550
0.53 seconds for epoch 86
====> Epoch: 87 train set loss avg: 347.3850
====> Test set loss avg: 344.9804
0.54 seconds for epoch 87
====> Epoch: 88 train set loss avg: 347.5168
====> Test

====> Epoch: 156 train set loss avg: 347.2388
====> Test set loss avg: 345.3148
0.63 seconds for epoch 156
====> Epoch: 157 train set loss avg: 347.2555
====> Test set loss avg: 345.9615
0.84 seconds for epoch 157
====> Epoch: 158 train set loss avg: 347.2526
====> Test set loss avg: 344.7004
0.75 seconds for epoch 158
====> Epoch: 159 train set loss avg: 347.1441
====> Test set loss avg: 346.6191
0.64 seconds for epoch 159
====> Epoch: 160 train set loss avg: 346.9707
====> Test set loss avg: 344.4359
0.75 seconds for epoch 160
====> Epoch: 161 train set loss avg: 347.1112
====> Test set loss avg: 346.1953
0.62 seconds for epoch 161
====> Epoch: 162 train set loss avg: 347.2965
====> Test set loss avg: 345.9316
0.63 seconds for epoch 162
====> Epoch: 163 train set loss avg: 347.2223
====> Test set loss avg: 345.2987
0.62 seconds for epoch 163
====> Epoch: 164 train set loss avg: 347.1356
====> Test set loss avg: 344.9417
0.63 seconds for epoch 164
====> Epoch: 165 train set loss avg: 

====> Epoch: 233 train set loss avg: 347.2057
====> Test set loss avg: 345.9751
0.74 seconds for epoch 233
====> Epoch: 234 train set loss avg: 347.0901
====> Test set loss avg: 346.4258
0.86 seconds for epoch 234
====> Epoch: 235 train set loss avg: 347.0907
====> Test set loss avg: 346.5020
0.79 seconds for epoch 235
====> Epoch: 236 train set loss avg: 347.1595
====> Test set loss avg: 346.3227
0.77 seconds for epoch 236
====> Epoch: 237 train set loss avg: 346.9691
====> Test set loss avg: 344.7961
0.73 seconds for epoch 237
====> Epoch: 238 train set loss avg: 347.0742
====> Test set loss avg: 345.2456
0.72 seconds for epoch 238
====> Epoch: 239 train set loss avg: 347.1862
====> Test set loss avg: 344.9291
0.72 seconds for epoch 239
====> Epoch: 240 train set loss avg: 347.1348
====> Test set loss avg: 345.3102
0.74 seconds for epoch 240
====> Epoch: 241 train set loss avg: 347.1241
====> Test set loss avg: 344.7534
0.74 seconds for epoch 241
====> Epoch: 242 train set loss avg: 

====> Epoch: 310 train set loss avg: 346.6792
====> Test set loss avg: 345.5456
0.84 seconds for epoch 310
====> Epoch: 311 train set loss avg: 346.6534
====> Test set loss avg: 345.6719
0.89 seconds for epoch 311
====> Epoch: 312 train set loss avg: 346.7506
====> Test set loss avg: 345.8716
0.87 seconds for epoch 312
====> Epoch: 313 train set loss avg: 346.9054
====> Test set loss avg: 345.9931
0.86 seconds for epoch 313
====> Epoch: 314 train set loss avg: 346.9145
====> Test set loss avg: 345.9256
0.87 seconds for epoch 314
====> Epoch: 315 train set loss avg: 346.7912
====> Test set loss avg: 345.7046
0.90 seconds for epoch 315
====> Epoch: 316 train set loss avg: 346.7750
====> Test set loss avg: 345.4841
0.83 seconds for epoch 316
====> Epoch: 317 train set loss avg: 346.7621
====> Test set loss avg: 345.5594
0.84 seconds for epoch 317
====> Epoch: 318 train set loss avg: 346.7957
====> Test set loss avg: 345.6527
0.83 seconds for epoch 318
====> Epoch: 319 train set loss avg: 

====> Epoch: 387 train set loss avg: 346.7526
====> Test set loss avg: 345.7603
0.90 seconds for epoch 387
====> Epoch: 388 train set loss avg: 346.8232
====> Test set loss avg: 346.0781
0.91 seconds for epoch 388
====> Epoch: 389 train set loss avg: 346.8500
====> Test set loss avg: 346.2636
0.90 seconds for epoch 389
====> Epoch: 390 train set loss avg: 346.8326
====> Test set loss avg: 346.2324
0.91 seconds for epoch 390
====> Epoch: 391 train set loss avg: 346.7493
====> Test set loss avg: 345.7602
0.91 seconds for epoch 391
====> Epoch: 392 train set loss avg: 346.6581
====> Test set loss avg: 345.6522
0.93 seconds for epoch 392
====> Epoch: 393 train set loss avg: 346.6915
====> Test set loss avg: 345.8542
1.07 seconds for epoch 393
====> Epoch: 394 train set loss avg: 346.6551
====> Test set loss avg: 345.8440
0.91 seconds for epoch 394
====> Epoch: 395 train set loss avg: 346.6109
====> Test set loss avg: 345.7378
0.86 seconds for epoch 395
====> Epoch: 396 train set loss avg: 

====> Epoch: 464 train set loss avg: 346.7069
====> Test set loss avg: 346.1276
0.89 seconds for epoch 464
====> Epoch: 465 train set loss avg: 346.7950
====> Test set loss avg: 346.2082
0.88 seconds for epoch 465
====> Epoch: 466 train set loss avg: 346.8481
====> Test set loss avg: 346.1253
0.91 seconds for epoch 466
====> Epoch: 467 train set loss avg: 346.6526
====> Test set loss avg: 345.7796
0.93 seconds for epoch 467
====> Epoch: 468 train set loss avg: 346.7414
====> Test set loss avg: 346.0855
0.89 seconds for epoch 468
====> Epoch: 469 train set loss avg: 346.7279
====> Test set loss avg: 346.2694
0.88 seconds for epoch 469
====> Epoch: 470 train set loss avg: 346.7885
====> Test set loss avg: 346.4020
0.88 seconds for epoch 470
====> Epoch: 471 train set loss avg: 346.6992
====> Test set loss avg: 346.2384
0.96 seconds for epoch 471
====> Epoch: 472 train set loss avg: 346.8388
====> Test set loss avg: 346.2150
0.88 seconds for epoch 472
====> Epoch: 473 train set loss avg: 

====> Epoch: 541 train set loss avg: 346.6477
====> Test set loss avg: 346.1325
0.99 seconds for epoch 541
====> Epoch: 542 train set loss avg: 346.6600
====> Test set loss avg: 346.1423
0.96 seconds for epoch 542
====> Epoch: 543 train set loss avg: 346.7720
====> Test set loss avg: 346.3196
1.13 seconds for epoch 543
====> Epoch: 544 train set loss avg: 346.7486
====> Test set loss avg: 346.4977
1.10 seconds for epoch 544
====> Epoch: 545 train set loss avg: 346.8044
====> Test set loss avg: 346.5201
0.95 seconds for epoch 545
====> Epoch: 546 train set loss avg: 346.7718
====> Test set loss avg: 346.2849
0.94 seconds for epoch 546
====> Epoch: 547 train set loss avg: 346.7293
====> Test set loss avg: 346.3902
1.15 seconds for epoch 547
====> Epoch: 548 train set loss avg: 346.7045
====> Test set loss avg: 346.0755
1.05 seconds for epoch 548
====> Epoch: 549 train set loss avg: 346.7754
====> Test set loss avg: 346.2267
1.17 seconds for epoch 549
====> Epoch: 550 train set loss avg: 

====> Epoch: 618 train set loss avg: 346.6450
====> Test set loss avg: 346.1478
0.95 seconds for epoch 618
====> Epoch: 619 train set loss avg: 346.7650
====> Test set loss avg: 346.3915
0.94 seconds for epoch 619
====> Epoch: 620 train set loss avg: 346.6545
====> Test set loss avg: 346.3679
0.93 seconds for epoch 620
====> Epoch: 621 train set loss avg: 346.5942
====> Test set loss avg: 346.3957
1.16 seconds for epoch 621
====> Epoch: 622 train set loss avg: 346.6747
====> Test set loss avg: 346.2793
0.93 seconds for epoch 622
====> Epoch: 623 train set loss avg: 346.6345
====> Test set loss avg: 346.1736
1.00 seconds for epoch 623
====> Epoch: 624 train set loss avg: 346.5445
====> Test set loss avg: 346.3102
0.96 seconds for epoch 624
====> Epoch: 625 train set loss avg: 346.6175
====> Test set loss avg: 346.3724
0.97 seconds for epoch 625
====> Epoch: 626 train set loss avg: 346.6208
====> Test set loss avg: 346.2337
0.93 seconds for epoch 626
====> Epoch: 627 train set loss avg: 

====> Epoch: 695 train set loss avg: 346.7527
====> Test set loss avg: 346.3749
0.94 seconds for epoch 695
====> Epoch: 696 train set loss avg: 346.7478
====> Test set loss avg: 346.3137
0.93 seconds for epoch 696
====> Epoch: 697 train set loss avg: 346.5778
====> Test set loss avg: 346.2292
0.96 seconds for epoch 697
====> Epoch: 698 train set loss avg: 346.6153
====> Test set loss avg: 346.3706
0.96 seconds for epoch 698
====> Epoch: 699 train set loss avg: 346.6423
====> Test set loss avg: 346.4571
0.95 seconds for epoch 699
====> Epoch: 700 train set loss avg: 346.7181
====> Test set loss avg: 346.3874
1.14 seconds for epoch 700
====> Epoch: 701 train set loss avg: 346.6309
====> Test set loss avg: 346.2555
1.05 seconds for epoch 701
====> Epoch: 702 train set loss avg: 346.6448
====> Test set loss avg: 346.0799
1.03 seconds for epoch 702
====> Epoch: 703 train set loss avg: 346.7352
====> Test set loss avg: 346.2344
1.00 seconds for epoch 703
====> Epoch: 704 train set loss avg: 

====> Epoch: 772 train set loss avg: 346.7015
====> Test set loss avg: 346.2677
1.09 seconds for epoch 772
====> Epoch: 773 train set loss avg: 346.6809
====> Test set loss avg: 346.5389
1.11 seconds for epoch 773
====> Epoch: 774 train set loss avg: 346.6453
====> Test set loss avg: 346.2631
1.14 seconds for epoch 774
====> Epoch: 775 train set loss avg: 346.7059
====> Test set loss avg: 346.2054
1.13 seconds for epoch 775
====> Epoch: 776 train set loss avg: 346.5261
====> Test set loss avg: 346.1717
1.04 seconds for epoch 776
====> Epoch: 777 train set loss avg: 346.7098
====> Test set loss avg: 346.3194
1.10 seconds for epoch 777
====> Epoch: 778 train set loss avg: 346.6955
====> Test set loss avg: 346.1202
1.04 seconds for epoch 778
====> Epoch: 779 train set loss avg: 346.6251
====> Test set loss avg: 346.1556
1.36 seconds for epoch 779
====> Epoch: 780 train set loss avg: 346.5839
====> Test set loss avg: 346.1960
1.39 seconds for epoch 780
====> Epoch: 781 train set loss avg: 

====> Epoch: 849 train set loss avg: 346.7206
====> Test set loss avg: 346.3219
1.17 seconds for epoch 849
====> Epoch: 850 train set loss avg: 346.5197
====> Test set loss avg: 346.2785
1.14 seconds for epoch 850
====> Epoch: 851 train set loss avg: 346.6222
====> Test set loss avg: 346.2696
1.03 seconds for epoch 851
====> Epoch: 852 train set loss avg: 346.6697
====> Test set loss avg: 346.2206
1.12 seconds for epoch 852
====> Epoch: 853 train set loss avg: 346.5819
====> Test set loss avg: 346.2437
1.19 seconds for epoch 853
====> Epoch: 854 train set loss avg: 346.6054
====> Test set loss avg: 346.2719
1.11 seconds for epoch 854
====> Epoch: 855 train set loss avg: 346.7596
====> Test set loss avg: 346.3006
1.38 seconds for epoch 855
====> Epoch: 856 train set loss avg: 346.7696
====> Test set loss avg: 346.3463
1.53 seconds for epoch 856
====> Epoch: 857 train set loss avg: 346.7224
====> Test set loss avg: 346.2504
1.31 seconds for epoch 857
====> Epoch: 858 train set loss avg: 

1.22 seconds for epoch 925
====> Epoch: 926 train set loss avg: 346.6679
====> Test set loss avg: 346.2914
1.12 seconds for epoch 926
====> Epoch: 927 train set loss avg: 346.6752
====> Test set loss avg: 346.1983
1.34 seconds for epoch 927
====> Epoch: 928 train set loss avg: 346.6621
====> Test set loss avg: 346.2948
1.17 seconds for epoch 928
====> Epoch: 929 train set loss avg: 346.6146
====> Test set loss avg: 346.3648
1.62 seconds for epoch 929
====> Epoch: 930 train set loss avg: 346.7105
====> Test set loss avg: 346.2662
1.71 seconds for epoch 930
====> Epoch: 931 train set loss avg: 346.6548
====> Test set loss avg: 346.2579
1.13 seconds for epoch 931
====> Epoch: 932 train set loss avg: 346.6925
====> Test set loss avg: 346.2098
1.36 seconds for epoch 932
====> Epoch: 933 train set loss avg: 346.6517
====> Test set loss avg: 346.2232
1.45 seconds for epoch 933
====> Epoch: 934 train set loss avg: 346.6099
====> Test set loss avg: 346.2729
1.19 seconds for epoch 934
====> Epoc

====> Epoch: 1003 train set loss avg: 346.6177
====> Test set loss avg: 346.2748
1.02 seconds for epoch 1003
====> Epoch: 1004 train set loss avg: 346.6510
====> Test set loss avg: 346.3557
1.00 seconds for epoch 1004
====> Epoch: 1005 train set loss avg: 346.7594
====> Test set loss avg: 346.1955
0.98 seconds for epoch 1005
====> Epoch: 1006 train set loss avg: 346.7077
====> Test set loss avg: 346.4008
0.98 seconds for epoch 1006
====> Epoch: 1007 train set loss avg: 346.7187
====> Test set loss avg: 346.4316
1.00 seconds for epoch 1007
====> Epoch: 1008 train set loss avg: 346.5886
====> Test set loss avg: 346.4169
1.05 seconds for epoch 1008
====> Epoch: 1009 train set loss avg: 346.6167
====> Test set loss avg: 346.2357
0.99 seconds for epoch 1009
====> Epoch: 1010 train set loss avg: 346.6592
====> Test set loss avg: 346.6006
1.04 seconds for epoch 1010
====> Epoch: 1011 train set loss avg: 346.6446
====> Test set loss avg: 346.1998
0.99 seconds for epoch 1011
====> Epoch: 1012 t

====> Epoch: 1079 train set loss avg: 346.8419
====> Test set loss avg: 346.2088
0.95 seconds for epoch 1079
====> Epoch: 1080 train set loss avg: 346.5810
====> Test set loss avg: 346.2391
0.94 seconds for epoch 1080
====> Epoch: 1081 train set loss avg: 346.6777
====> Test set loss avg: 346.2870
0.94 seconds for epoch 1081
====> Epoch: 1082 train set loss avg: 346.7110
====> Test set loss avg: 346.2441
0.93 seconds for epoch 1082
====> Epoch: 1083 train set loss avg: 346.6364
====> Test set loss avg: 346.3740
1.00 seconds for epoch 1083
====> Epoch: 1084 train set loss avg: 346.5931
====> Test set loss avg: 346.3525
0.95 seconds for epoch 1084
====> Epoch: 1085 train set loss avg: 346.6375
====> Test set loss avg: 346.4601
0.94 seconds for epoch 1085
====> Epoch: 1086 train set loss avg: 346.5981
====> Test set loss avg: 346.2516
0.95 seconds for epoch 1086
====> Epoch: 1087 train set loss avg: 346.6655
====> Test set loss avg: 346.2100
0.95 seconds for epoch 1087
====> Epoch: 1088 t

====> Epoch: 1155 train set loss avg: 346.7177
====> Test set loss avg: 346.2998
0.96 seconds for epoch 1155
====> Epoch: 1156 train set loss avg: 346.7046
====> Test set loss avg: 346.2952
0.94 seconds for epoch 1156
====> Epoch: 1157 train set loss avg: 346.6429
====> Test set loss avg: 346.4127
0.95 seconds for epoch 1157
====> Epoch: 1158 train set loss avg: 346.8859
====> Test set loss avg: 346.3960
0.95 seconds for epoch 1158
====> Epoch: 1159 train set loss avg: 346.6360
====> Test set loss avg: 346.3358
0.95 seconds for epoch 1159
====> Epoch: 1160 train set loss avg: 346.6446
====> Test set loss avg: 346.1394
0.95 seconds for epoch 1160
====> Epoch: 1161 train set loss avg: 346.7019
====> Test set loss avg: 346.3262
0.95 seconds for epoch 1161
====> Epoch: 1162 train set loss avg: 346.6598
====> Test set loss avg: 346.3318
0.96 seconds for epoch 1162
====> Epoch: 1163 train set loss avg: 346.5181
====> Test set loss avg: 346.1714
0.96 seconds for epoch 1163
====> Epoch: 1164 t

====> Epoch: 1231 train set loss avg: 346.6506
====> Test set loss avg: 346.4222
0.96 seconds for epoch 1231
====> Epoch: 1232 train set loss avg: 346.7350
====> Test set loss avg: 346.3962
0.97 seconds for epoch 1232
====> Epoch: 1233 train set loss avg: 346.7940
====> Test set loss avg: 346.2930
0.96 seconds for epoch 1233
====> Epoch: 1234 train set loss avg: 346.7004
====> Test set loss avg: 346.3628
0.97 seconds for epoch 1234
====> Epoch: 1235 train set loss avg: 346.7875
====> Test set loss avg: 346.2978
0.95 seconds for epoch 1235
====> Epoch: 1236 train set loss avg: 346.6480
====> Test set loss avg: 346.2862
0.96 seconds for epoch 1236
====> Epoch: 1237 train set loss avg: 346.5647
====> Test set loss avg: 346.3595
0.96 seconds for epoch 1237
====> Epoch: 1238 train set loss avg: 346.7049
====> Test set loss avg: 346.2385
0.96 seconds for epoch 1238
====> Epoch: 1239 train set loss avg: 346.6365
====> Test set loss avg: 346.3275
0.97 seconds for epoch 1239
====> Epoch: 1240 t

====> Epoch: 1307 train set loss avg: 346.6938
====> Test set loss avg: 346.2925
0.98 seconds for epoch 1307
====> Epoch: 1308 train set loss avg: 346.6982
====> Test set loss avg: 346.2983
0.97 seconds for epoch 1308
====> Epoch: 1309 train set loss avg: 346.5582
====> Test set loss avg: 346.4678
0.99 seconds for epoch 1309
====> Epoch: 1310 train set loss avg: 346.5239
====> Test set loss avg: 346.2714
0.99 seconds for epoch 1310
====> Epoch: 1311 train set loss avg: 346.6369
====> Test set loss avg: 346.3650
0.97 seconds for epoch 1311
====> Epoch: 1312 train set loss avg: 346.7885
====> Test set loss avg: 346.3553
0.98 seconds for epoch 1312
====> Epoch: 1313 train set loss avg: 346.6555
====> Test set loss avg: 346.2748
0.98 seconds for epoch 1313
====> Epoch: 1314 train set loss avg: 346.7552
====> Test set loss avg: 346.2819
0.98 seconds for epoch 1314
====> Epoch: 1315 train set loss avg: 346.5757
====> Test set loss avg: 346.3602
1.03 seconds for epoch 1315
====> Epoch: 1316 t

====> Epoch: 1383 train set loss avg: 346.6760
====> Test set loss avg: 346.1910
1.17 seconds for epoch 1383
====> Epoch: 1384 train set loss avg: 346.6220
====> Test set loss avg: 346.3395
1.13 seconds for epoch 1384
====> Epoch: 1385 train set loss avg: 346.7214
====> Test set loss avg: 346.4444
1.12 seconds for epoch 1385
====> Epoch: 1386 train set loss avg: 346.6497
====> Test set loss avg: 346.3228
1.18 seconds for epoch 1386
====> Epoch: 1387 train set loss avg: 346.7011
====> Test set loss avg: 346.3929
1.23 seconds for epoch 1387
====> Epoch: 1388 train set loss avg: 346.7557
====> Test set loss avg: 346.2857
1.15 seconds for epoch 1388
====> Epoch: 1389 train set loss avg: 346.6461
====> Test set loss avg: 346.2695
1.07 seconds for epoch 1389
====> Epoch: 1390 train set loss avg: 346.6702
====> Test set loss avg: 346.2541
1.25 seconds for epoch 1390
====> Epoch: 1391 train set loss avg: 346.7129
====> Test set loss avg: 346.3433
1.14 seconds for epoch 1391
====> Epoch: 1392 t

====> Epoch: 1459 train set loss avg: 346.6234
====> Test set loss avg: 346.4313
0.99 seconds for epoch 1459
====> Epoch: 1460 train set loss avg: 346.7386
====> Test set loss avg: 346.3679
1.04 seconds for epoch 1460
====> Epoch: 1461 train set loss avg: 346.5696
====> Test set loss avg: 346.4057
0.99 seconds for epoch 1461
====> Epoch: 1462 train set loss avg: 346.5456
====> Test set loss avg: 346.2527
0.98 seconds for epoch 1462
====> Epoch: 1463 train set loss avg: 346.6885
====> Test set loss avg: 346.4356
1.02 seconds for epoch 1463
====> Epoch: 1464 train set loss avg: 346.6576
====> Test set loss avg: 346.2875
0.99 seconds for epoch 1464
====> Epoch: 1465 train set loss avg: 346.7046
====> Test set loss avg: 346.2109
0.99 seconds for epoch 1465
====> Epoch: 1466 train set loss avg: 346.4751
====> Test set loss avg: 346.3108
0.98 seconds for epoch 1466
====> Epoch: 1467 train set loss avg: 346.7162
====> Test set loss avg: 346.2840
0.97 seconds for epoch 1467
====> Epoch: 1468 t

In [None]:
# Insert name of model here if want to load a model, e.g. solver.save_model_dir + "/VAE_MNIST_train_loss=151.39_z=2.pt"
#solver = torch.load(solver.save_model_dir + "/VAE_MNIST_train_loss=97.15_z=20.pt")
#solver.model.eval()

In [None]:
# Plotting train and test losses for all epochs
plot_losses(solver, solver.train_loss_history["train_loss_acc"], solver.test_loss_history)

In [None]:
# Plotting the gaussian of z space and some metrics about the space
plot_gaussian_distributions(solver, len(solver.train_loss_history["train_loss_acc"]))

In [None]:
# Monitoring the reconstruction loss (likelihood lower bound) and KL divergence
DEBUG = 1
if DEBUG:
    for epoch, train_loss, test_loss, rl, kl in zip(solver.train_loss_history["epochs"], \
                             solver.train_loss_history["train_loss_acc"], \
                             solver.test_loss_history, \
                             solver.train_loss_history["recon_loss_acc"], \
                             solver.train_loss_history["kl_diverg_acc"]):
        print("epoch: {}, train_loss: {:.2f}, test_loss: {:.2f}, recon. loss: {:.2f}, KL div.: {:.2f}".format(
            epoch, train_loss, test_loss, rl, kl))
        print("overfitting: {:.2f}".format(abs(test_loss-train_loss)))
plot_rl_kl(solver, solver.train_loss_history["recon_loss_acc"], solver.train_loss_history["kl_diverg_acc"])

In [None]:
# visualize q(z) (latent space z)
if solver.z_dim == 2:
    if solver.data_loader.with_labels:
        plot_latent_space(solver, solver.z_space, var="z", title="classes", labels=solver.data_labels)
    else:
        plot_latent_space(solver, solver.z_space, var="z")
else:
    print("Plot of latent space not possible as dimension of z is not 2")

In [None]:
# Visualizations of learned data manifold for generative models with two-dimensional latent space
if solver.z_dim == 2:
    if solver.data_loader.dataset == "MNIST":
        n = 20
        grid_x = grid_y = np.linspace(-3, 3, n)
        plot_latent_manifold(solver, "bone", grid_x, grid_y, n)
    if solver.data_loader.dataset == "LFW" or solver.data_loader.dataset == "FF":
        n = 10
        grid_x = grid_y = np.linspace(-3, 3, n)
        plot_latent_manifold(solver, "gray", grid_x, grid_y, n, fig_size=(10, 8))
else:
    print("Plot is not possible as dimension of z is not 2")

In [None]:
# plots real faces and in grid samples
if solver.data_loader.dataset == "LFW" or solver.data_loader.dataset == "FF":
    n = 225
    n_cols = 15
    plot_faces_grid(n, n_cols, solver)
    plot_faces_samples_grid(n, n_cols, solver)

In [None]:
last_train_loss = solver.train_loss_history["train_loss_acc"][-1]
torch.save(solver, solver.data_loader.directories.result_dir + "/model_VAE_" + solver.data_loader.dataset + "_train_loss=" + "{0:.2f}".format(last_train_loss) + "_z=" + str(solver.z_dim) + ".pt")