# 1 - Edge Filter
En este primer ejercicio, crearemos una red convolucional mínima para aprender un filtro de bordes

In [2]:
# Import de paquetes
%matplotlib inline
import matplotlib.pyplot as plt
%matplotlib inline
plt.ion()

import sys
import os

# Numpy
import numpy as np
from skimage import color, io

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# Torchvision
import torchvision.utils
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

# Dataset: CIFAR10 https://www.cs.toronto.edu/~kriz/cifar.html
from torchvision.datasets import CIFAR10

## Hiperparámetros de entrenamiento

In [3]:
# Hiperparametros
num_epochs = 10
batch_size = 512
learning_rate = 1e-1
use_gpu = True

## Clase dataset

In [None]:
gt_edge_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
gt_edge_filter.weight.data = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3)

def edge_filter(img):
    with torch.no_grad():
        return gt_edge_filter(img.unsqueeze(0)).squeeze(0)

class EdgeImageDataset(Dataset):
    
    def __init__(self, root, train=True):
        self.cifar = CIFAR10(root=root, train=train, transform=None, target_transform=None, download=True)
        
        self.img_transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor()
        ])
        
        self.img2edge_transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Lambda(edge_filter)
        ])
    
    def __getitem__(self, index):
        img = self.cifar[index][0]
        return (self.img_transform(img), self.img2edge_transform(img))
        
    def __len__(self):
        return len(self.cifar)

train_dataset = EdgeImageDataset('./data/CIFAR10', train=True)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = EdgeImageDataset('./data/CIFAR10', train=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/CIFAR10/cifar-10-python.tar.gz


## Clase CNN

In [None]:
class EdgeNet(nn.Module):
    def __init__(self, d=128):
        super(EdgeNet, self).__init__()
        
        # no vamos a usar bias, de esa forma, el kernel aprendido debería ser 
        # similar al que utilizamos para el filtro del target
        self.conv = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)

    def forward(self, input):
        return self.conv(input)

enet = EdgeNet()

# Selecciono el dispositivo de entrenamiento a utilizar (gpu/cpu)
device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
enet = enet.to(device)
print('Using device: %s' % (device))

# Cantidad de parámetros de mi red
num_params = sum(p.numel() for p in enet.parameters() if p.requires_grad)
print('Number of parameters: %d' % (num_params))

## Loop de entrenamiento

In [None]:
# Optimizador de los parámetros
optimizer = torch.optim.Adam(params=enet.parameters(), lr=learning_rate)

# Modo de entrenamiento 
enet.train()

train_loss_avg = []

print('Training ...')
for epoch in range(num_epochs):
    train_loss_avg.append(0)
    num_batches = 0
    
    for image_batch, edge_batch in train_dataloader:
        
        # obtengo los batch de datos
        image_batch = image_batch.to(device)
        edge_batch = edge_batch.to(device)
        
        # intentamos predecir los bordes de la imagen
        predicted_edge_batch = enet(image_batch)
        
        # computamos la distancia L2 con respecto a la imagen de bordes del target
        loss = F.mse_loss(predicted_edge_batch, edge_batch)
        
        # backpropagation: computa los gradientes
        optimizer.zero_grad()
        loss.backward()
        
        # un paso en el optimizador, en la dirección establecida por los gradientes
        optimizer.step()
        
        # llevo un registro de la loss en cada época para plotear
        train_loss_avg[-1] += loss.item()
        num_batches += 1
        
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average loss: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))

## Plot del histórico de losses

In [None]:
# Graficamos la curva histórica de losses
fig = plt.figure(figsize=(15, 5))
plt.plot(train_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
# Si queremos guardar los parámetros de la red
# torch.save(enet.state_dict(), './detector_de_bordes.pth')

## Testing

In [None]:
# Performance en el set de testing
enet.eval()

test_loss_avg, num_batches = 0, 0
for image_batch, edge_batch in test_dataloader:

    with torch.no_grad():
   
        image_batch = image_batch.to(device)
        edge_batch  = edge_batch.to(device)

        # predecimos 
        predicted_edge_batch = enet(image_batch)

        # computamos la distancia L2
        loss = F.mse_loss(predicted_edge_batch, edge_batch)

        test_loss_avg += loss.item()
        num_batches += 1
    
test_loss_avg /= num_batches
print('average loss: %f' % (test_loss_avg))

In [None]:
# Miramos algunas de las imágenes que predecimos en comparación a las del dataset
with torch.no_grad():

    # elegimos imágenes random
    image_inds  = np.random.choice(len(test_dataset), 25, replace=False)
    image_batch = torch.stack([test_dataset[i][0] for i in image_inds])
    edge_batch  = torch.stack([test_dataset[i][1] for i in image_inds])
    image_batch = image_batch.to(device)

    # predecimos los bordes
    predicted_edge_batch = enet(image_batch)

    # importante! para poder plotear, necesitamos traerlas si o si a cpu 
    # pueden estar en gpu si seleccionamos ese dispositivo al inicio
    image_batch.cpu()
    predicted_edge_batch.cpu()

    # get edge magnitudes in [0,1]
    predicted_edge_batch = ((predicted_edge_batch/4.0).pow(2)*4.0).clamp(min=0, max=1)
    edge_batch           = ((edge_batch/4.0).pow(2)*4.0).clamp(min=0, max=1)

    # detach indica que no necesitamos diferenciación automática para este tensor
    # vamos a empezar a operar con el de una forma que no tiene que ser trackeada
    # http://www.bnikolic.co.uk/blog/pytorch-detach.html
    print('learned edge filter')
    print(enet.conv.weight.cpu().detach().numpy())

    print('ground truth edge filter')
    print(gt_edge_filter.weight.cpu().detach().numpy())

    # plot images
    fig, ax = plt.subplots(figsize=(7, 7), nrows=1, ncols=1)
    ax.imshow(np.transpose(torchvision.utils.make_grid(image_batch, nrow=5).cpu().numpy(), (1, 2, 0)))
    ax.title.set_text('images')

    # plot edges
    fig, ax = plt.subplots(figsize=(15, 15), nrows=1, ncols=2)
    ax[0].imshow(np.transpose(torchvision.utils.make_grid(predicted_edge_batch, nrow=5).cpu().numpy(), (1, 2, 0)))
    ax[0].title.set_text('predicted edges')
    ax[1].imshow(np.transpose(torchvision.utils.make_grid(edge_batch, nrow=5).cpu().numpy(), (1, 2, 0)))
    ax[1].title.set_text('ground truth edges')
    plt.show()