In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, MNIST, CIFAR100
from torchvision.transforms import ToTensor
from sklearn.manifold import TSNE

  from .autonotebook import tqdm as notebook_tqdm


## Analiza działania modelu


In [6]:

train = CIFAR10('./data', train=True, download=True)
train_data, val_data = torch.utils.data.random_split(train.data.astype(float), [40000, 10000])
test_data = CIFAR10('./data', train=False, download=True)

training_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)



Files already downloaded and verified
Files already downloaded and verified


In [19]:
from models.GMMPrior import GMMPrior
from models.VAE import Encoder, Decoder, VAE
from src.utils import training, plot_curve
from models.GaussianPrior import StandardPrior
from src.utils import VAEAnalyzer

def analyze(prior_name='gmm', hid_dim=128, lr=1e-3, lat_features=10, max_patience=5, optim_type='adamax'):
    num_components = 4**2
    prior_name = prior_name
    lr = lr
    num_epochs = 50
    max_patience = max_patience
    in_features = 32*32*3
    hid_dim=hid_dim
    lat_features = lat_features
    if prior_name=='gmm':
        prior = GMMPrior(lat_features,num_components)
    else:
        prior = StandardPrior(L=lat_features)
    encoder = Encoder(n_input_features=in_features, n_hidden_neurons=hid_dim, n_latent_features=lat_features)
    decoder = Decoder(n_hidden_neurons=hid_dim, n_latent_features=lat_features, n_output_features=in_features)
    vae = VAE(encoder=encoder, decoder=decoder, prior=prior)
    if optim_type == 'adam':
        optimizer = torch.optim.Adam(vae.parameters(), lr=lr)
    else:
        optimizer = torch.optim.Adamax(vae.parameters(), lr=lr)

    name = 'vae_' + prior_name + '_' + str(num_components) + '_' + str(lat_features)+'_'+str(hid_dim)+'_'+str(max_patience)+'_'+ str(optim_type)
    result_dir ='results/' + name + '/'

    if not(os.path.exists(result_dir)):
        os.mkdir(result_dir)

    nll_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=vae, optimizer=optimizer,
                       training_loader=training_loader, val_loader=val_loader)
    

    plot_curve(nll_val=nll_val, name=result_dir+name)


    
    train = CIFAR10('./data', train=True, download=True, transform=ToTensor())
    _, val_data = torch.utils.data.random_split(train, [40000, 10000])
    analyzer = VAEAnalyzer(model=vae, dataset=val_data, n_samplings=1)
    analyzer._retrieve_reconstructions()
    sil, dav = analyzer.get_metrics()
    return sil, dav





In [17]:

# Prior
priors = ['gmm', 'standard']
for prior in priors:
    analyze(prior_name=prior)



Epoch: 0, val nll=-365400.50420776586
saved!
Epoch: 1, val nll=-365417.4485550267
saved!
Epoch: 2, val nll=-365428.10960796685
saved!
Epoch: 3, val nll=-365432.35110249027
saved!
Epoch: 4, val nll=-365439.5002988087
saved!
Epoch: 5, val nll=-365445.4116087231
saved!
Epoch: 6, val nll=-365445.1865387744
Epoch: 7, val nll=-365453.42471001507
saved!
Epoch: 8, val nll=-365455.210897406
saved!
Epoch: 9, val nll=-365456.02645176376
saved!
Epoch: 10, val nll=-365458.9413004676
saved!
Epoch: 11, val nll=-365459.6469614624
saved!
Epoch: 12, val nll=-365460.717474936
saved!
Epoch: 13, val nll=-365462.1363500895
saved!
Epoch: 14, val nll=-365462.65165666654
saved!
Epoch: 15, val nll=-365462.97604232625
saved!
Epoch: 16, val nll=-365463.1432030609
saved!
Epoch: 17, val nll=-365463.08470499254
Epoch: 18, val nll=-365463.86579904455
saved!
Epoch: 19, val nll=-365463.917910085
saved!
Epoch: 20, val nll=-365464.3682982973
saved!
Epoch: 21, val nll=-365465.03875804826
saved!
Epoch: 22, val nll=-365465.

In [20]:
# optimizer na gmm
optims = ['adam', 'adamax']

for opt in optims:
    analyze(optim_type=opt)

Epoch: 0, val nll=-365393.7816345271
saved!
Epoch: 1, val nll=-365418.11860488314
saved!
Epoch: 2, val nll=-365442.31973475945
saved!
Epoch: 3, val nll=-365449.6507096431
saved!
Epoch: 4, val nll=-365455.125542443
saved!
Epoch: 5, val nll=-365458.185442523
saved!
Epoch: 6, val nll=-365452.0345740907
Epoch: 7, val nll=-365459.6649279623
saved!
Epoch: 8, val nll=-365460.3751367239
saved!
Epoch: 9, val nll=-365461.6175499301
saved!
Epoch: 10, val nll=-365463.3950219687
saved!
Epoch: 11, val nll=-365464.88438596926
saved!
Epoch: 12, val nll=-365465.7939912447
saved!
Epoch: 13, val nll=-365466.42011839116
saved!
Epoch: 14, val nll=-365466.58551939984
saved!
Epoch: 15, val nll=-365462.1067409418
Epoch: 16, val nll=-365465.05429334915
Epoch: 17, val nll=-365465.71840509627
Epoch: 18, val nll=-365466.0572771459
Epoch: 19, val nll=-365463.400490008
Epoch: 20, val nll=-365466.4251177083
Files already downloaded and verified
Silhouette score: -0.0004800753958988935
Davies Bouldin Index: 45.602748

