# Problem 2: Denoising Autoencoders

In [None]:
!pip3 install imagecorruptions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import random_split, Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from imagecorruptions import corrupt

In [None]:
class CustomMNIST(Dataset):
    def __init__(self, mnist_dataset, corruption):
        # Inputs: 
        # MNIST dataset
        # Imagecorruptions corrupt function
        
        self.mnist_dataset = mnist_dataset
        self.corruption = corruption
    
    def __len__(self):
        # Outputs:
        # returns the number of examples that we have
        
        return len(self.mnist_dataset)
    
    def __getitem__(self, idx):
        # Inputs:
        # idx is the index of the example the dataloader is loading
        # Outputs:
        # (X_corrupted, X_original) = (corrupted version of original MNIST image, original)
        # both are tensors of shape Channels x Height x Width
        # the dataloader adds the extra dimension of batch_size automatically

        image, _ = self.mnist_dataset[idx]

        # Pad the image tensor to get to 32x32
        padded_image = F.pad(image, (2, 2, 2, 2), 'constant', 0)

        # Denormalize the image tensor to [0, 1]
        denormalized_image = (padded_image + 1) / 2.0  # Assuming the tensor is normalized to [-1, 1]

        # Convert the tensor values from [0, 1] to [0, 255] and then to a numpy array
        image_np = (denormalized_image.squeeze(0).cpu().numpy() * 255).astype(np.uint8)

        # Corrupt the image
        corrupted_image_np = corrupt(image_np, corruption_name=self.corruption, severity=1)
        corrupted_image_np = corrupted_image_np[:, :, 0]  # Take only one channel to make it grayscale

        # Convert the numpy arrays back to torch tensors
        X_original = transforms.ToTensor()(image_np).float()
        X_corrupted = transforms.ToTensor()(corrupted_image_np).float()

        return (X_corrupted, X_original)
 

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = datasets.MNIST('./data', download=True, train=True, transform=transform)
testset = datasets.MNIST('./data', download=True, train=False, transform=transform)

# 80-20 split for training and validation
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
trainset, valset = random_split(trainset, [train_size, val_size])

# Adding the corrupted image version to datasets
corruption = 'gaussian_blur'
trainset = CustomMNIST(trainset, corruption)
valset = CustomMNIST(valset, corruption)
testset = CustomMNIST(testset, corruption)

batch_size = 8
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
valloader = DataLoader(valset, batch_size=batch_size, shuffle=False)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

In [None]:
class DenoisingAutoencoder(nn.Module):
    def __init__(self):
        super(DenoisingAutoencoder, self).__init__()
        # Input = Batch Size x Channels X Height x Width
        
        # Encoder Layer  = Input Size  -> Output Size
        # First Conv     = BSx1x32x32  -> BSx32x32x32
        # First Maxpool  = BSx32x32x32 -> BSx32x16x16
        # Second Conv    = BSx32X16x16 -> BSx64x16x16
        # Second Maxpool = BSx64x16x16 -> BSx64x8x8
        # Third Conv     = BSx64x8x8   -> BSx128x8x8
        # Third Maxpool  = BSx128x8x8  -> BSx128x4x4
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # O = [BS, 32, 32, 32]
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # O = [BS, 32, 16, 16]
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # O = [BS, 64, 16, 16]
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # O = [BS, 64, 8, 8]
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # O = [BS, 128, 8, 8]
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2)  # O = [BS, 128, 4, 4]
        )

        # Decoder Layer   = Input Size  -> Output Size
        # First Conv      = BSx128x4x4  -> BSx128x4x4
        # First Upsample  = BSx128x4x4  -> BSx128x8x8
        # Second Conv     = BSx128x8x8  -> BSx64x8x8
        # Second Upsample = BSx64x8x8   -> BSx64x16x16
        # Third Conv      = BSx64x16x16 -> BSx32x16x16
        # Third Upsample  = BSx32x16x16 -> BSx32x32x32
        # Fourth Conv     = BSx32x32x32 -> BSx1x32x32
        # Sigmoid         = BSx1x32x32  -> BSx1x32x32
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),  # [BS, 128, 4, 4]
            nn.ReLU(True),
            nn.Upsample(scale_factor=2),  # [BS, 128, 8, 8]
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),  # [BS, 64, 8, 8]
            nn.ReLU(True),
            nn.Upsample(scale_factor=2),  # [BS, 64, 16, 16]
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),  # [BS, 32, 16, 16]
            nn.ReLU(True),
            nn.Upsample(scale_factor=2),  # [BS, 32, 32, 32]
            nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),  # [BS, 1, 32, 32]
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
def train(model, device, trainloader, criterion, optimizer):
    train_loss = 0.0
    model.train()
    
    for corrupted, original in trainloader: 
        corrupted, original = corrupted.to(device), original.to(device)
        
        # Forward Pass
        reconstructed = model(corrupted)
        loss = criterion(reconstructed, original)
        
        # Backward Pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    # Average training loss for the epoch
    train_loss /= len(trainloader)
    return train_loss

