# Deep Image Prior for Super-Resolution

Deep Image Prior (DIP) can be applied to the problem of super-resolution, where the goal is to upscale a low-resolution (LR) image to a higher resolution (HR) using a neural network. Unlike traditional methods that require pre-trained models, DIP uses a randomly initialized network and optimizes it to reconstruct the high-resolution image.

In this notebook, we will:
1. Set up the environment
2. Define the neural network architecture
3. Load and preprocess images
4. Train the network for super-resolution
5. Evaluate and display the results

Let's dive in!

**Importing Libraries**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 2. Define the Neural Network

We will use a convolutional neural network for super-resolution. The network consists of several convolutional layers with ReLU activations.

Let's define the network.

In [None]:
import torch
import torch.nn as nn

class DIPNet(nn.Module):
    def __init__(self):
        super(DIPNet, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, padding=2),  # Input: 3xH x W
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)  # Up-sample to match HR image size
        )
    
    def forward(self, x):
        return self.model(x)


## 3. Load and Preprocess Images

We need to load the high-resolution image and create a low-resolution version by downsampling. Then, we'll prepare these images for training.

Let's load and preprocess the images.

In [None]:
def load_image(image_path, size=(256, 256)):
    image = Image.open(image_path).convert('RGB')
    image = image.resize(size, Image.BICUBIC)
    transform = transforms.ToTensor()
    return transform(image).unsqueeze(0).to(device)

def downsample(image, scale_factor):
    """
    Downsample an image by a given scale factor.
    
    Parameters:
    - image: The image tensor to downsample.
    - scale_factor: The factor by which to downsample.
    
    Returns:
    - The downsampled image tensor.
    """
    downsampled = torch.nn.functional.interpolate(image, scale_factor=1/scale_factor, mode='bilinear', align_corners=False)
    return downsampled

def add_noise(image, noise_level=0.1):
    noise = noise_level * torch.randn_like(image)
    return image + noise

# Load and preprocess the image
image_path = 'images/image_hr.png'  # Replace with your image path
hr_image = load_image(image_path)

# Create a low-resolution image by downsampling
lr_image = downsample(hr_image, scale_factor=4)
noisy_lr_image = add_noise(lr_image, noise_level=0.1)


In [None]:
import matplotlib.pyplot as plt

def display_images(original, lr, lrnoisy):
    """
    Display the original and noisy images side by side.
    
    Parameters:
    - original: The original high-resolution image tensor.
    - noisy: The noisy image tensor.
    """
    # Convert tensors to numpy arrays and transpose dimensions for plotting
    original = original.squeeze().cpu().detach().numpy().transpose(1, 2, 0)
    lr = lr.squeeze().cpu().detach().numpy().transpose(1, 2, 0)
    lrnoisy = lrnoisy.squeeze().cpu().detach().numpy().transpose(1, 2, 0)
    
    # Plot the images
    fig, axs = plt.subplots(1, 3, figsize=(12, 6), gridspec_kw={'width_ratios': [1, 0.5, 0.5]})
    axs[0].imshow(original)
    axs[0].set_title('Original Image')
    axs[0].axis('off')
    
    axs[1].imshow(lr)
    axs[1].set_title('Low Res. Image')
    axs[1].axis('off')
    
    axs[2].imshow(lrnoisy)
    axs[2].set_title('Noisy Low Res. Image')
    axs[2].axis('off') 

    plt.show()

# Display the images
display_images(hr_image, lr_image, noisy_lr_image)

## 4. Train the Network for Super-Resolution

We will train the network to upscale the low-resolution image to its high-resolution counterpart by minimizing the Mean Squared Error (MSE) between the restored and original high-resolution images.

Let's define the training loop.

In [None]:
def train(model, noisy_lr_image, hr_image, num_epochs=2000, learning_rate=0.001):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        # Forward pass
        restored_image = model(noisy_lr_image)
        # Compute loss
        loss = criterion(restored_image, hr_image)
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
    
    return model

# Initialize and train the model
model = DIPNet().to(device)
trained_model = train(model, noisy_lr_image, hr_image)

## 5. Evaluate and Display the Results

After training, we will use the model to upscale the low-resolution image and visualize the results. This will help us assess how well the network has restored the image.

Let's visualize the original, low-resolution, and restored images.

In [None]:
def display_images(original, lr, restored):
    """
    Display the original high-resolution image, low-resolution image, and the restored image side by side.
    
    Parameters:
    - original: The original high-resolution image tensor.
    - lr: The low-resolution image tensor.
    - restored: The restored high-resolution image tensor.
    """
    original = original.squeeze().cpu().detach().numpy().transpose(1, 2, 0)
    lr = lr.squeeze().cpu().detach().numpy().transpose(1, 2, 0)
    restored = restored.squeeze().cpu().detach().numpy().transpose(1, 2, 0)
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5), gridspec_kw={'width_ratios': [1, 0.5, 1]})
    axs[0].imshow(original)
    axs[0].set_title('Original Image')
    axs[0].axis('off')
    
    axs[1].imshow(lr)
    axs[1].set_title('Low Res Image')
    axs[1].axis('off')
    
    axs[2].imshow(restored)
    axs[2].set_title('Restored Image')
    axs[2].axis('off')
    
    plt.show()

# Restore the low-resolution image
model.eval()
with torch.no_grad():
    restored_image = model(noisy_lr_image)

# Display the images
display_images(hr_image, lr_image, restored_image)

# Conclusion

In this notebook, we implemented Deep Image Prior (DIP) for super-resolution. We defined a simple neural network, prepared the data, trained the network to upscale low-resolution images, and visualized the results.

Feel free to modify the network architecture, training parameters, or test with different images to explore the capabilities of Deep Image Prior.