In [1]:
import torch
from torch import nn, cuda, optim, tensor
from torch.nn import functional as F
from torchvision import datasets, transforms

import numpy as np
import pandas as pd
import math
from sklearn.metrics.pairwise import euclidean_distances
import matplotlib.pyplot as plt
import numpy as np

In [2]:
test_dataset = datasets.MNIST(root='../app/mnist_data/', train=False, transform=transforms.ToTensor(), download=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

In [3]:
class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()
        
        # encoder part
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), 
            nn.Linear(64, 12), 
            nn.ReLU(True), 
            nn.Linear(12, 2)
        )
        # decoder part
        self.decoder = nn.Sequential(
            nn.Linear(2, 12),
            nn.LeakyReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True), 
            nn.Linear(128, 28 * 28), 
            nn.Tanh()
        )
        
    def forward(self, x):
        z = self.encoder(x)
        x = self.decoder(z)
        return z, x

In [4]:
class AutoEncoder:
    
    def fn_train(self, data, num_epochs):
        
        self.model = EncoderDecoder()        
        if torch.cuda.is_available():
            self.model.cuda()
            
        batch_size = 100
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
        
        for epoch in range(num_epochs):
            for i in range(int(data.shape[0] / batch_size)):
                batch = data[i * batch_size:(i + 1) * batch_size]
                batch = tensor(batch, dtype=torch.float)
                if cuda.is_available():
                    batch = batch.cuda()
                # ===================forward=====================
                # forward pass: compute predicted outputs by passing inputs to the model
                encoded, decoded = self.model(batch)
                # calculate the loss
                loss = criterion(decoded, batch)
                # ===================backward====================
                # clear the gradients of all optimized variables
                optimizer.zero_grad()
                # backward pass: compute gradient of the loss with respect to model parameters
                loss.backward()
                # perform a single optimization step (parameter update)
                optimizer.step()
        data = tensor(data, dtype=torch.float)
        if cuda.is_available():
            data = data.cuda()
        encoded, decoded = self.model(data)
        if cuda.is_available():
            encoded = encoded.cpu()
            decoded = decoded.cpu()
        return encoded.detach().numpy(), decoded.detach().numpy()

    def data_projection(self, x_data):
        encoded, decoded = self.model(x_data)
        return encoded, decoded
    
    def img_fn(self):
        # obtain one batch of test images
        dataiter = iter(test_loader)
        images, labels = dataiter.next()

        with torch.no_grad():
            images_flatten = images.view(images.size(0), -1)
            images_flatten = images_flatten.cuda()
            # get sample outputs
            encoded, decoded = self.model(images_flatten)
            # prep images for display
            return decoded, images.numpy()

In [5]:
import timeit

train_dataset = datasets.MNIST(root='../app/mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=10000000)

print('Start projection training')
ae = AutoEncoder()

for data in train_loader:
    data = data[0].numpy()
    data = data
    data = data.reshape((data.shape[0], -1))
    data = data.astype(np.float32)
    
    start_time = timeit.default_timer()
    dim_reducted = ae.fn_train(data,  1)
    print('Training done', timeit.default_timer() - start_time)
    print(dim_reducted)

Start projection training
Training done 5.849467336000089
(array([[ 0.6833365 , -1.1703335 ],
       [ 1.6832031 , -0.42967546],
       [ 0.14373209, -0.8071306 ],
       ...,
       [-0.17171516, -1.3081797 ],
       [ 0.8014802 , -0.6141271 ],
       [-1.3357732 , -1.173598  ]], dtype=float32), array([[-9.05147847e-03, -3.01710726e-03,  1.08721033e-02, ...,
         5.23970881e-03, -5.43999206e-03,  6.00954751e-03],
       [-1.01989992e-02, -3.59269930e-03, -1.22585120e-02, ...,
         1.70369875e-02, -1.09606115e-02,  4.10939101e-03],
       [-2.03931406e-02, -3.28817149e-03, -1.04504405e-02, ...,
        -5.28144790e-03, -8.29700287e-03, -9.66230407e-03],
       ...,
       [-6.15663768e-04,  5.70841075e-04, -4.84520523e-03, ...,
         3.09552997e-05, -9.51757561e-03, -9.84099344e-04],
       [-8.96656699e-03, -2.21567973e-03, -1.11383321e-02, ...,
         7.22291134e-03, -8.49894062e-03,  1.12378560e-02],
       [ 3.54703143e-03,  8.68346728e-03, -7.77987344e-03, ...,
      

In [None]:
decoded, images = ae.img_fn()
# output is resized into a batch of images
output = decoded.view(128, 1, 28, 28)
# use detach when it's an output that requires_grad
output = output.detach().cpu().numpy()

# plot the first ten input images and then reconstructed images
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))

# input images on top row, reconstructions on bottom
for images, row in zip([images, output], axes):
    for img, ax in zip(images, row):
        ax.imshow(np.squeeze(img), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

In [6]:
encoded_t = dim_reducted[0]
encoded_t = encoded_t[0:64]
z_t = tensor(encoded_t, dtype=torch.float)
z_t.shape

torch.Size([64, 2])

In [7]:
from torchvision.utils import save_image
with torch.no_grad():    
    z_t = z_t.cuda()
    sample = ae.model.decoder(z_t).cpu()    
    save_image(sample.view(64, 1, 28, 28), './results/vae_sample_' + '.png')