# Super-Resolution Using `torchsr`

## Overview

Super-resolution (SR) is a technique used to enhance the resolution of an image. In this notebook, we'll use `torchsr`, a PyTorch-based library, to perform super-resolution. `torchsr` provides a variety of models that are pre-trained and can be used for high-quality image upscaling.

## Key Concepts

1. **Super-Resolution Models**: These models learn to generate high-resolution images from low-resolution inputs. `torchsr` includes several popular models such as EDSR, VDSR, and ESRGAN.

2. **Preprocessing**: Images are often converted to tensors and normalized before being fed into the model.

3. **Postprocessing**: After the model generates the super-resolved image, it is converted back to a standard image format for visualization or saving.

## Installation

First, install `torchsr` if you haven't already:

In [None]:
!pip install torchsr

## Example Code: Using `torchsr` for Super-Resolution

Below is a detailed example demonstrating how to use `torchsr` for image super-resolution.

In [None]:
# Importing libraries
import torch
from torchsr.models import edsr, rcan, carn
from torchvision.transforms import ToTensor, ToPILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import cv2

In [None]:
# Load the pre-trained model from torchsr
# see: https://github.com/Coloquinte/torchSR


model = edsr(scale=4, pretrained=True)
model.eval()

# Load and preprocess the low-resolution image
image_path = 'images/image_lr_4.png'  # Replace with your image path
lr_image = Image.open(image_path).convert('RGB')
lr_tensor = ToTensor()(lr_image).unsqueeze(0)

# Perform super-resolution
with torch.no_grad():
    sr_tensor = model(lr_tensor)

# Convert the result back to an image
sr_image = ToPILImage()(sr_tensor.squeeze())

plt.imshow(sr_image)
plt.title('Recovered Image')
plt.axis('off')  # Hide axes
plt.show()

In [None]:
# List of models to iterate through
models = {
    'edsr': edsr(scale=4, pretrained=True),
    'rcan': rcan(scale=4, pretrained=True),
    'esrgan': carn(scale=4, pretrained=True)
}

# Set all models to evaluation mode
for model in models.values():
    model.eval()

# Load and preprocess the low-resolution and high-resolution images
lr_image_path = 'images/image_lr_4.png'  # Replace with your image path
hr_image_path = 'images/image_hr.png'    # Replace with your high-resolution image path

lr_image = Image.open(lr_image_path).convert('RGB')
hr_image = Image.open(hr_image_path).convert('RGB')
lr_tensor = ToTensor()(lr_image).unsqueeze(0)

# Convert high-resolution image to numpy for PSNR and SSIM calculation
hr_image_np = np.array(hr_image)

# Create a figure with subplots
fig, axs = plt.subplots(1, len(models), figsize=(15, 5))

# Iterate through the models and perform super-resolution
for idx, (model_name, model) in enumerate(models.items()):
    with torch.no_grad():
        sr_tensor = model(lr_tensor)

    # Convert the result back to an image and to numpy for metric calculations
    sr_image = ToPILImage()(sr_tensor.squeeze())
    sr_image_np = np.array(sr_image)

    # Calculate PSNR and SSIM
    psnr_value = psnr(hr_image_np, sr_image_np)
    ssim_value = ssim(hr_image_np, sr_image_np, win_size=3, channel_axis=-1)

    # Display the results in the subplot
    axs[idx].imshow(sr_image)
    axs[idx].set_title(f'{model_name.upper()}\nPSNR: {psnr_value:.2f}, SSIM: {ssim_value:.3f}')
    axs[idx].axis('on')  # Hide axes

# Adjust layout and display the figure
plt.tight_layout()
plt.show()

In [None]:
# List of models to iterate through
models = {
    'Original': None,  # This will be used to compare the original image to itself
    'Bilinear': Image.Resampling.BILINEAR,
    'Bicubic': Image.Resampling.BICUBIC,
    'Nearest': Image.Resampling.NEAREST,
    'Lanczos': Image.Resampling.LANCZOS,
    'EDSR': edsr(scale=4, pretrained=True),
    'RCAN': rcan(scale=4, pretrained=True),
    'ESRGAN': carn(scale=4, pretrained=True)
}

# Load and preprocess the low-resolution and high-resolution images
lr_image_path = 'images/image_lr_4.png'  # Replace with your image path
hr_image_path = 'images/image_hr.png'    # Replace with your high-resolution image path

lr_image = Image.open(lr_image_path).convert('RGB')
hr_image = Image.open(hr_image_path).convert('RGB')
lr_tensor = ToTensor()(lr_image).unsqueeze(0)

# Convert high-resolution image to numpy for PSNR and SSIM calculation
hr_image_np = np.array(hr_image)