In [21]:
# hid_dim
hid_dims = [64, 128, 256]

for dim in hid_dims:
    analyze(hid_dim=dim)



Epoch: 0, val nll=-362732.60391843366
saved!
Epoch: 1, val nll=-364725.6788017733
saved!
Epoch: 2, val nll=-365075.7321015392
saved!
Epoch: 3, val nll=-365200.5589430454
saved!
Epoch: 4, val nll=-365323.8375849147
saved!
Epoch: 5, val nll=-365329.77009407274
saved!
Epoch: 6, val nll=-365333.9545794487
saved!
Epoch: 7, val nll=-365335.24197962944
saved!
Epoch: 8, val nll=-365449.5635685859
saved!
Epoch: 9, val nll=-365456.31647424813
saved!
Epoch: 10, val nll=-365457.69426584966
saved!
Epoch: 11, val nll=-365459.2557355795
saved!
Epoch: 12, val nll=-365456.4332432104
Epoch: 13, val nll=-365460.65011695493
saved!
Epoch: 14, val nll=-365461.1260278342
saved!
Epoch: 15, val nll=-365459.66492600396
Epoch: 16, val nll=-365461.3158626093
saved!
Epoch: 17, val nll=-365462.28440746426
saved!
Epoch: 18, val nll=-365461.7767736568
Epoch: 19, val nll=-365462.53932010266
saved!
Epoch: 20, val nll=-365464.8666513177
saved!
Epoch: 21, val nll=-365465.034080436
saved!
Epoch: 22, val nll=-365465.136092

In [22]:
lat_dims = [2, 4, 10, 20 ]
for dim in lat_dims:
    analyze(lat_features=dim)

Epoch: 0, val nll=-364840.1103461473
saved!
Epoch: 1, val nll=-365192.9008570006
saved!
Epoch: 2, val nll=-365206.51971520274
saved!
Epoch: 3, val nll=-365343.5583795231
saved!
Epoch: 4, val nll=-365456.59518944804
saved!
Epoch: 5, val nll=-365458.90921677864
saved!
Epoch: 6, val nll=-365460.971769711
saved!
Epoch: 7, val nll=-365462.36509425775
saved!
Epoch: 8, val nll=-365464.4229799238
saved!
Epoch: 9, val nll=-365465.0908277675
saved!
Epoch: 10, val nll=-365465.4326643424
saved!
Epoch: 11, val nll=-365465.96576594916
saved!
Epoch: 12, val nll=-365466.34494492976
saved!
Epoch: 13, val nll=-365458.4256535253
Epoch: 14, val nll=-365466.3579247368
saved!
Epoch: 15, val nll=-365466.6081009941
saved!
Epoch: 16, val nll=-365466.632042262
saved!
Epoch: 17, val nll=-365466.8162312633
saved!
Epoch: 18, val nll=-365466.8978751409
saved!
Epoch: 19, val nll=-365466.93986614124
saved!
Epoch: 20, val nll=-365467.17803661164
saved!
Epoch: 21, val nll=-365467.242734999
saved!
Epoch: 22, val nll=-36

In [32]:
analyze(prior_name='normal')

Epoch: 0, val nll=-364161.54447734525
saved!
Epoch: 1, val nll=-364945.389046538
saved!
Epoch: 2, val nll=-365184.25840648246
saved!
Epoch: 3, val nll=-365275.2683646134
saved!
Epoch: 4, val nll=-365315.8020459612
saved!
Epoch: 5, val nll=-365346.1996019605
saved!
Epoch: 6, val nll=-365347.52604728617
saved!
Epoch: 7, val nll=-365366.4805798652
saved!
Epoch: 8, val nll=-365378.3688869441
saved!
Epoch: 9, val nll=-365379.4710680113
saved!
Epoch: 10, val nll=-365377.98926340166
Epoch: 11, val nll=-365381.98916615726
saved!
Epoch: 12, val nll=-365386.29280197853
saved!
Epoch: 13, val nll=-365383.6569097798
Epoch: 14, val nll=-365362.9379117611
Epoch: 15, val nll=-365326.3774989139
Epoch: 16, val nll=-365425.18653197365
saved!
Epoch: 17, val nll=-365456.08284946193
saved!
Epoch: 18, val nll=-365458.71145744383
saved!
Epoch: 19, val nll=-365460.56635318557
saved!
Epoch: 20, val nll=-365459.67750267964
Epoch: 21, val nll=-365458.30645469564
Epoch: 22, val nll=-365462.83377423947
saved!
Epoch

(-0.0005095536, 45.773779188887985)