# Image Super-Resolution with PyTorch

In this notebook, we'll learn how to implement an image super-resolution model using PyTorch. Super-resolution is a technique that enhances the resolution of images using deep learning methods.

## Importing Libraries

We'll import the necessary libraries for this project.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import numpy as np

## Defining the Model

We will define a simple convolutional neural network model for super-resolution based on the SRCNN (Super-Resolution Convolutional Neural Network) architecture.

In [None]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=9, padding=4)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
        
        # Transposed convolution layer to upscale the image
        self.deconv = nn.ConvTranspose2d(3, 1, kernel_size=4, stride=2, padding=1, output_padding=0)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.deconv(x)  # Transposed convolution to upscale the image
        return x

model = SRCNN()
print(model)

## Preparing the Data

For this example, we'll use the MNIST dataset, which consists of small images. We'll preprocess the images to be of lower resolution and use them to train our super-resolution model.

In [None]:
transform_lr = transforms.Compose([
    transforms.Resize((16, 16)),  # Downsample to low resolution
    transforms.ToTensor()
])

# High resolution target images
transform_hr = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to high resolution
    transforms.ToTensor()
])

# Load MNIST dataset
full_dataset_hr = MNIST(root='./data', train=True, download=True, transform=transform_hr)

# Create a subset of 560 images
indices = list(range(len(full_dataset_hr)))
np.random.shuffle(indices)
subset_indices = indices[:560]
hr_subset = Subset(full_dataset_hr, subset_indices)
trainloader_hr = DataLoader(hr_subset, batch_size=16, shuffle=True)

## Defining the Loss Function and Optimizer

We'll use Mean Squared Error (MSE) as our loss function and Adam optimizer to update the model weights.

In [None]:
criterion = nn.MSELoss()
# criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

## Training the Model

We'll train the model using the MNIST dataset.

In [None]:
num_epochs = 50

for epoch in range(num_epochs):
    running_loss = 0.0
    for idx, (hr_images, _) in enumerate(trainloader_hr):
        # Generate low-resolution images by downsampling high-resolution images
        lr_images = torch.stack([transform_lr(transforms.ToPILImage()(img)) for img in hr_images])

        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(lr_images)

        if epoch == 0 and idx == 0:
            print(f'LR Images Size: {lr_images.size()}')
            print(f'HR Images Size: {hr_images.size()}')
            print(f'Model Output Size: {outputs.size()}')

        loss = criterion(outputs, hr_images)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(trainloader_hr)}')

print('Training completed')

In [None]:
def show_images(lr_images, hr_images, sr_images):
    lr_images = lr_images.permute(0, 2, 3, 1).numpy()
    hr_images = hr_images.permute(0, 2, 3, 1).numpy()
    sr_images = sr_images.permute(0, 2, 3, 1).numpy()
    
    fig, axs = plt.subplots(3, 5, figsize=(15, 9))
    for i in range(5):
        axs[0, i].imshow(np.clip(hr_images[i], 0, 1))
        axs[0, i].set_title('High Resolution')
        axs[0, i].axis('on')
        
        axs[1, i].imshow(np.clip(lr_images[i], 0, 1))
        axs[1, i].set_title('Low Resolution')
        axs[1, i].axis('on')
        
        axs[2, i].imshow(np.clip(sr_images[i], 0, 1))
        axs[2, i].set_title('Super-Resolved')
        axs[2, i].axis('on')
    
    plt.show()

# Visualization
hr_images = next(iter(trainloader_hr))[0]
lr_images = torch.stack([transform_lr(transforms.ToPILImage()(img)) for img in hr_images])
with torch.no_grad():
    sr_images = model(lr_images)
show_images(lr_images, hr_images, sr_images)