# Create a figure with subplots
fig, axs = plt.subplots(2, len(models) // 2, figsize=(15, 10))

# Iterate through the models and perform super-resolution or interpolation
for idx, (model_name, model) in enumerate(models.items()):
    fil = 0
    if model_name == 'Original':
        sr_image = hr_image  # Compare the original image to itself
    elif isinstance(model, torch.nn.Module):
        # Super-resolution with model
        model.eval()
        with torch.no_grad():
            sr_tensor = model(lr_tensor)
        sr_image = ToPILImage()(sr_tensor.squeeze())
        fil = 1
    else:
        # Interpolation with PIL
        sr_image = lr_image.resize((hr_image.size[0], hr_image.size[1]), resample=model)
    
    # Convert the result to numpy for metric calculations
    sr_image_np = np.array(sr_image)

    # Calculate PSNR and SSIM
    psnr_value = psnr(hr_image_np, sr_image_np)
    ssim_value = ssim(hr_image_np, sr_image_np, win_size=3, channel_axis=-1)

    if idx >= 4:
        idx %= 4
        fil = 1

    # Display the results in the subplot
    axs[fil,idx].imshow(sr_image)
    axs[fil,idx].set_title(f'{model_name}\nPSNR: {psnr_value:.2f}, SSIM: {ssim_value:.3f}')
    axs[fil,idx].axis('off')  # Hide axes

# Adjust layout and display the figure
plt.tight_layout()
plt.show()


To visualize the differences between the super-resolved (or interpolated) images and the original image, you can use several techniques, such as:

- **Absolute Difference Image**: Subtract the original image from the super-resolved image pixel by pixel and display the absolute difference. This will highlight areas where there are greater discrepancies.

- **Residual Image**: Similar to the difference image, but without taking the absolute value. This can show areas that are overestimated or underestimated.


Below is an example of how to implement and visualize the absolute difference image in your code. This will highlight the most significant differences between each super-resolved image and the original.

In [None]:
# Define function to calculate absolute difference
def absolute_difference(image1, image2):
    image1 = np.array(image1).astype("float32")
    image2 = np.array(image2).astype("float32")
    return np.abs(image1 - image2)

# Define function to calculate residual
def residual(image1, image2):
    image1 = np.array(image1).astype("float32")
    image2 = np.array(image2).astype("float32")
    a = image1 - image2
    return a

# Define function to calculate heatmap
def prepare_heatmap(heatmap, scaleHM=False):
    # Ensure the difference image is in the range [0, 255] and convert to uint8
    if scaleHM:
        if heatmap.max() == heatmap.min(): # images are same
            pass
        else:
            heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) * 255
    heatmap = np.clip(heatmap, 0, 255).astype(np.uint8)
    # Convert to grayscale if needed
    if len(heatmap.shape) == 3:
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_RGB2GRAY)
    # Apply color map
    return heatmap


In [None]:
# List of models to iterate through
models = {
    'Original': None,  # This will be used to compare the original image to itself
    'Bilinear': Image.Resampling.BILINEAR,
    'Bicubic': Image.Resampling.BICUBIC,
    'Nearest': Image.Resampling.NEAREST,
    'Lanczos': Image.Resampling.LANCZOS,
    'EDSR': edsr(scale=4, pretrained=True),
    'RCAN': rcan(scale=4, pretrained=True),
    'ESRGAN': carn(scale=4, pretrained=True)
}

# Load and preprocess the low-resolution and high-resolution images
lr_image_path = 'images/image_lr_4.png'  # Replace with your image path
hr_image_path = 'images/image_hr.png'    # Replace with your high-resolution image path

lr_image = Image.open(lr_image_path).convert('RGB')
hr_image = Image.open(hr_image_path).convert('RGB')
lr_tensor = ToTensor()(lr_image).unsqueeze(0)

# Convert high-resolution image to numpy for PSNR, SSIM, and difference calculation
hr_image_np = np.array(hr_image)

# Create a figure with subplots for the images and their differences
fig, axs = plt.subplots(3, len(models), figsize=(20, 10))

# Iterate through the models and perform super-resolution or interpolation
for idx, (model_name, model) in enumerate(models.items()):
    if model_name == 'Original':
        sr_image = hr_image  # Compare the original image to itself
    elif isinstance(model, torch.nn.Module):
        # Super-resolution with model
        model.eval()
        with torch.no_grad():
            sr_tensor = model(lr_tensor)
        sr_image = ToPILImage()(sr_tensor.squeeze())
    else:
        # Interpolation with PIL
        sr_image = lr_image.resize((hr_image.size[0], hr_image.size[1]), resample=model)
    
    # Convert the result to numpy for metric calculations
    sr_image_np = np.array(sr_image)

    # Calculate PSNR and SSIM
    psnr_value = psnr(hr_image_np, sr_image_np)
    ssim_value = ssim(hr_image_np, sr_image_np, win_size=3, channel_axis=-1)

    # Calculate the absolute difference
    abs_diff = absolute_difference(hr_image, sr_image)
    residual_img = residual(hr_image, sr_image)
    heatmap_abs_diff = prepare_heatmap(abs_diff)
    heatmap_residual_img = prepare_heatmap(residual_img, scaleHM=True)
    
    # Display the super-resolved or interpolated image
    axs[0, idx].imshow(sr_image)
    axs[0, idx].set_title(f'{model_name}\nPSNR: {psnr_value:.2f}, SSIM: {ssim_value:.3f}')
    axs[0, idx].axis('off')  # Hide axes

    # Display the difference image
    axs[1, idx].imshow(heatmap_abs_diff, cmap='gray', vmin=0, vmax=255)
    axs[1, idx].set_title(f'Absolute ({model_name})')
    axs[1, idx].axis('off')  # Hide axes

    # Display the difference image
    axs[2, idx].imshow(heatmap_residual_img, cmap='gray', vmin=0, vmax=255)
    axs[2, idx].set_title(f'Residual ({model_name})')
    axs[2, idx].axis('off')  
 

# Adjust layout and display the figure
plt.tight_layout()
plt.show()