In [1]:
import torch
import torch.nn
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import torchvision
import torch
from torch import nn, relu
from torchvision import datasets, transforms
from torch.distributions.kl import kl_divergence
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from time import time
from model import LinearVAE, pPCA,DeepNonLinearVAE

In [2]:
class config:
    latent_dim = 200
    data_size = 1000
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(config.device)

def seed_torch(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)

cuda:0


In [3]:
seed_torch(0)
latent_dim = config.latent_dim
data_size = config.data_size
device = config.device

In [4]:
mnist_data = datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
]))
def preprocess(x, eps):
    x = x/ 255.0
    x = eps + (1 - 2 * eps) * x
    x = np.log(x / (1.0 - x))
    x = x.to(torch.float32)
    return x


loader = mnist_data.data[:data_size].view(-1, 784)/255
loader = preprocess(loader,1e-6)
loader_numpy = loader.numpy()
loader = loader.to(device)

In [5]:
# Deep Non Linear VAE
model = DeepNonLinearVAE(784, latent_dim, 784, 1.0, True).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)

# Probability PCA
pcaModel = pPCA(loader_numpy, latent_dim)
W_mle, sigma_mle, loglikelihood = pcaModel.get_result()

elbo_result = []
for i in range(1000000):
    elbo = model(loader)[0]
    optimizer.zero_grad()
    elbo.backward()
    optimizer.step()
#     if np.abs(-elbo.detach().cpu().numpy() - loglikelihood) < 1:
#         print("End Epoch:{}".format(i))

#         break
    if i % 500 == 0:
        elbo_result.append(-elbo)
        print("Epoch:{},     \t exact loglikelihood:{},   \t ELBO:{}".format(i, loglikelihood, -elbo))

Epoch:0,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-24710.447265625
Epoch:500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-1759.8199462890625
Epoch:1000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-1698.4896240234375
Epoch:1500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-1503.6546630859375
Epoch:2000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-1450.7763671875
Epoch:2500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-1406.395263671875
Epoch:3000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-1836.894775390625
Epoch:3500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-1545.78466796875
Epoch:4000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-1487.7342529296875
Epoch:4500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-1451.33544921875
Epoch:5000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-1414.8568115234375
Epoch:5500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-138

Epoch:48000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-61.734527587890625
Epoch:48500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:57.98548889160156
Epoch:49000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:153.5625762939453
Epoch:49500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:105.79316711425781
Epoch:50000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:97.8600082397461
Epoch:50500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:6.7633819580078125
Epoch:51000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:144.56015014648438
Epoch:51500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-24.345352172851562
Epoch:52000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:156.97055053710938
Epoch:52500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:62.4056396484375
Epoch:53000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:198.38525390625
Epoch:53500,     	 exact loglikelihood:-932.918761060843,   	

Epoch:96000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:61.49598693847656
Epoch:96500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:514.3396606445312
Epoch:97000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:-124.78450775146484
Epoch:97500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:342.4851379394531
Epoch:98000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:533.8673706054688
Epoch:98500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:8.2392578125
Epoch:99000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:463.0679931640625
Epoch:99500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:517.8394165039062
Epoch:100000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:438.5930480957031
Epoch:100500,     	 exact loglikelihood:-932.918761060843,   	 ELBO:457.67755126953125
Epoch:101000,     	 exact loglikelihood:-932.918761060843,   	 ELBO:410.1942138671875
Epoch:101500,     	 exact loglikelihood:-932.918761060843,   	 E

KeyboardInterrupt: 