In [1]:
import numpy as np
import os, sys
import random

In [2]:
sys.path.append('dependencies/')
import Loss
import Model
import Sampling
import Trainer

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [4]:
cuda = torch.cuda.is_available()
h_dims = [200]
z_dim = 32

In [5]:
flatten_bernoulli = lambda img: transforms.ToTensor()(img).view(-1).bernoulli()
mnist = datasets.MNIST('../data/', train=True, transform=flatten_bernoulli, download=True)
mnist_val = datasets.MNIST('../data/', train=False, transform=flatten_bernoulli, download=True)

unlabelled = torch.utils.data.DataLoader(mnist, batch_size=100, shuffle=True, num_workers=2)
validation = torch.utils.data.DataLoader(mnist_val, batch_size=100, shuffle=True, num_workers=2)

In [6]:
vae = Model.VAE([28*28, h_dims, z_dim]).cuda()
if cuda: vae = vae.cuda()

In [7]:
objective = Loss.VariationalInference(Loss.binary_cross_entropy, Loss.KL_divergence_normal)
opt = torch.optim.Adam(vae.parameters(), lr = 1e-5)

In [8]:
trainer = Trainer.VAE_trainer(vae, objective, opt, cuda)
trainer.train(None, unlabelled, 100+1)

Epoch: 10, loss:147.240
Epoch: 20, loss:123.000
Epoch: 30, loss:97.601
Epoch: 40, loss:89.882
Epoch: 50, loss:89.191
Epoch: 60, loss:84.881
Epoch: 70, loss:83.219
Epoch: 80, loss:70.545
Epoch: 90, loss:73.234
Epoch: 100, loss:73.861


### Plotting reconstruction

In [None]:
# from Plotutils import latent_imshow
# %matplotlib inline
# latent_imshow(-10, 10, 25, vae)

### Classification

In [9]:
classifier = Model.Classifier([32,[32],10], dataset='mnist')

In [10]:
classifier_trainer = Trainer.Classifier_trainer(vae, classifier, cuda)
classifier_trainer.train(unlabelled, validation, 50+1)

Epoch: 10 loss: 1.517, accuracy: 0.943
Epoch: 20 loss: 1.509, accuracy: 0.955
Epoch: 30 loss: 1.515, accuracy: 0.952
Epoch: 40 loss: 1.497, accuracy: 0.954
Epoch: 50 loss: 1.527, accuracy: 0.959
