In [1]:
import argparse
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import os
import random
import sys
sys.path.append('dependencies/')
import loss
from model import VAE
from sampling import Gaussian_sample

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [3]:
plt.style.use("ggplot")
cuda = torch.cuda.is_available()

In [4]:
flatten_bernoulli = lambda img: transforms.ToTensor()(img).view(-1).bernoulli()
mnist = datasets.MNIST('data/', train=True, transform=flatten_bernoulli, download=True)
mnist_val = datasets.MNIST('data/', train=False, transform=flatten_bernoulli, download=True)

unlabelled = torch.utils.data.DataLoader(mnist, batch_size=100, shuffle=True, num_workers=2)
validation = torch.utils.data.DataLoader(mnist_val, batch_size=100, shuffle=True, num_workers=2)

In [5]:
model = VAE([28*28, [128], 32]).cuda()
if cuda: model = model.cuda()

In [6]:
objective = loss.VariationalInference(loss.binary_cross_entropy, loss.KL_divergence_normal)
opt = torch.optim.Adam(model.parameters(), lr = 1e-5)

In [7]:
for epoch in range(200):
    for unlabelled_image, _ in unlabelled:
        u = Variable(unlabelled_image)
        if cuda:
            u = u.cuda()
        reconstruction, (_, z_mu, z_logvar) = model(u)
        loss = objective(reconstruction, u, z_mu, z_logvar)
        loss.backward()
        opt.step()
        opt.zero_grad()
        
    if epoch%5 == 0:
        print_loss = loss.data[0]
        print("Epoch: {0:}, loss:{1:.3f}".format(epoch, print_loss))

Epoch: 0, loss:41982.887
Epoch: 5, loss:22524.299
Epoch: 10, loss:19817.111
Epoch: 15, loss:17865.240
Epoch: 20, loss:17177.883
Epoch: 25, loss:15818.608
Epoch: 30, loss:15780.653
Epoch: 35, loss:16280.032
Epoch: 40, loss:15474.046
Epoch: 45, loss:14968.266
Epoch: 50, loss:14359.342
Epoch: 55, loss:14284.939
Epoch: 60, loss:14115.750
Epoch: 65, loss:14336.035
Epoch: 70, loss:14086.359
Epoch: 75, loss:14080.630
Epoch: 80, loss:13676.011
Epoch: 85, loss:13489.354
Epoch: 90, loss:13154.350
Epoch: 95, loss:13220.627
Epoch: 100, loss:14185.450
Epoch: 105, loss:13024.311
Epoch: 110, loss:13450.775
Epoch: 115, loss:12689.738
Epoch: 120, loss:12915.957
Epoch: 125, loss:12490.836
Epoch: 130, loss:12846.779
Epoch: 135, loss:12612.048
Epoch: 140, loss:12419.067
Epoch: 145, loss:11956.010
Epoch: 150, loss:12674.398
Epoch: 155, loss:12228.896
Epoch: 160, loss:12288.304
Epoch: 165, loss:12471.174
Epoch: 170, loss:12400.033
Epoch: 175, loss:12132.865
Epoch: 180, loss:12003.592
Epoch: 185, loss:12556.

In [14]:
for unlabelled_image, _ in unlabelled:
    print(unlabelled_image.size())
    model.forward(Variable(unlabelled_image).cuda())
    break

torch.Size([100, 784])
