# Autoencoder
Autoencoder para FashionMNIST

In [1]:
# Import de paquetes
%matplotlib inline
import matplotlib.pyplot as plt
%matplotlib inline
plt.ion()

import sys
import os

# Numpy
import numpy as np
from skimage import color, io

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# Torchvision
import torchvision.utils
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

# Dataset
from torchvision.datasets import FashionMNIST

## Hiperparámetros de entrenamiento
latent_dims es el tamaño del bottleneck del autoencoder

In [2]:
latent_dims = 10
num_epochs = 50
batch_size = 128
capacity = 64
learning_rate = 1e-3
use_gpu = True

## Fashion MNIST Dataset

60.000 imágenes en 10 categorias de ropa: Top/T-shirt, Trouser, Pullover, Dress, Coat, Sandar, Shirt, Sneaker, Bag, Anckle Boot. Normalizamos el dataset.

In [4]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # image = (image - mean) / std
])

train_dataset = FashionMNIST(root='./data/FashionMNIST', download=True, train=True, transform=img_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = FashionMNIST(root='./data/FashionMNIST', download=True, train=False, transform=img_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

## Autoencoder Definition

Vamos a usar una arquitectura encoder/decoder convolucional.
En los layers convolucionales, incrementamos la cantidad de canales a medida que nos acercamos al bottleneck. A pesar de eso, el número total de features se reduce ya que la cantidad de canales se incrementa en un factor de 2, pero las dimensiones espaciales se reducen por un factor de 4 (https://distill.pub/2016/deconv-checkerboard/)

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c = capacity
        # TODO
        
    def forward(self, x):
        # TODO
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = capacity
        # TODO
            
    def forward(self, x):
        # TODO
        return x
    
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
    
    def forward(self, x):
        latent = self.encoder(x)
        x_recon = self.decoder(latent)
        return x_recon
    
autoencoder = Autoencoder()

device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
autoencoder = autoencoder.to(device)

num_params = sum(p.numel() for p in autoencoder.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)

## Loop de entrenamiento
Para evaluar la calidad de la reconstrucción a cada paso, utilizamos MSE loss https://pytorch.org/docs/stable/nn.html#torch.nn.MSELoss

In [None]:
optimizer = torch.optim.Adam(params=autoencoder.parameters(), lr=learning_rate, weight_decay=1e-5)

# red en modo entrenamiento
autoencoder.train()

train_loss_avg = []

print('Training ...')
for epoch in range(num_epochs):
    train_loss_avg.append(0)
    num_batches = 0
    
    for image_batch, _ in train_dataloader:
        
        image_batch = image_batch.to(device)
        
        # reconstrucción del autoencoder
        image_batch_recon = autoencoder(image_batch)
        
        # error de reconstrucción
        loss = F.mse_loss(image_batch_recon, image_batch)
        
        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # optimizamos los pesos usando el gradiente propagado por backprop
        optimizer.step()
        
        train_loss_avg[-1] += loss.item()
        num_batches += 1
        
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))

In [None]:
fig = plt.figure()
plt.plot(train_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Reconstruction error')
plt.show()

## Evaluación

In [None]:
# red en modo evaluación
autoencoder.eval()

test_loss_avg, num_batches = 0, 0
for image_batch, _ in test_dataloader:
    
    with torch.no_grad():

        image_batch = image_batch.to(device)

        # reconstruimos con el autoencoder 
        image_batch_recon = autoencoder(image_batch)

        # medimos el error de reconstruction 
        loss = F.mse_loss(image_batch_recon, image_batch)

        test_loss_avg += loss.item()
        num_batches += 1
    
test_loss_avg /= num_batches
print('average reconstruction error: %f' % (test_loss_avg))

## Visualizamos algunas reconstrucciones

In [None]:
autoencoder.eval()

# Funciones de graficación de las imagenes
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    return x

def show_image(img):
    img = to_img(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

def visualise_output(images, model):

    with torch.no_grad():

        images = images.to(device)
        images = model(images)
        images = images.cpu()
        images = to_img(images)
        np_imagegrid = torchvision.utils.make_grid(images[1:50], 10, 5).numpy()
        plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
        plt.show()

images, labels = iter(test_dataloader).next()

# Primero, mostramos las imágenes originales (GT)
print('Original images')
show_image(torchvision.utils.make_grid(images[1:50],10,5))
plt.show()

# Visualizamos las reconstrucciones del autoencoder
print('Autoencoder reconstruction:')
visualise_output(images, autoencoder)