# Variational Autoencoder (VAE)
Autoencoder para FashionMNIST

In [4]:
# Import de paquetes
%matplotlib inline
import matplotlib.pyplot as plt
%matplotlib inline
plt.ion()

import sys
import os

# Numpy
import numpy as np
from skimage import color, io

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# Torchvision
import torchvision.utils
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

# Dataset
from torchvision.datasets import FashionMNIST

## Hiperparámetros de entrenamiento
latent_dims es el tamaño del bottleneck del autoencoder

In [11]:
latent_dims = 2
num_epochs = 100
batch_size = 128
capacity = 64
learning_rate = 1e-3
variational_beta = 1
use_gpu = True

## Fashion MNIST Dataset

60.000 imágenes en 10 categorias de ropa: Top/T-shirt, Trouser, Pullover, Dress, Coat, Sandar, Shirt, Sneaker, Bag, Anckle Boot. Normalizamos el dataset.

In [12]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # image = (image - mean) / std
])

train_dataset = FashionMNIST(root='./data/FashionMNIST', download=True, train=True, transform=img_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = FashionMNIST(root='./data/FashionMNIST', download=True, train=False, transform=img_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

## VAE Definition

Vamos a usar una arquitectura encoder/decoder convolucional.
En los layers convolucionales, incrementamos la cantidad de canales a medida que nos acercamos al bottleneck. A pesar de eso, el número total de features se reduce ya que la cantidad de canales se incrementa en un factor de 2, pero las dimensiones espaciales se reducen por un factor de 4 (https://distill.pub/2016/deconv-checkerboard/)

In [13]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c = capacity
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=4, stride=2, padding=1) # out: c x 14 x 14
        self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1) # out: c x 7 x 7
        self.fc_mu = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)
        self.fc_logvar = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)
            
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1) # aplastamos el tensor
        x_mu = self.fc_mu(x)
        x_logvar = self.fc_logvar(x)
        return x_mu, x_logvar

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = capacity
        self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*7*7)
        self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
            
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), capacity*2, 7, 7) # transformamos el tensor a 4D
        x = F.relu(self.conv2(x))
        x = torch.sigmoid(self.conv1(x)) # Sigmoide (estamos usando una loss BCE)
        return x
    
class VariationalAutoencoder(nn.Module):
    def __init__(self):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
    
    def forward(self, x):
        latent_mu, latent_logvar = self.encoder(x)
        latent = self.latent_sample(latent_mu, latent_logvar)
        x_recon = self.decoder(latent)
        return x_recon, latent_mu, latent_logvar
    
    def latent_sample(self, mu, logvar):
        if self.training:
            # reparameterización
            std = logvar.mul(0.5).exp_()
            eps = torch.empty_like(std).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
    
# Definición de loss
def vae_loss(recon_x, x, mu, logvar):

    #Promediar o no la entropía cruzada binaria sobre todos los píxeles
    #es un detalle sutil, pero con fuerte impacto en el entrenamiento, porque cambia
    #la magnitud de los pesos que tenemos que elegir para los otros términos de la 
    #loss por varios órdenes de magnitud. No promediarlos es una implementación
    #directa de -log likelihood, pero promediarlos hace que los pesos de los
    #otros términos de la los sean independientes de la resolución de la imagen
    recon_loss = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    
    # KL-divergence sobre el espacio latente
    kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + variational_beta * kldivergence
    
    
vae = VariationalAutoencoder()

device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
vae = vae.to(device)

num_params = sum(p.numel() for p in vae.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)

Number of parameters: 308357


## Loop de entrenamiento
Para evaluar la calidad de la reconstrucción a cada paso, utilizamos MSE loss https://pytorch.org/docs/stable/nn.html#torch.nn.MSELoss

In [None]:
optimizer = torch.optim.Adam(params=vae.parameters(), lr=learning_rate, weight_decay=1e-5)

# red en modo entrenamiento
vae.train()

train_loss_avg = []

print('Training ...')
for epoch in range(num_epochs):
    train_loss_avg.append(0)
    num_batches = 0
    
    for image_batch, _ in train_dataloader:
        
        image_batch = image_batch.to(device)
        
        # reconstrucción del vae
        image_batch_recon, latent_mu, latent_logvar = vae(image_batch)
        
        # error de reconstrucción
        loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)
        
        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # optimizamos los pesos usando el gradiente propagado por backprop
        optimizer.step()
        
        train_loss_avg[-1] += loss.item()
        num_batches += 1
        
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))

Training ...


In [None]:
fig = plt.figure()
plt.plot(train_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

## Evaluación

In [None]:
# red en modo evaluación
vae.eval()

test_loss_avg, num_batches = 0, 0
for image_batch, _ in test_dataloader:
    
    with torch.no_grad():
    
        image_batch = image_batch.to(device)

        # vae 
        image_batch_recon, latent_mu, latent_logvar = vae(image_batch)

        #  error
        loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)

        test_loss_avg += loss.item()
        num_batches += 1
    
test_loss_avg /= num_batches
print('average reconstruction error: %f' % (test_loss_avg))

## Visualizamos algunas reconstrucciones

In [None]:
vae.eval()

# Funciones de graficación de las imagenes
def to_img(x):
    x = x.clamp(0, 1)
    return x

def show_image(img):
    img = to_img(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

def visualise_output(images, model):

    with torch.no_grad():
    
        images = images.to(device)
        images, _, _ = model(images)
        images = images.cpu()
        images = to_img(images)
        np_imagegrid = torchvision.utils.make_grid(images[1:50], 10, 5).numpy()
        plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
        plt.show()

images, labels = iter(test_dataloader).next()

# Primero, mostramos las imágenes originales (GT)
print('Original images')
show_image(torchvision.utils.make_grid(images[1:50],10,5))
plt.show()

# Visualizamos las reconstrucciones del autoencoder
print('VAE reconstruction:')
visualise_output(images, vae)