In [1]:
import numpy as np
import os, sys
import random

In [2]:
sys.path.append('dependencies/')
import Loss
import Model
import Sampling
import Trainer

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [4]:
cuda = torch.cuda.is_available()
h_dims = [200]
z_dim = 32

In [5]:
flatten_bernoulli = lambda img: transforms.ToTensor()(img).view(-1).bernoulli()
mnist = datasets.MNIST('../data/', train=True, transform=flatten_bernoulli, download=True)
mnist_val = datasets.MNIST('../data/', train=False, transform=flatten_bernoulli, download=True)

unlabelled = torch.utils.data.DataLoader(mnist, batch_size=100, shuffle=True, num_workers=2)
validation = torch.utils.data.DataLoader(mnist_val, batch_size=100, shuffle=True, num_workers=2)

In [6]:
vae = Model.VAE([28*28, h_dims, z_dim]).cuda()
if cuda: vae = vae.cuda()

In [7]:
objective = Loss.VariationalInference(Loss.binary_cross_entropy, Loss.KL_divergence_normal)
opt = torch.optim.Adam(vae.parameters(), lr = 1e-5)

In [8]:
trainer = Trainer.VAE_trainer(vae, objective, opt, cuda)
trainer.train(None, unlabelled, 100+1)

Epoch: 10, loss:147.240
Epoch: 20, loss:123.000
Epoch: 30, loss:97.601
Epoch: 40, loss:89.882
Epoch: 50, loss:89.191
Epoch: 60, loss:84.881
Epoch: 70, loss:83.219
Epoch: 80, loss:70.545
Epoch: 90, loss:73.234
Epoch: 100, loss:73.861


### Plotting reconstruction

In [None]:
# from Plotutils import latent_imshow
# %matplotlib inline
# latent_imshow(-10, 10, 25, vae)

### Classification

In [None]:
class Classifier(nn.Module):
    def __init__(self, dims, dataset='mnist'):
        super(Classifier, self).__init__()
        [z_dim, h_dim, n_class] = dims
        neurons = [z_dim, *h_dim, n_class]
        linear_layers = [nn.Linear(neurons[i-1], neurons[i]) for i in range(1, len(neurons))]
        self.h = nn.ModuleList(modules=linear_layers)
        self.output = nn.Linear(dims[-1], n_class)
    
    def forward(self, z):
        for i, next_layer in enumerate(self.h):
            z = next_layer(z)
            if i < len(self.h) - 1:
                z = F.relu(z)
        return F.softmax(self.output(z))

In [None]:
class Classifier_trainer(nn.Module):
    def __init__(self, model, classifier, cuda):
        super(Classifier_trainer, self).__init__()
        self.model = model
        self.classifier = classifier
        self.cuda = cuda
        self.optimizer = torch.optim.Adam(self.classifier.parameters(), lr = 1e-3)
        if self.cuda:
            self.model = self.model.cuda()
            self.classifier = self.classifier.cuda()
    
    def _calculate_z(self, x):
        _, (z, _, _) = self.model(x)
        return z
    
    def _calculate_logits(self, z):
        logits = self.classifier(z)
        return logits
    
    def train(self, train_loader, validation_loader, n_epochs):
        for epoch in range(n_epochs):
            for trn_x, trn_y in train_loader:
                trn_x, trn_y = Variable(trn_x), Variable(trn_y)
                if self.cuda:
                    trn_x, trn_y = trn_x.cuda(), trn_y.cuda()
                logits = self._calculate_logits(self._calculate_z(trn_x))
                loss = F.cross_entropy(logits, trn_y)
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
            if (epoch+1)%10==0:
                accuracy = []
                for val_x, val_y in validation_loader:
                    val_x = Variable(val_x)
                    if self.cuda:
                        val_x = val_x.cuda()
                        val_y = val_y.cuda()
                    logits=self._calculate_logits(self._calculate_z(val_x))
                    _, val_y_pred = torch.max(logits, 1)
                    accuracy += [torch.mean((val_y_pred.data == val_y).float())]
                    
                print("Epoch: {0:} loss: {1:.3f}, accuracy: {2:.3f}".format(epoch+1, loss.data[0], np.mean(accuracy)))

In [9]:
classifier = Model.Classifier([32,[32],10], dataset='mnist')

In [10]:
classifier_trainer = Trainer.Classifier_trainer(vae, classifier, cuda)
classifier_trainer.train(unlabelled, validation, 50+1)

Epoch: 10 loss: 1.517, accuracy: 0.943
Epoch: 20 loss: 1.509, accuracy: 0.955
Epoch: 30 loss: 1.515, accuracy: 0.952
Epoch: 40 loss: 1.497, accuracy: 0.954
Epoch: 50 loss: 1.527, accuracy: 0.959
