In [1]:
import torch
from torch import nn, cuda, optim, tensor
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [2]:
class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(784, 256)
        #self.fc2 = nn.Linear(512, 256)
        self.fc31 = nn.Linear(256, 2)
        self.fc32 = nn.Linear(256, 2)
        
        # decoder part
        self.fc4 = nn.Linear(2, 256)
        #self.fc5 = nn.Linear(256, 512)
        self.fc6 = nn.Linear(256, 784)
        
    def encoder(self, x):
        # encode data points, and return posterior parameters for each point.
        h = F.relu(self.fc1(x))
        #h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h)
    
    def reparameterize(self, mu, log_var):
        # reparameterisation trick to sample z values
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + eps*std 
        
    def decoder(self, z):
        # decode latent variables
        h = F.relu(self.fc4(z))
        #h = F.relu(self.fc5(h))
        return torch.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        # encodes samples and then decodes them
        mu, log_var = self.encoder(x.view(-1, 784))        
        z = self.reparameterize(mu, log_var)
        return self.decoder(z), mu, log_var

In [3]:
class VAE:
    
    def loss_function(self, recon_x, x, mu, log_var):
        # return reconstruction error + KL divergence losses
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return BCE + KLD
    
    def fn_train(self, data, num_epochs):

        self.model = EncoderDecoder()
        if torch.cuda.is_available():
            self.model.cuda()

        batch_size = 100
        optimizer = optim.Adam(self.model.parameters(), lr=1e-3)

        # train the model
        for epoch in range(num_epochs):
            for i in range(int(data.shape[0] / batch_size)):
                batch = data[i * batch_size:(i + 1) * batch_size]
                batch = tensor(batch, dtype=torch.float)
                if cuda.is_available():
                    batch = batch.cuda()      
                optimizer.zero_grad()                
                recon_batch, mu, log_var = self.model(batch)
                loss = self.loss_function(recon_batch, batch, mu, log_var)
                loss.backward()
                optimizer.step()
        data = tensor(data, dtype=torch.float)
        if cuda.is_available():
            data = data.cuda()
        decoded, mu, log_var = self.model(data)
        encoded = self.model.reparameterize(mu, log_var)
        return encoded, decoded
    
    def data_projection(self, data):
        decoded, mu, log_var = self.model(data)
        encoded = self.model.reparameterize(mu, log_var)
        with torch.no_grad():
            z = torch.randn(64, 2).cuda()
            sample = self.model.decoder(z).cpu()    
            save_image(sample.view(64, 1, 28, 28), './results/sample_test' + '.png')
        return encoded, decoded

In [4]:
import numpy as np
import pandas as pd
import math
import timeit

train_dataset = datasets.MNIST(root='../app/mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1000000, shuffle=True, num_workers=1)         

# number of epochs
n_epochs = 1

# train model
vae = VAE()

for data in train_loader:
    data = data[0].numpy()
    #data = data.reshape((data.shape[0], -1))
    #data = data.astype(np.float32)
    print('Start projection training')
    start_time = timeit.default_timer()
    encoded, decoded = vae.fn_train(data,  n_epochs)
    print('Training done', timeit.default_timer() - start_time)
    #print(encoded)

Start projection training
Training done 9.22427196700005


### Latent Space Exploration

In [5]:
# load one big batch for visualization
train_loader_batch = torch.utils.data.DataLoader(dataset=train_dataset, 
                                            batch_size=10000, 
                                            shuffle=False)
one_batch = next(iter(train_loader_batch))
img, labels = one_batch
images_flatten = img.view(img.size(0), -1)
all_data = images_flatten.cuda()

In [6]:
# get a latent space z
encoded, decoded = vae.data_projection(all_data)
z = encoded.detach().cpu().numpy()

In [6]:
encoded_t = encoded[0:64]

In [7]:
encoded_t.shape

torch.Size([64, 2])