<img src="convolutional_denoising_autoencoder.png">

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image


class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        # ====== ENCODER PART ======       
        # MNIST image is 1x28x28 (CxHxW)
        # Pytorch expects input data as BxCxHxW 
        # B: Batch size
        # C: number of channels gray scale images have 1 channel
        # W: width of the image 
        # H: height of the image

        # use 32 3x3 filters with padding
        # padding is set to 1 so that image W,H is not changed after convolution
        # stride is 2 so filters will move 2 pixels for next calculation  
        # W after conv2d  [(W - Kernelw + 2*padding)/stride] + 1
        # after convolution we'll have Bx32 14x14 feature maps (28-3+2)/2 + 1 = 14
        self.conv1 = nn.Conv2d(in_channels=1,      # 1 channel because gray scaled image
                                out_channels=32,   # apply 32 filters and get a feature map for each filter
                                kernel_size=3,     # filters are 3x3 weights
                                stride=2,          # halves the size of the image
                                padding=1)


        
        # after convolution we'll have Bx64 7x7 feature maps 
        self.conv2= nn.Conv2d(in_channels=32,
                                out_channels=64,
                                kernel_size=3,
                                stride=2,
                                padding=1
                                )


        # first fully connected layer from 64*7*7=3136 input features to 16 hidden units
        self.fc1 = nn.Linear(in_features=64*7*7,
                                out_features=16)

        # ====== DECODER PART ======   
        self.fc2 = nn.Linear(in_features=16,
                                out_features=64*7*7)

          # 32 14x14
        self.conv_t1 = nn.ConvTranspose2d(in_channels=64,
                                            out_channels=32,
                                            kernel_size=3,
                                            stride=2,
                                            padding=1,
                                            output_padding=1)                                


        # 1 28x28
        self.conv_t2 = nn.ConvTranspose2d(in_channels=32,
                                            out_channels=1,
                                            kernel_size=3,
                                            stride=2,
                                            padding=1,
                                            output_padding=1)                              

                           


    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, start_dim=1) # flatten feature maps, Bx(C*H*W)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = x.view(-1,64,7,7) # reshape back to feature map format
        x = F.relu(self.conv_t1(x))
        x = torch.tanh(self.conv_t2(x))
        return x             


def to_img(x):
    x = 0.5 * (x + 1)   # from [-1, 1] range to [0, 1] range
    x = x.clamp(0, 1)   # assign less than 0 to 0, bigger than 1 to 1
    x = x.view(x.size(0), 1, 28, 28) # B, C, H, W format for MNIST
    return x


num_epochs = 50
batch_size = 128
learning_rate = 1e-3
n_batches = 60000 // batch_size

# normalize each image and set the pixel values between -1 and 1
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# prepare data loader
dataset = MNIST('./data', transform=img_transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# determine where to run the code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# create an AutoEncoder network instance
net = AutoEncoder().to(device)
# print(net)  # display the architecture
loss_function = nn.MSELoss().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate,
                             weight_decay=1e-5)


def train(net, loader, loss_func, optimizer):
    net.train()                                      # put model in train mode
    total_loss = torch.zeros(1).to(device)
    for img, _ in loader:                            # next batch
        img = img.to(device)                         # move to gpu if available
        noise = torch.randn(*img.shape).to(device)   # generate random noise        
        noised_img = img.masked_fill(noise > 0.5, 1) # set image values at indices where noise >0.5  to 1
        output = net(noised_img)                     # feed forward
        loss = loss_func(output, img)                # calculate loss 
        
        optimizer.zero_grad()                        # clear previous gradients 
        loss.backward()                              # calculate new gradients
        optimizer.step()                             # update weights 
        total_loss += loss                           # accumulate loss
    return noised_img, img, output, total_loss

for epoch in range(num_epochs):
    noised_img, img, output, loss = train(net, dataloader, loss_function, optimizer)
    # log
    print('epoch [{}/{}], loss:{:.4f}'
            .format(epoch+1, num_epochs, loss.item()/n_batches))
    if epoch % 10 == 0:
        pic_org = to_img(img.cpu().data)
        pic_noised = to_img(noised_img.cpu().data)
        pic_pred = to_img(output.cpu().data)
        save_image(pic_org, './denoise_image_org__{}.png'.format(epoch))
        save_image(pic_noised, './denoise_image_noised__{}.png'.format(epoch))
        save_image(pic_pred, './denoise_image_pred__{}.png'.format(epoch))

# save the model
torch.save(net.state_dict(), './conv_autoencoder.pth')

0it [00:00, ?it/s]Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
9920512it [00:08, 1154180.38it/s]                             
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
0it [00:00, ?it/s]Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
32768it [00:00, 73669.16it/s]            
0it [00:00, ?it/s]Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
1654784it [00:01, 914824.71it/s]                              
0it [00:00, ?it/s]Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
8192it [00:01, 7985.65it/s]             
Extracting ./data/MNIST

KeyboardInterrupt: 