# Denoising Auto Encoder

In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import utils
import numpy as np
from matplotlib import pyplot as plt
import wandb
wandb.init(project='denoising_auto_encoder_class')

[34m[1mwandb[0m: Currently logged in as: [33mingambe[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


We can also use Auto Encoder to reconstruct a partially destructed image

The model will map the noisy input image to the latent space and back to the input space while removing the noise

<center>
    <img src='images/15_denoising_ae.png' width=55% style="margin-left:auto; margin-right:auto"/>
    <p style="font-size:14px;">Source: <a href='https://atcold.github.io/pytorch-Deep-Learning/en/week07/07-3/'>NYU Deep Learning</a></p>
</center>

Bellow is a plot of the traveling distance in a denoising auto encoder

<center>
    <img src='images/17_distance.png' width=55% style="margin-left:auto; margin-right:auto"/>
    <p style="font-size:14px;">Source: <a href='https://atcold.github.io/pytorch-Deep-Learning/en/week07/07-3/'>NYU Deep Learning</a></p>
</center>

The lighter the colour, the longer the distance a point travelled. From the diagram, we can tell that the points at the corners travelled close to 1 unit, whereas the points within the 2 branches didn’t move at all since they are attracted by the top and bottom branches during the training process.

In [2]:
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.view(x.size(0), 28, 28)
    return x

In [3]:
def layer_init(m):
    torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
    torch.nn.init.constant_(m.bias, 0)
    return m

class CNNDenosingAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            layer_init(nn.Conv2d(1, 32, 5)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 32, 4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 3)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3)),
        )
        self.decoder = nn.Sequential(
            layer_init(nn.ConvTranspose2d(64, 64, 3)),
            nn.ReLU(),
            layer_init(nn.ConvTranspose2d(64, 32, 3)),
            nn.ReLU(),
            layer_init(nn.ConvTranspose2d(32, 32, 4)),
            nn.ReLU(),
            layer_init(nn.ConvTranspose2d(32, 1, 5))
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [4]:
batch_size = 128

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = MNIST('./data', transform=img_transform, download=True)
indices = torch.arange(5000)
mnist_5k = torch.utils.data.Subset(dataset, indices)
dataloader = DataLoader(mnist_5k, batch_size=batch_size, shuffle=True)

In [5]:
model = CNNDenosingAutoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
%matplotlib inline
from tqdm.notebook import tqdm

noise = torch.nn.Dropout(p=0.5)

num_epochs = 10
for epoch in range(num_epochs):
    losses = 0
    for img, label in tqdm(dataloader, unit='batch'):
        with torch.no_grad():
            noisy_imgs = noise(img)
        output = model(noisy_imgs)
        loss = criterion(output, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses += loss.item()
    wandb.log({'loss': losses / len(dataloader)})

  0%|          | 0/40 [00:00<?, ?batch/s]

In [None]:
in_pic = to_img(noisy_imgs.cpu().data)
plt.figure(figsize=(18, 6))
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.imshow(in_pic[i+4])
    plt.axis('off')
    
out_pic = to_img(output.cpu().data)
plt.figure(figsize=(18, 6))
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.imshow(out_pic[i+4])
    plt.axis('off')