In [1]:
import numpy as np
from itertools import cycle
import time

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [3]:
from Sampling import Gaussian_sample
from Model import VAE, Encoder, Decoder
from limitedmnist import LimitedMNIST
from torchvision.datasets import MNIST

In [4]:
cuda = torch.cuda.is_available()
h_dims = [200, 100]
z_dim = 32
batch_size=100
n_epochs=20

In [5]:
def generate_label(batch_size, label, nlabels=2):
    """
    Generates a `torch.Tensor` of size batch_size x n_labels of
    the given label.

    Example: generate_label(2, 1, 3) #=> torch.Tensor([[0, 1, 0],
                                                       [0, 1, 0]])
    :param batch_size: number of labels
    :param label: label to generate
    :param nlabels: number of total labels
    """
    labels = (torch.ones(batch_size, 1) * label).type(torch.LongTensor)
    y = torch.zeros((batch_size, nlabels))
    y.scatter_(1, labels, 1)
    return y.type(torch.LongTensor)


def onehot(k):
    """
    Converts a number to its one-hot or 1-of-k representation
    vector.
    :param k: (int) length of vector
    :return: onehot function
    """
    def hot_vector(label):
        y = torch.LongTensor(k)
        y.zero_()
        y[label] = 1
        return y
    return hot_vector


def log_sum_exp(tensor, dim=None, sum_op=torch.sum):
    """
    Uses the LogSumExp (LSE) as an approximation for the sum in a log-domain.
    :param tensor: Tensor to compute LSE over
    :param dim: dimension to perform operation over
    :param sum_op: reductive operation to be applied, e.g. torch.sum or torch.mean
    :return: LSE
    """
    max, _ = torch.max(tensor, dim=dim, keepdim=True)
    return torch.log(sum_op(torch.exp(tensor - max), dim=dim, keepdim=True)) + max

In [6]:
labels = np.arange(10)
n = len(labels)

# Load in data
mnist_lab = LimitedMNIST('./', train=True, transform=torch.bernoulli, target_transform=onehot(n), digits=labels, fraction=0.025)
mnist_ulab = LimitedMNIST('./', train=True, transform=torch.bernoulli, target_transform=onehot(n), digits=labels, fraction=1.0)
mnist_val = LimitedMNIST('./', train=False, transform=torch.bernoulli, target_transform=onehot(n), digits=labels)

# Unlabelled data
unlabeled = DataLoader(mnist_ulab, batch_size=batch_size, shuffle=True, num_workers=2)
# Validation data
validation = DataLoader(mnist_val, batch_size=batch_size, shuffle=True, num_workers=2)
# Labelled data
labeled = DataLoader(mnist_lab, batch_size=batch_size, shuffle=True, num_workers=2)

In [7]:
class Classifier(nn.Module):
    def __init__(self, dims):
        super(Classifier, self).__init__()
        [x_dim, h_dim, y_dim] = dims
        self.h = nn.Linear(x_dim, h_dim)
        self.logits = nn.Linear(h_dim, y_dim)
    
    def forward(self, x):
        x = F.relu(self.h(x))
        x = F.softmax(self.logits(x))
        return x