In [None]:
def validate(model, device, valloader, criterion):
    model.eval()
    val_loss = 0.0
    correct = 0

    with torch.no_grad():  # no need to track gradients
        for corrupted, original in valloader:
            corrupted, original = corrupted.to(device), original.to(device)
            
            reconstructed = model(corrupted)
            loss = criterion(reconstructed, original)
            
            val_loss += loss.item()

    # Average validation loss
    val_loss /= len(valloader)
    return val_loss

In [None]:
learning_rate = 0.001

model = DenoisingAutoencoder()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
criterion = nn.MSELoss()
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
num_epochs = 10

training_losses = []
validation_losses = []

for epoch in range(1, num_epochs + 1):
    train_loss = train(model, device, trainloader, criterion, optimizer)
    val_loss = validate(model, device, valloader, criterion)
    
    training_losses.append(train_loss)
    validation_losses.append(val_loss)
    
    # Update the learning rate
    scheduler.step()
    
    print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

# Plotting
plt.plot(training_losses, label="Training Loss")
plt.plot(validation_losses, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# Takes tensors of 1x32x32
# Plots 32x32 image
def plotImages(corrupted, reconstructed, original):
    corrupted = corrupted.cpu().detach().squeeze().numpy()
    reconstructed = reconstructed.cpu().detach().squeeze().numpy()
    original = original.cpu().detach().squeeze().numpy()
    
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(corrupted, cmap='gray')
    plt.title('Corrupted Image')

    plt.subplot(1, 3, 2)
    plt.imshow(reconstructed, cmap='gray')
    plt.title('Reconstructed Image')
    
    plt.subplot(1, 3, 3) 
    plt.imshow(original, cmap='gray')
    plt.title('Original Image')

    plt.tight_layout()
    plt.show()
    
def plotBatch(corrupted_batch, reconstructed_batch, original_batch):
    batch_size = corrupted_batch.size(0)
    for i in range(batch_size):
        plotImages(corrupted_batch[i], reconstructed_batch[i], original_batch[i])

In [None]:
def test(model, device, testloader, criterion):
    model.eval()
    test_loss = 0.0
    all_losses = []
    all_originals = []
    all_corrupteds = []
    all_reconstructed = []

    with torch.no_grad():  
        for corrupted, original in testloader:
            corrupted, original = corrupted.to(device), original.to(device)
            
            reconstructed = model(corrupted)
            loss = criterion(reconstructed, original)
            
            test_loss += loss.item()
            
            # Append data for plotting
            all_losses.append(loss.item())
            all_originals.append(original.cpu())
            all_corrupteds.append(corrupted.cpu())
            all_reconstructed.append(reconstructed.cpu())
    
    # Plot the first reconstruction from the first 10 batches
    for loss, corrupted, reconstructed, original in zip(all_losses[:10], all_corrupteds[:10], all_reconstructed[:10], all_originals[:10]):
        print(f"Loss = {loss}")
        plotImages(corrupted[0], reconstructed[0], original[0])
    
    # Average test loss
    test_loss /= len(testloader)
    print(f"Total Test Loss = {test_loss}")

In [None]:
test(model, device, testloader, criterion)