# Image Super Resolution with ESRGAN
### Using Pytorch



In [2]:
# importing libraries
from PIL import Image
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from torch.utils.data import DataLoader, Dataset
from torchvision import models
from torchvision.utils import save_image

## GAN using VGG model

In [10]:
# define dataset class, transformations, and dataloader
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = os.listdir(root_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.image_files[index])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image
    
# define transforms
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

# create dataset and dataloader
dataset = ImageDataset(root_dir="/original_images", transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [11]:
# define esrgan model architecture
class RRDB(nn.Module):
    def __init__(self, in_channels):
        super(RRDB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.relu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return x + out
# define generator and discriminator   
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_rrdb=23):
        super(Generator, self).__init__()
        self.initial_conv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)  
        self.rrdb_blocks = nn.Sequential(*[RRDB(64) for _ in range(num_rrdb)])
        self.final_conv = nn.Conv2d(64, in_channels, kernel_size=3, padding=1)

    def forward(self, x):
        initial_feature = self.initial_conv(x)
        out = self.rrdb_blocks(initial_feature)
        out = self.final_conv(out)
        return out

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers 
          
        self.model = nn.Sequential(
            *block(in_channels, 64, normalize=False),
            *block(64, 128),
            *block(128, 256),
            *block(256, 512),
            nn.Conv2d(512, 1, 3, stride=1, padding=1)
        )
        
    def forward(self, img):
        return self.model(img)  


In [12]:
# loss function

class ContentLoss(nn.Module):
    def __init__(self):
        super(ContentLoss, self).__init__()

    def forward(self, sr, hr):
        return F.mse_loss(sr, hr)
    
class PerceptualLoss(nn.Module):
    def __init__(self, vgg_model):
        super(PerceptualLoss, self).__init__()
        self.vgg = vgg_model.features[:36]
        self.vgg.eval()

    def forward(self, sr, hr):
        sr_features = self.vgg(sr)
        hr_features = self.vgg(hr)
        return F.mse_loss(sr_features, hr_features)

In [13]:
# define training loop
def train(generator, discriminator, dataloader, num_epochs, optimizer_G, optimizer_D, criterion_content, criterion_perceptual, device):
    generator.to(device)
    discriminator.to(device)

    for epoch in range(num_epochs):
        for i, img in enumerate(dataloader):
            img = img.to(device)

            # generator
            sr_image = generator(img)

            #train generator
            optimizer_G.zero_grad()
            content_loss = criterion_content(sr_image, img)
            perceptual_loss = criterion_perceptual(sr_image, img)
            g_loss = content_loss + perceptual_loss
            g_loss.backward()
            optimizer_G.step()

            # train discriminator
            optimizer_D.zero_grad()
            real_output = discriminator(img)
            fake_output = discriminator(sr_image.detach())
            d_loss = F.binary_cross_entropy_with_logits(real_output, torch.ones_like(real_output)) + \
                     F.binary_cross_entropy_with_logits(fake_output, torch.zeros_like(fake_output))
            d_loss.backward()
            optimizer_D.step()

            if i % 40 == 0:
                print(f"Epoch {epoch +1}/{num_epochs}, Step {i}, G Loss: {g_loss.item()}, D Loss: {d_loss.item()}")

In [None]:
# train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = Generator()
discriminator = Discriminator()

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# load pretrained vgg model
vgg = models.vgg19(pretrained=True).to(device)
criterion_content = ContentLoss()
criterion_perceptual = PerceptualLoss(vgg)

# train esrgan model
train(generator, discriminator, dataloader, num_epochs=60, optimizer_G=optimizer_G, optimizer_D=optimizer_D, criterion_content=criterion_content, criterion_perceptual=criterion_perceptual, device=device)

In [15]:
# save model
def save_gan_model(generator, discriminator, path):
    torch.save({
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
    }, path)

save_gan_model(generator, discriminator, 'esrgan_model.pth')

In [None]:
# test with sample image
test_image = Image.open("/blurred_test_image.jpg").convert('RGB')
original_image = Image.open("/original_image.jpg").convert('RGB')
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])
test_image = transform(test_image).unsqueeze(0).to(device)
original_image = transform(original_image).unsqueeze(0).to(device)

# generate super resolution image
generator.eval()
with torch.no_grad():
    sr_image = generator(test_image)

# save and display image
save_image(original_image, 'original_image.png')
save_image(sr_image, 'sr_image.png')
save_image(test_image, 'lr_image.png')

# show images
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.title('Original Image')
plt.imshow(np.transpose(original_image.squeeze().cpu().numpy(), (1, 2, 0)))
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title('Low Resolution Image')
plt.imshow(np.transpose(test_image.squeeze().cpu().numpy(), (1, 2, 0)))
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title('Super Resolution Image')
plt.imshow(np.transpose(sr_image.squeeze().cpu().numpy(), (1, 2, 0)))
plt.axis('off')
plt.show()

In [None]:
# calculate metrics
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt
import cv2

sharp_image = cv2.imread('original_image.png')
blurred_image = cv2.imread('lr_image.png')
output_image = cv2.imread('sr_image.png')

# resize images
sharp_image = cv2.resize(sharp_image, (256, 256))
blurred_image = cv2.resize(blurred_image, (256, 256))
output_image = cv2.resize(output_image, (256, 256))

# convert to RGB
sharp_image = cv2.cvtColor(sharp_image, cv2.COLOR_BGR2RGB)
blurred_image = cv2.cvtColor(blurred_image, cv2.COLOR_BGR2RGB)
output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)

# create figure
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = axes.ravel()

# calculate mse and ssim and psnr
mse_sharp = mse(sharp_image, sharp_image)
ssim_sharp = ssim(sharp_image, sharp_image, channel_axis=-1)
psnr_sharp = psnr(sharp_image, sharp_image)
mse_blurred = mse(sharp_image, blurred_image)
ssim_blurred = ssim(sharp_image, blurred_image, channel_axis=-1)
psnr_blurred = psnr(sharp_image, blurred_image)
mse_output = mse(sharp_image, output_image)
ssim_output = ssim(sharp_image, output_image, channel_axis=-1)
psnr_output = psnr(sharp_image, output_image)

# plot images
ax[0].axis('off')
ax[0].imshow(sharp_image)
ax[0].set_title(f"Sharp Image\nMSE: {mse_sharp}\nSSIM: {ssim_sharp}\nPSNR: {psnr_sharp}")
ax[1].axis('off')
ax[1].imshow(blurred_image)
ax[1].set_title(f"Blurred Image\nMSE: {mse_blurred}\nSSIM: {ssim_blurred}\nPSNR: {psnr_blurred}")
ax[2].axis('off')
ax[2].imshow(output_image)
ax[2].set_title(f"Output Image\nMSE: {mse_output}\nSSIM: {ssim_output}\nPSNR: {psnr_output}")

# show figure
plt.show()

