In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
import sys
sys.path.append('/home/hao/Research/probtorch/')
sys.path.append('../models/')

import probtorch
from probtorch.util import expand_inputs
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [2]:
NUM_PIXELS = 784
NUM_HIDDEN = 400
NUM_LATENT = 10   

# training parameters
NUM_SAMPLES = 10
NUM_BATCH = 100
NUM_EPOCHS = 500
LEARNING_RATE = 1e-4
BETA1 = 0.90
EPS = 1e-9
CUDA = torch.cuda.is_available()
# path parameters
DATA_PATH = '/home/hao/Research/apg_data/mnist/vanilla/data'
WEIGHTS_PATH = '../weights/'

In [15]:
import torch
a = {'ll' : []}
for i in range(10):
    a['ll'].append((torch.ones(1)*(i+1)).unsqueeze(-1))
a['ll'] =torch.cat(a['ll'], -1)

In [16]:
a['ll']

tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]])

In [None]:
from vae_dec_digit import Decoder
from vae_enc_digit import Encoder
enc = Encoder(num_pixels=NUM_PIXELS, num_hidden=NUM_HIDDEN, z_what_dim=NUM_LATENT)
dec = Decoder(num_pixels=NUM_PIXELS, num_hidden=NUM_HIDDEN, z_what_dim=NUM_LATENT)
if CUDA:
    enc.cuda()
    dec.cuda()

optimizer =  torch.optim.Adam(list(enc.parameters())+list(dec.parameters()), lr=LEARNING_RATE, betas=(BETA1, 0.999))

In [None]:
def elbo(q, p, alpha=0.1):
    if NUM_SAMPLES is None:
        return probtorch.objectives.montecarlo.elbo(q, p, sample_dim=None, batch_dim=0, alpha=alpha)
    else:
        return probtorch.objectives.montecarlo.elbo(q, p, sample_dim=0, batch_dim=1, alpha=alpha)

In [None]:
from torchvision import datasets, transforms
train_data = torch.utils.data.DataLoader(datasets.MNIST(DATA_PATH, train=True, download=True, transform=transforms.ToTensor()), batch_size=NUM_BATCH, shuffle=True) 

In [None]:
def train(data, enc, dec, optimizer):
    epoch_elbo = 0.0
    enc.train()
    dec.train()
    N = 0
    for b, (images, labels) in enumerate(data):
        if images.size()[0] == NUM_BATCH:
            N += 1
            images = images.view(-1, NUM_PIXELS)
            if CUDA:
                images = images.cuda()
            optimizer.zero_grad()
            q = enc(images, num_samples=NUM_SAMPLES)
            p = dec(images, q, num_samples=NUM_SAMPLES)
            loss = -elbo(q, p)
            loss.backward()
            optimizer.step()
            if CUDA:
                loss = loss.cpu()
            epoch_elbo -= float(loss.item())

    return epoch_elbo / N

In [None]:
import time
enc.load_state_dict(torch.load(WEIGHTS_PATH +'vanilla-vae-enc'))
dec.load_state_dict(torch.load(WEIGHTS_PATH +'vanilla-vae-dec'))

for e in range(NUM_EPOCHS):
    train_start = time.time()
    train_elbo = train(train_data, enc, dec, optimizer)
    train_end = time.time()
    print('[Epoch %d] Train: ELBO %.4e (%ds)' % (e, train_elbo, train_end - train_start))

In [None]:
torch.save(enc.state_dict(), WEIGHTS_PATH +'vanilla-vae-enc')
torch.save(dec.state_dict(), WEIGHTS_PATH +'vanilla-vae-dec')