# Denoising Autoencoder
_(Requires Python 3, PyTorch 1.0.1, TorchVision 0.2.2)_

### Libraries
Import torch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as func
import torch.utils.data as Data

we'll also need numpy and matplotlib

In [None]:
import numpy as np
import matplotlib.pyplot as plt

## Model
This consists in the encoder and decoder block. The main difference with an autoencoder is that we add noise to the input data. The cost function is the binary crossentropy between the noise-free input and the reconstructed output.

### Encoder

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dims, lat_dim, activations=func.relu):
        super(Encoder, self).__init__()
        
        # Parse input arguments
        if type(hidden_dims) == int:
            hidden_dims = [hidden_dims,]

        if type(activations) == list:
            if len(activations) != len(hidden_dims):
                raise ValueError('activations and hidden_dims must have the same dimensions')
        else:
            activations = [activations]*len(hidden_dims)

        # Store arguments
        self.hidden_dims = np.array(hidden_dims)
        self.activations = activations
        
        # Create layers
        self.layers = nn.ModuleList()
        prev_d = input_dim
        for d in hidden_dims:
            self.layers.append(nn.Linear(prev_d, d))
            prev_d = d
            
        # Latent layer
        self.layers.append(nn.Linear(prev_d, lat_dim))
        self.activations.append(lambda x: x)
        
    def forward(self, x):
        for layer, activation in zip(self.layers, self.activations):
            x = activation(layer(x))
        
        return x



### Decoder

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, lat_dim, hidden_dims, output_dim, activations=func.relu):
        super(Decoder, self).__init__()
        
        # Parse input arguments
        if type(hidden_dims) == int:
            hidden_dims = [hidden_dims,]
        
        if type(activations) == list:
            if len(activations) != len(hidden_dims):
                raise ValueError('activations and hidden_dims must have the same dimensions')
        else:
            activations = [activations]*len(hidden_dims)
            
        # Store arguments
        self.hidden_dims = np.array(hidden_dims)
        self.activations = activations
        activations.append(torch.sigmoid)
        
        # Create layers
        self.layers = nn.ModuleList()
        prev_d = lat_dim
        for d in self.hidden_dims:
            self.layers.append(nn.Linear(prev_d, d))
            prev_d = d
            
        # Output layer
        self.layers.append(nn.Linear(prev_d, output_dim))

            
    def forward(self, x):
        for layer, activation in zip(self.layers, self.activations):
            x = activation(layer(x))
        
        return x

### Denoising autoencoder

In [None]:
class DenoisingAutoencoder(nn.Module):
    # Mou el tema del noise aqui i treu-lo del encoder!!
    
    def __init__(self, input_dim, hidden_dims, lat_dim, activations=func.relu, noise='salt-pepper', noise_p=0.3):
        super(DenoisingAutoencoder, self).__init__()
        
        # Parse arguments
        if type(hidden_dims) == int:
            hidden_dims = [hidden_dims,]
            
        if noise not in ['salt-pepper',]:
            raise ValueError('noise options: salt-pepper')
        
        # Store arguments
        self.noise = noise
        self.noise_p = noise_p
        
        # Build encoder and decoder
        h_dim = np.array(hidden_dims)
        self.encoder = Encoder(input_dim, h_dim, lat_dim, activations)
        self.decoder = Decoder(lat_dim, np.flip(h_dim), input_dim, activations)
        
    
    def forward(self, x):
        x_n = self.apply_noise(x)
        lat = self.encoder(x_n)
        out = self.decoder(lat)
        
        return out
        
    def apply_noise(self, x):
        x_n = x.clone()
        if self.noise == 'salt-pepper':
            idx = np.where(np.random.random(x.shape) < self.noise_p)
            val = torch.Tensor(np.array(np.random.random(len(idx[0])) < 0.5, dtype=int))
            x_n[idx] = val

        return x_n
            

## Training example

### Training Data
Let's train the denoising autoencoder on the MNIST dataset.

In [None]:
import torchvision
train = torchvision.datasets.MNIST('./', train=True, download=True, transform=torchvision.transforms.ToTensor())
test = torchvision.datasets.MNIST('./', train=False, download=True, transform=torchvision.transforms.ToTensor())

### Instantiate the Autoencoder

In [None]:
input_dim  = 28*28
layers     = [200,100,30]
latent_dim = 2

noise_p = 0.3

autoencoder = DenoisingAutoencoder(input_dim, layers, latent_dim, noise_p = noise_p)

### Optimizer and Loss function

In [None]:
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.0005)

In [None]:
loss_func = func.binary_cross_entropy

### Training loop

In [None]:
%matplotlib notebook

n_test_img = 4
epochs     = 10
batch_size = 100

# Reshape data
train_samples = train.data.view(-1, 28*28).type(torch.float32)/255.0
test_samples  = test.data.view(-1,28*28).type(torch.float32)/255.0

# Lists to store training losses
train_loss = []
test_loss  = []

# Plot test input images
test_imgs = test.data[0:n_test_img,:].type(torch.float32).view(-1,28*28)/255.0
f, a = plt.subplots(2, n_test_img, figsize=(8, 3))
for i in range(n_test_img):
    a[0][i].imshow(255-np.reshape(test_imgs.data.numpy()[i], (28,28)), cmap='gray')
    a[0][i].set_xticks(())
    a[0][i].set_yticks(())
    
loss_text = f.text(0, 0, "epoch: 0, loss: 0")

# Data iterator
train_batches = Data.DataLoader(dataset=train, batch_size=batch_size, shuffle=True)

for e in np.arange(epochs):
    batch_loss = 0
    for batch_no, (batch, batch_labels) in enumerate(train_batches):
        # Input and target data (flatten)
        b_in = batch.view(-1, 28*28)
        target = batch.view(-1, 28*28)
        # Forward pass of the data through the network
        out  = autoencoder(b_in)
        # Compute the Loss
        loss = loss_func(out, target)
        batch_loss += loss
        # Reset the gradients
        optimizer.zero_grad()
        loss.backward()
        # Update the gradients
        optimizer.step()
        
        # Test images
        if batch_no % 50 == 0:
            test_out = autoencoder(test_imgs)
            for i in range(n_test_img):
                a[1][i].imshow(1.0-np.reshape(test_out.data.numpy()[i], (28,28)), cmap='gray')
                a[1][i].set_xticks(())
                a[1][i].set_yticks(())
            loss_text.set_text("epoch: {}, loss: {:.3f}".format(e+1, loss))
            f.canvas.draw()
            
    # Compute batch loss
    train_loss.append(batch_loss/batch_no)
    test_loss.append(func.binary_cross_entropy(autoencoder(test_samples), test_samples))

### Loss plots

In [None]:
plt.figure()
plt.plot(train_loss)
plt.plot(test_loss)
plt.legend(["train", "test"])

## Results
### Loss on test dataset
We can easily compute the Loss by evaluating the loss function on the output data:

#### Train Loss

In [None]:
print("Train loss: {}".format(train_loss[-1]))

#### Test Loss

In [None]:
print("Test loss: {}".format(test_loss[-1]))

### Latent space
The projections of the test images in the latent space after training look like this:

In [None]:
test_lat = autoencoder.encoder(test_samples).data.numpy()
plt.figure(figsize=(10,7))
plt.scatter(test_lat[:,0], test_lat[:,1], c=test.targets.numpy(), s=1.7)