# 3. AutoEncoder Demo (ConvTranspose)

### About this notebook

This notebook was used in the 50.039 Deep Learning course at the Singapore University of Technology and Design.

**Author:** Matthieu DE MARI (matthieu_demari@sutd.edu.sg)

**Version:** 1.1 (05/04/2022)

**Requirements:**
- Python 3 (tested on v3.9.6)
- Matplotlib (tested on v3.5.1)
- Numpy (tested on v1.22.1)


### Imports

In [None]:
# Imports
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import matplotlib.pyplot as plt

In [None]:
# CUDA check
CUDA = True
device = "cuda" if (torch.cuda.is_available() and CUDA) else "cpu"
print(torch.cuda.is_available())
print(device)

### Dataset and dataloaders

In [None]:
# Data Preprocessing
# - ToTensor
# - Image Normalization
transform = transforms.Compose([transforms.ToTensor(), \
                                transforms.Normalize((0.1307,), (0.3081,))])

In [None]:
# Train datasets/dataloaders
train_set = torchvision.datasets.MNIST(root='./data', \
                                       train = True, \
                                       download = True, \
                                       transform = transform)
train_loader = torch.utils.data.DataLoader(train_set, \
                                           batch_size = 32, \
                                           shuffle = False)

In [None]:
# Test datasets/dataloaders
test_set = torchvision.datasets.MNIST(root = './data', \
                                      train = False, \
                                      download = True, \
                                      transform = transform)
test_loader = torch.utils.data.DataLoader(test_set, \
                                          batch_size = 4, \
                                          shuffle = False)

### Model

Conv2d and Conv2dTranspose layers

In [None]:
# Define AutoEncoder Model for MNIST
class MNIST_Autoencoder(nn.Module):

    def __init__(self, hidden_layer = 7):
        
        # Init from nn.Module
        super().__init__()
        
        # Encoder part will be a simple sequence of Conv2d+ReLU.
        self.encoder = nn.Sequential(nn.Conv2d(1, 16, 3, stride = 2, padding = 1),
                                     nn.ReLU(),
                                     nn.Conv2d(16, 32, 3, stride = 2, padding = 1),
                                     nn.ReLU(),
                                     nn.Conv2d(32, 64, hidden_layer))
        
        
        # Decoder part will be a simple sequence of TransposeConv2d+ReLU.
        # Finish with Sigmoid
        self.decoder = nn.Sequential(nn.ConvTranspose2d(64, 32, hidden_layer),
                                     nn.ReLU(),
                                     nn.ConvTranspose2d(32, 16, 3, stride = 2, padding = 1, output_padding = 1),
                                     nn.ReLU(),
                                     nn.ConvTranspose2d(16, 1, 3, stride = 2, padding = 1, output_padding = 1),
                                     nn.ReLU())
        

    def forward(self,x):
        
        # Forward is encoder into decoder
        x = self.encoder(x)
        x = self.decoder(x)
        return x

### Train model

In [None]:
# Initialize MNIST Autoencoder
torch.manual_seed(10)
model = MNIST_Autoencoder(hidden_layer = 7).to(device)

In [None]:
# Defining Parameters
# - MSE Loss, which will be our reconstruction loss for now
# - Adam as optimizer
# - 20 Epochs
# - 128 as batch size
num_epochs = 20
batch_size = 128
distance = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),weight_decay = 1e-5)

In [None]:
# Train
outputs_list = []
loss_list = []
for epoch in range(num_epochs):
    for data in train_loader:
        
        # Send data to device
        img, _ = data
        img = Variable(img).to(device)
        
        # Forward pass
        output = model(img)
        loss = distance(output, img)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # Display
    print('epoch {}/{}, loss {:.4f}'.format(epoch + 1, num_epochs, loss.item()))
    outputs_list.append((epoch, img, output),)
    loss_list.append(loss.item())

In [None]:
# Display loss
plt.figure()
plt.plot(loss_list)
plt.show()

In [None]:
# Display some outputs
for k in range(0, num_epochs, 4):
    plt.figure(figsize = (9, 2))
    imgs = outputs_list[k][1].cpu().detach().numpy()
    recon = outputs_list[k][2].cpu().detach().numpy()
    for i, item in enumerate(imgs):
        if i >= 9: break
        plt.subplot(2, 9, i+1)
        plt.imshow(item[0])
        
    for i, item in enumerate(recon):
        if i >= 9: break
        plt.subplot(2, 9, 9+i+1)
        plt.imshow(item[0])