# Simple Autoencoder for MNIST (with GPU implementation)

### Import necessary libraries

In [23]:
import os #will be used for creating directories, etc.

import torch
import torchvision.datasets as dsets             #for downloading dataset
import torchvision.transforms as transforms      #for transforming dataset into tensors

from torch import nn
import torchvision

from torch.autograd import Variable

from time import time

### Prepare MNIST dataset

In [24]:
#download MNIST dataset
dataset = dsets.MNIST(root='../data', train=True, transform=transforms.ToTensor(), download=True)
print(dataset)

Dataset MNIST
    Number of datapoints: 60000
    Split: train
    Root Location: ../data
    Transforms (if any): ToTensor()
    Target Transforms (if any): None


### Utility functions

In [25]:
def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

### Training parameters

In [26]:
batch_size = 100
num_epochs = 50

# shuffle and prepare dataset with minibatches
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
print(data_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7fbaf0887b50>


### Create Autoencoder NN

In [27]:
class Autoencoder(nn.Module):
    def __init__(self, in_dim=784, h_dim=100):
        super(Autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, h_dim),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(h_dim, in_dim),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        out = self.encoder(x)
        out = self.decoder(out)
        return out

In [28]:
hidden_size = 30     #size of the bottleneck layer

# Create Autoencoder model
ae = Autoencoder(in_dim=784, h_dim=hidden_size)

if torch.cuda.is_available():
    ae.cuda()

### Optimizer and loss function

In [29]:
criterion = nn.BCELoss()     #Binary cross-entropy loss for each pixel
optimizer = torch.optim.Adam(ae.parameters(), lr=0.001)

### Let's train!!

In [30]:
# We want to go through all the examples in each epoch, so number of iterations per epoch
iter_per_epoch = len(data_loader)
data_iter      = iter(data_loader) # data_iter is an iterator object

#### Now we want to view and save results on a fixed batch to visualize Autoencoder's performance

In [43]:
dir_output = ('data/')
os.makedirs(dir_output, exist_ok=True)

# save fixed inputs for debugging
fixed_x, labels = next(data_iter)        #points to the first batch
# torchvision.utils.save_image(Variable(fixed_x).data.cpu(), './data/real_images.png')
print(fixed_x.size())
# print(fixed_x.view(fixed_x.size(0), -1).size())
fixed_x = to_var(fixed_x.view(fixed_x.size(0), -1))
print(fixed_x.size())

# fixed_x

torch.Size([100, 1, 28, 28])
torch.Size([100, 784])


### Begin training

In [44]:
for epoch in range(1):
    t0 = time()
    for i, (images, _) in enumerate(data_loader):
        
        # flatten the image
        images = to_var(images.view(images.size(0), -1))
        out = ae(images)
        loss = criterion(out, images)
        
        optimizer.zero_grad()
        loss.backward()         # calculate gradients
        optimizer.step()        # update parameters
        
        # display training process after every 100 iterations
        if (i+1) % 100 == 0:
            print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f Time: %.2fs' 
                %(epoch+1, num_epochs, i+1, len(dataset)//batch_size, loss.item(), time()-t0))
            
    # save the reconstructed images
    reconst_images = ae(fixed_x)
    print(reconst_images.size())
    reconst_images = reconst_images.view(reconst_images.size(0), 1, 28, 28)
#     torchvision.utils.save_image(reconst_images.data.cpu(), './data/reconst_images_%d.png' % (epoch+1))

Epoch [1/50], Iter [100/600] Loss: 0.1033 Time: 0.83s
Epoch [1/50], Iter [200/600] Loss: 0.1006 Time: 1.58s
Epoch [1/50], Iter [300/600] Loss: 0.1087 Time: 2.33s
Epoch [1/50], Iter [400/600] Loss: 0.0974 Time: 3.03s
Epoch [1/50], Iter [500/600] Loss: 0.1026 Time: 3.74s
Epoch [1/50], Iter [600/600] Loss: 0.0992 Time: 4.44s
torch.Size([100, 784])
