In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
import sys
sys.path.append('/home/hao/Research/probtorch/')
sys.path.append('../models/')

import probtorch
from probtorch.util import expand_inputs
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [2]:
NUM_PIXELS = 784
NUM_HIDDEN = 400
NUM_LATENT = 10   

# training parameters
NUM_SAMPLES = 10
NUM_BATCH = 100
NUM_EPOCHS = 500
LEARNING_RATE = 1e-3
BETA1 = 0.90
EPS = 1e-9
CUDA = torch.cuda.is_available()

# path parameters
MODEL_NAME = 'vanilla-vae-%02ddim' % NUM_LATENT
DATA_PATH = '/home/hao/Research/apg_data/mnist/vanilla/data'
WEIGHTS_PATH = '../weights/'
RESTORE = False

In [3]:
from vae_dec_digit import Decoder
from vae_enc_digit import Encoder
enc = Encoder(num_pixels=NUM_PIXELS, num_hidden=NUM_HIDDEN, z_what_dim=NUM_LATENT)
dec = Decoder(num_pixels=NUM_PIXELS, num_hidden=NUM_HIDDEN, z_what_dim=NUM_LATENT)
if CUDA:
    enc.cuda()
    dec.cuda()

optimizer =  torch.optim.Adam(list(enc.parameters())+list(dec.parameters()), lr=LEARNING_RATE, betas=(BETA1, 0.999))

In [4]:
def elbo(q, p, alpha=0.1):
    if NUM_SAMPLES is None:
        return probtorch.objectives.montecarlo.elbo(q, p, sample_dim=None, batch_dim=0, alpha=alpha)
    else:
        return probtorch.objectives.montecarlo.elbo(q, p, sample_dim=0, batch_dim=1, alpha=alpha)

In [5]:
from torchvision import datasets, transforms
train_data = torch.utils.data.DataLoader(datasets.MNIST(DATA_PATH, train=True, download=True, transform=transforms.ToTensor()), batch_size=NUM_BATCH, shuffle=True) 

In [6]:
def train(data, enc, dec, optimizer):
    epoch_elbo = 0.0
    enc.train()
    dec.train()
    N = 0
    for b, (images, labels) in enumerate(data):
        if images.size()[0] == NUM_BATCH:
            N += 1
            images = images.view(-1, NUM_PIXELS)
            if CUDA:
                images = images.cuda()
            optimizer.zero_grad()
            q = enc(images, num_samples=NUM_SAMPLES)
            p = dec(images, q, num_samples=NUM_SAMPLES)
            loss = -elbo(q, p)
            loss.backward()
            optimizer.step()
            if CUDA:
                loss = loss.cpu()
            epoch_elbo -= float(loss.item())

    return epoch_elbo / N

In [None]:
import time
from random import random
if not RESTORE:
    mask = {}
    for e in range(NUM_EPOCHS):
        train_start = time.time()
        train_elbo = train(train_data, enc, dec, optimizer)
        train_end = time.time()
        print('[Epoch %d] Train: ELBO %.4e (%ds)' % (e, train_elbo, train_end - train_start))
torch.save(enc.state_dict(), WEIGHTS_PATH +'vanilla-vae-enc')
torch.save(dec.state_dict(), WEIGHTS_PATH +'vanilla-vae-dec')

[Epoch 0] Train: ELBO -1.6539e+02 (10s)
[Epoch 1] Train: ELBO -1.2132e+02 (10s)
[Epoch 2] Train: ELBO -1.1335e+02 (10s)
[Epoch 3] Train: ELBO -1.1013e+02 (10s)
[Epoch 4] Train: ELBO -1.0815e+02 (10s)
[Epoch 5] Train: ELBO -1.0677e+02 (10s)
[Epoch 6] Train: ELBO -1.0573e+02 (10s)
[Epoch 7] Train: ELBO -1.0493e+02 (9s)
[Epoch 8] Train: ELBO -1.0425e+02 (9s)
[Epoch 9] Train: ELBO -1.0371e+02 (9s)
[Epoch 10] Train: ELBO -1.0325e+02 (9s)
[Epoch 11] Train: ELBO -1.0281e+02 (9s)
[Epoch 12] Train: ELBO -1.0246e+02 (9s)
[Epoch 13] Train: ELBO -1.0213e+02 (9s)
[Epoch 14] Train: ELBO -1.0185e+02 (10s)
[Epoch 15] Train: ELBO -1.0152e+02 (10s)
[Epoch 16] Train: ELBO -1.0130e+02 (10s)
[Epoch 17] Train: ELBO -1.0108e+02 (10s)
[Epoch 18] Train: ELBO -1.0087e+02 (10s)
[Epoch 19] Train: ELBO -1.0067e+02 (10s)
[Epoch 20] Train: ELBO -1.0047e+02 (10s)
[Epoch 21] Train: ELBO -1.0032e+02 (9s)
[Epoch 22] Train: ELBO -1.0016e+02 (9s)
[Epoch 23] Train: ELBO -9.9992e+01 (9s)
[Epoch 24] Train: ELBO -9.9855e+01 (