In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
import numpy as np
import torchvision as tv 

In [None]:
# Create dataloaders for training and testing 
train_dataset = tv.datasets.MNIST(root='./data', train=True, transform=tv.transforms.ToTensor(), download=True)
test_dataset = tv.datasets.MNIST(root='./data', train=False, transform=tv.transforms.ToTensor(), download=True)

# Batch size of 32, shuffle the data, and load the data
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True) # Training loader
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False) # Testing loader

In [None]:
# Rule about convolutions: if we have an h x h image, and a k x k kernel, the new size of our image following a convolution is (h-k+1) x (h-k+1)

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        # A module dict is just a container for modules
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3), 
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2), # 16 x 13 x 13
            nn.Conv2d(16,32,kernel_size=4), # 32 x 10 x 10
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2), # 32 x 5 x 5
            nn.Conv2d(32, 8, kernel_size=2), # 8 x 4 x 4
            nn.ReLU(True),
            # Flatten the output of the convolutional layers
            nn.Flatten(),
            nn.Linear(8*4*4, 16), # 16 outputs
            nn.ReLU(True)
        )

        self.decoder = nn.ModuleDict({
            'fc': nn.Linear(16, 8*6*6), # fully connected layer
            'Deconv1': nn.ConvTranspose2d(8, 32, kernel_size=3), # Deconvolutional layer
            'Deconv2': nn.ConvTranspose2d(32, 16, kernel_size=5), # Deconvolutional layer
            'Deconv3': nn.ConvTranspose2d(16, 1, kernel_size=5) # Deconvolutional layer
        })

    def forward(self, x):
        z = self.encoder(x) # Neural network sequential object - passes data through each of the layers in turn (unlike a dict)
        z = self.decoder['fc'](z)
        z = F.relu(z)
        z = z.view(-1, 8, 6, 6) #(8, 6, 6)
        z = self.decoder['Deconv1'](z) # (32, 8, 8)
        z = F.relu(z)
        z = self.decoder['Deconv2'](z) # (16, 12, 12)
        z = F.relu(z)
        z = F.interpolate(z, scale_factor=2) # double the size of the image (16, 24, 24)
        z = self.decoder['Deconv3'](z) # (16, 28, 28)
        return z


In [None]:
# Create an instance of the autoencoder

autoencoder = AutoEncoder()
data, target = next(iter(train_loader))
print(data.shape)
autoencoder(data).shape


In [None]:
# Train the network
optimiser = torch.optim.Adam(autoencoder.parameters(), lr=0.001)

# MSE loss
loss_function = nn.MSELoss()

In [None]:
# Training loop
from tqdm import tqdm # tqdm is a progress bar library

losses = []
for epoch in range(2):
    for data, target in tqdm(train_loader):
        optimiser.zero_grad()
        output = autoencoder(data)
        loss = loss_function(output, data)
        loss.backward()
        optimiser.step()
        losses.append(loss.detach().numpy())

In [None]:
# Plot smoothed losses
import matplotlib.pyplot as plt
plt.plot(np.convolve(losses, np.ones(100)/100, mode='valid'))
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.show()

In [None]:
# Look at some image reconstructions
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():
    output = autoencoder(example_data)
    plt.imshow(example_data[0][0], cmap='gray')
    plt.show()
    plt.imshow(output[0][0], cmap='gray')
    plt.show()