In [8]:
class DGM(VAE):
    def __init__(self, dims, ratio):
        self.alpha = 0.1*ratio
        [x_dim, h_dim, z_dim, self.y_dim] = dims
        
        super(DGM, self).__init__([x_dim, h_dim, z_dim])
        self.encoder = Encoder([x_dim+self.y_dim, h_dim, z_dim])
        self.decoder = Decoder([z_dim+self.y_dim, list(reversed(h_dim)), x_dim])
        self.classifier = Classifier([x_dim, h_dim[-1], self.y_dim])
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_normal(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
    
    def forward(self, x, y=None):
        logits = self.classifier(x)
        if y is None:
            return logits
        z, z_mu, z_logvar = self.encoder(torch.cat([x,y], dim=1))
        reconstruction = self.decoder(torch.cat([z,y], dim=1))
        return reconstruction, logits, (z, z_mu, z_logvar)
    
    def sample(self, z, y):
        y = y.type(torch.FloatTensor)
        x = self.decoder(torch.cat([z,y], dim=1))
        return x

In [9]:
dgm = DGM([784, [200,100], 20, 10], 0.1)
if cuda: dgm = dgm.cuda()

In [10]:
def binary_cross_entropy(r, x):
    epsilon=1e-7
    return -torch.sum((x * torch.log(r + epsilon) + (1 - x) * torch.log((1 - r) + epsilon)), dim=-1)

In [11]:
def cross_entropy(logits, y):
    return -y*torch.log(logits + 1e-8)
def loss_function(x_reconstructed, x,  mu, logvar):
    reconstruction_error = binary_cross_entropy(x_reconstructed, x)
    KL_div = 0.5*(1. + logvar - mu**2 - torch.exp(logvar))
    return reconstruction_error, torch.sum(KL_div)

In [12]:
def custom_logger(d):
    x, y = next(iter(validation))
    _, y_logits = torch.max(dgm.classifier(Variable(x)), 1)
    _, y = torch.max(y, 1)

    acc = torch.sum(y_logits.data == y)/len(y)
    d["Accuracy"] = acc
    
    print(d)

In [13]:
def calculate_loss(x, y=None):
    is_unlabeled = True if y is None else False
    x = Variable(x)
    if cuda:
        x = x.cuda()
    logits = dgm.forward(x)

    loss = 0
    if is_unlabeled==False:
        y = Variable(y.type(torch.FloatTensor))
        if cuda:
            y=y.cuda()
        x_recon, _, (z, z_mu, z_logvar) = dgm.forward(x, y)
        reconstruction_error, KL_div = loss_function(x_recon, x, z_mu, z_logvar)
        loss = torch.sum(reconstruction_error) - KL_div + torch.sum(dgm.alpha * -cross_entropy(logits, y))
    
    elif is_unlabeled:
        for i in range(dgm.y_dim):
            y = generate_label(batch_size, i, dgm.y_dim)
            y = Variable(y.type(torch.FloatTensor))
            if cuda:
                y = y.cuda()
            x_recon, _, (z, z_mu, z_logvar) = dgm.forward(x, y)
            reconstruction_error, KL_div = loss_function(x_recon, x, z_mu, z_logvar)
            loss += torch.sum(torch.mul(logits[:,i], reconstruction_error - KL_div))
            
    return loss

In [14]:
opt = optim.Adam(dgm.parameters(), lr = 3e-4)

In [15]:
for epoch in range(n_epochs):
    for (x, y), (u, _) in zip(cycle(labeled), unlabeled):
        U = calculate_loss(u)
        L = calculate_loss(x, y)
        J = L + U
        J.backward()
        opt.step()
        opt.zero_grad()
    print("epoch: {}, unlabeled loss: {:.3f}, labeled loss: {:.3f}, total loss: {:.3f}".format(epoch+1, U.data[0], L.data[0], J.data[0]))

epoch: 1, unlabeled loss: 21502.297, labeled loss: 21050.078, total loss: 42552.375
epoch: 2, unlabeled loss: 20983.039, labeled loss: 20427.412, total loss: 41410.453
epoch: 3, unlabeled loss: 21367.160, labeled loss: 18448.693, total loss: 39815.852
epoch: 4, unlabeled loss: 20775.436, labeled loss: 17873.164, total loss: 38648.602
epoch: 5, unlabeled loss: 20697.492, labeled loss: 17449.678, total loss: 38147.172
epoch: 6, unlabeled loss: 20700.348, labeled loss: 16777.822, total loss: 37478.172
epoch: 7, unlabeled loss: 21140.793, labeled loss: 17511.396, total loss: 38652.188
epoch: 8, unlabeled loss: 19919.182, labeled loss: 17545.756, total loss: 37464.938
epoch: 9, unlabeled loss: 20955.654, labeled loss: 17046.859, total loss: 38002.516
epoch: 10, unlabeled loss: 20347.039, labeled loss: 17500.830, total loss: 37847.867
epoch: 11, unlabeled loss: 20049.412, labeled loss: 17355.887, total loss: 37405.297
epoch: 12, unlabeled loss: 20482.496, labeled loss: 17840.299, total loss: