In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt

import torch
print(torch.__version__)

# Dataset y Dataloaders

Usaremos la base de datos MNIST de dígitos manuscritos

Corresponden a imágenes de 28x28 píxeles en blanco y negro

Esta base de datos viene incluida en el [módulo datasets de la librería torchvision](https://pytorch.org/docs/stable/torchvision/datasets.html)

In [None]:
from torchvision import datasets, transforms

mnist_train_data = datasets.MNIST(root='~/datasets/',
                                  train=True, download=True,
                                  transform=transforms.ToTensor())

mnist_test_data = datasets.MNIST(root='~/datasets/',
                                 train=False, download=True, 
                                 transform=transforms.ToTensor())

image, label = mnist_train_data[0]
display(len(mnist_train_data), type(image), image.dtype, type(label))
fig, ax = plt.subplots(1, 10, figsize=(8, 1.5), tight_layout=True)
idx = np.random.permutation(len(mnist_train_data))[:10]
for k in range(10):
    image, label = mnist_train_data[idx[k]]
    ax[k].matshow(image[0, :, :].numpy(), cmap=plt.cm.Greys_r)
    ax[k].axis('off');
    ax[k].set_title(label)
    
from torch.utils.data import Subset, DataLoader
import sklearn.model_selection

# Set de entrenamiento y validación estratíficados
sss = sklearn.model_selection.StratifiedShuffleSplit(train_size=0.75).split(mnist_train_data.data, 
                                                                            mnist_train_data.targets)
train_idx, valid_idx = next(sss)

# Data loader de entrenamiento
train_dataset = Subset(mnist_train_data, train_idx)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32)

# Data loader de validación
valid_dataset = Subset(mnist_train_data, valid_idx)
valid_loader = DataLoader(valid_dataset, shuffle=False, batch_size=256)

# Autoencoder 

- [Revisar material teórico entre las láminas 96 y 102](https://docs.google.com/presentation/d/1IJ2n8X4w8pvzNLmpJB-ms6-GDHWthfsJTFuyUqHfXg8/present#slide=id.g3d5022dff0_1_100)
- [Revisar tutorial de Pytorch](https://github.com/phuijse/INFO257/blob/master/notebooks/clases/1_pytorch_tutorial.ipynb) y [clase redes convolucionales](https://github.com/phuijse/INFO257/blob/master/notebooks/clases/4_red_convolucional.ipynb)


Completemos el modelo de autoencoder que se propone a continuación

In [None]:
from torch import nn

class Autoencoder(nn.Module):
    
    def __init__(self, ...):
        super(type(self), self).__init__()
        
        # Completar
    
    def encode(self, x):
        # Completar
        return x
        
    def decode(self, z):
        # Completar
        return z
    
    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)
    
model = Autoencoder()
display(model)

x = mnist_train_data[0][0]
xhat = model.forward(x.view(1, 28*28)).view(1, 28, 28).detach().numpy()

fig, ax = plt.subplots(1, 2, figsize=(6, 2), tight_layout=True)
ax[0].matshow(x[0], cmap=plt.cm.Greys_r)
ax[0].set_title('Original'); ax[0].axis('off')
ax[1].matshow(xhat[0], cmap=plt.cm.Greys_r)
ax[1].set_title('Reconstruida'); ax[1].axis('off')

# Entrenamiento

- Para actualizar los parámetros usaremos el optimizar Adam
- Transformamos las imágenes de MNIST al rango [0, 1] y usamos la entropía cruzada binaria como función de costo. Interpretamos la salida como logaritmos de probabilidades (logits)
- Utilizaremos la librería de alto nivel [`ignite`](https://pytorch.org/ignite/) para entrenar y el dashboard `tensorboard` para visualizar los entrenamientos

No olvides lanzar el dashboard

    tensorboard --logdit /tmp/tensorboard

In [None]:
torch.manual_seed(1234) # Por reproducibilidad
max_epochs = 100  
model = Autoencoder()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
device = torch.device('cpu')
#device = torch.device('cuda:0')

def train_one_step(engine, batch): 
    model.train()
    x, y = batch # Desenpaquetamos el minibatch
    x, y = x.to(device), y.to(device) # Enviamos los datos a GPU
    x = x.view(-1, 28*28) # Aplanar las imágenes
    optimizer.zero_grad() # Reseteamos los gradientes
    xhat = model.forward(x) # Reproducimos la entrada
    loss = criterion(xhat, x) # Medimos el error entre la entrada y la reconstrucción
    loss.backward() # Calculamos los gradientes
    optimizer.step() # Actualizamos los parámetros
    return loss.item()

def evaluate_one_step(engine, batch):
    model.eval()
    with torch.no_grad():
        x, y = batch
        x, y = x.to(device), y.to(device)
        x = x.view(-1, 28*28) # Aplanar las imágenes
        xhat = model.forward(x)
        loss = criterion(xhat, x)
        return xhat, x
    
from ignite.engine import Engine, Events
from ignite.metrics import Loss, Accuracy

metrics = {'Loss': Loss(criterion)}
trainer = Engine(train_one_step)
evaluator = Engine(evaluate_one_step)
for name, metric in metrics.items():
    metric.attach(evaluator, name)

In [None]:
import time
from torch.utils.tensorboard import SummaryWriter
from ignite.handlers import ModelCheckpoint

# Contexto de escritura de datos para tensorboard
with SummaryWriter(log_dir=f'/tmp/tensorboard/run{time.time_ns()}') as writer:

    @trainer.on(Events.EPOCH_COMPLETED(every=1)) # Cada 1 epocas
    def log_results(engine):
        evaluator.run(train_loader) # Evaluo el conjunto de entrenamiento
        writer.add_scalar("train/loss", evaluator.state.metrics['Loss'], engine.state.epoch)
        
        evaluator.run(valid_loader) # Evaluo el conjunto de validación
        writer.add_scalar("valid/loss", evaluator.state.metrics['Loss'], engine.state.epoch)
        
        print(f"Epoca: {engine.state.epoch} Valid loss: {evaluator.state.metrics['Loss']:.4f}")

    best_model_handler = ModelCheckpoint(dirname='.', require_empty=False, filename_prefix="best", n_saved=1,
                                         score_function=lambda engine: -engine.state.metrics['Loss'],
                                         score_name="val_loss")

    # Lo siguiente se ejecuta cada ves que termine el loop de validación
    evaluator.add_event_handler(Events.COMPLETED, 
                                best_model_handler, {'ae': model})

    trainer.run(train_loader, max_epochs=max_epochs)

# Inspección de resultados

Primero recuperamos el mejor modelo

In [None]:
model = Autoencoder()
model.load_state_dict(torch.load('best_ae_val_loss=-0.1856.pt'))

## Visualización de las reconstrucciones

In [None]:
fig, axs = plt.subplots(2, 10, figsize=(8, 3))
P = np.random.permutation(10000)

for i in range(10):
    image, label = mnist_test_data[P[i]]
    hat_image = nn.Sigmoid()(model.forward(image.view(1, 28*28))).view(1, 28, 28)
    axs[0, i].matshow(image.numpy()[0], cmap=plt.cm.Greys_r)
    axs[0, i].axis('off');
    axs[0, i].set_title(label)
    axs[1, i].matshow(hat_image.detach().numpy()[0], cmap=plt.cm.Greys_r)
    axs[1, i].axis('off');
plt.tight_layout();

## Visualización del espacio latente

In [None]:
fig = plt.figure(figsize=(10, 4), dpi=80)
ax_main = plt.subplot2grid((2, 3), (0, 0), colspan=2, rowspan=2)
ax_ori = plt.subplot2grid((2, 3), (0, 2))
ax_rec = plt.subplot2grid((2, 3), (1, 2))
ax_ori.axis('off'); ax_rec.axis('off');

N = test_loader.dataset.__len__()
test_targets = mnist_test_data.targets.numpy()
code_targets = []
for x, y in test_loader:
    z = model.encode(x.view(-1, 28*28))
    code_targets.append(z)
code_targets = torch.cat(code_targets).detach().numpy()

for digit in range(10):
    ax_main.scatter(code_targets[test_targets == digit, 0], 
                    code_targets[test_targets == digit, 1], 
                    s=5, alpha=0.5, label=str(digit))
ax_main.legend();

def onclick(event):
    code_closest = [event.xdata, event.ydata]
    idx = np.argmin(np.sum((code_targets - code_closest)**2, axis=1))
    image, label = mnist_test_data[idx]
    hat_image = nn.Sigmoid()(model.forward(image.view(1, 28*28))).view(1, 28, 28)
    ax_ori.matshow(mnist_test_data[idx][0].numpy()[0], cmap=plt.cm.Greys_r)
    ax_rec.matshow(hat_image.detach().numpy()[0], cmap=plt.cm.Greys_r)
    
cid = fig.canvas.mpl_connect('button_press_event', onclick)