# **Final Project: Image super resolution**

## ESRGAN definition

In [1]:
## Create a Custom Dataset for CK database
import scipy.io as sio
from google.colab import drive
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as tf
import torch.nn.functional as F

# Mount Google Drive
drive.mount('/content/drive')
data_path = '/content/drive/Shareddrives/DeepLearning/DeepLearning_2021/finalProject/Data/'
results_path = '/content/drive/Shareddrives/DeepLearning/DeepLearning_2021/finalProject/Results/'

Mounted at /content/drive


In [2]:
STL = torchvision.datasets.STL10(data_path, transform=tf.ToTensor(), download=True, split='train')


Files already downloaded and verified


In [3]:
STLTest = torchvision.datasets.STL10(data_path, transform=tf.ToTensor(), download=True, split='test')
  

Files already downloaded and verified


In [4]:
train_loader = torch.utils.data.DataLoader(dataset=STL,
                                           batch_size=128, 
                                           shuffle=True)

In [5]:
import torch
import torch.nn as nn

class RDB(nn.Module):
  def __init__(self):
    super(RDB, self).__init__()

    self.conv1 = nn.Conv2d(64,32,kernel_size=3, stride=1, padding = 1)
    self.conv2 = nn.Conv2d(96,32,kernel_size=3, stride=1, padding = 1)
    self.conv3 = nn.Conv2d(128,32,kernel_size=3, stride=1, padding = 1)
    self.conv4 = nn.Conv2d(160,32,kernel_size=3, stride=1, padding = 1)
    self.conv5 = nn.Conv2d(192,64,kernel_size=3, stride=1, padding = 1)

    self.leakyRelu = nn.LeakyReLU(inplace=True, negative_slope = 0.2)

  def forward(self,x):
    out1 = self.leakyRelu(self.conv1(x))
    out2 = self.leakyRelu(self.conv2(torch.cat((x, out1), 1)))
    out3 = self.leakyRelu(self.conv3(torch.cat((x, out1, out2), 1)))
    out4 = self.leakyRelu(self.conv4(torch.cat((x, out1, out2, out3), 1)))
    out = self.conv5(torch.cat((x, out1, out2, out3, out4), 1))

    return out

class RRDB(nn.Module):
  def __init__(self, scalingParam = 0.5):
    super(RRDB, self).__init__()
    self.scalingParam = scalingParam
    self.RDB1 = RDB()
    self.RDB2 = RDB()
    self.RDB3 = RDB()

  def forward(self,x):
    out1 = self.RDB1(x)
    out = x + out1 * self.scalingParam
    out2 = self.RDB2(out)
    out = out + out2 * self.scalingParam
    out3 = self.RDB3(out)
    out = out + out3 * self.scalingParam

    return out

class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.layer1 = nn.Conv2d(3,64,kernel_size=3, stride=1, padding = 1)

    self.RRDB1 = RRDB()
    self.RRDB2 = RRDB()
    self.RRDB3 = RRDB()    

    self.pool = nn.UpsamplingNearest2d(scale_factor=3)

    self.conv1 = nn.Conv2d(128,64,kernel_size=3, stride=1, padding = 1)
    self.conv2 = nn.Conv2d(64,64,kernel_size=3, stride=1, padding = 1)
    self.conv3 = nn.Conv2d(64,3,kernel_size=3, stride=1, padding = 1)

    self.leakyRelu = nn.LeakyReLU(inplace=True, negative_slope = 0.2)
  
  def forward(self,x):

    out = self.layer1(x)

    out1 = self.RRDB1(out)
    out1 = self.RRDB2(out1)
    out1 = self.RRDB3(out1)

    out = torch.cat((out, out1), 1)

    out = self.pool(out)

    out = self.leakyRelu(self.conv1(out))
    out = self.leakyRelu(self.conv2(out))
    out = self.leakyRelu(self.conv3(out))

    return out

  # Sample a set of images
  def sample(self, realImage):
    tr = F.interpolate(realImage, size=32)
    return tr


# Convolution + BatchNormnalization + ReLU block for the encoder
class ConvBNReLU(nn.Module):
  def __init__(self,in_channels, out_channels, pooling=False):
    super(ConvBNReLU, self).__init__()
    self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=3,
                          padding = 1)
    self.bn = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU(inplace=True)

    self.pool = None
    if(pooling):
      self.pool = nn.AvgPool2d(2,2)

  def forward(self,x):
    if(self.pool):
      out = self.pool(x)
    else:
      out = x
    out = self.relu(self.bn(self.conv(out)))   
    return out

#fer el decoder com el del paper
class Discriminator(nn.Module):
  def __init__(self,out_features,base_channels=16):
    super(Discriminator, self).__init__()
    self.layer1 = ConvBNReLU(3,base_channels,pooling=True)
    self.layer2 = ConvBNReLU(base_channels,base_channels*2,pooling=True)
    self.layer3 = ConvBNReLU(base_channels*2,base_channels*4,pooling=True)
    self.fc = nn.Linear(12*12*base_channels*4,out_features)
  
  def forward(self,x):
    out = self.layer1(x)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.fc(out.view(x.shape[0],-1))
    return torch.sigmoid(out)

In [6]:
# GAN Train function. We have a generator and discriminator models and their respective optimizers.
def train_GAN(gen, disc,  train_loader, optimizer_gen, optim_disc,
              num_epochs=10, model_name='gan_ESRGAN.ckpt', device='cpu', gan_weight = 0.1):
    gen = gen.to(device)
    gen.train() # Set the generator in train mode
    disc = disc.to(device)
    disc.train() # Set the discriminator in train mode

    total_step = len(train_loader)
    losses_list = []

    # Iterate over epochs
    for epoch in range(num_epochs):
        # Iterate the dataset
        disc_loss_avg = 0
        gen_loss_avg = 0
        l1_loss_avg = 0
        nBatches = 0
        update_generator = True

        for i, (real_images,_) in enumerate(train_loader):
            # Get batch of samples and labels
            real_images = real_images.to(device)
            n_images = real_images.shape[0]

            # Forward pass
            # Generate images with the generator
            fake_images = gen.forward(gen.sample(real_images))
            
            # Use the discriminator to obtain the probabilties for real and generated images
            prob_real = disc(real_images)
            prob_fake = disc(fake_images)
            
            # Generator loss
            l1_loss = (real_images-fake_images).abs().mean()

            # Generator loss
            gen_loss = -torch.log(1-prob_real).mean() - torch.log(prob_fake).mean()
   
            # Discriminator loss
            disc_loss = -torch.log(prob_real).mean() - torch.log(1-prob_fake).mean()

            # We are going to update the discriminator and generator parameters alternatively at each iteration

            if (update_generator):
              # Optimize generator
              # Backward and optimize
              optimizer_gen.zero_grad() # 
              (gan_weight*gen_loss+l1_loss).backward() # Necessary to not erase intermediate variables needed for computing disc_loss gradient
              optimizer_gen.step()
              update_generator = False
            else:           
              # Optimize discriminator
              # Backward and optimize
              optimizer_disc.zero_grad()
              (disc_loss).backward()
              optimizer_disc.step()
              update_generator = True

            disc_loss_avg += disc_loss.cpu().item()
            gen_loss_avg += gen_loss.cpu().item()
            
            l1_loss_avg += l1_loss.cpu().item()
            nBatches+=1
            if (i+1) % 20 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Gen. Loss: {:.4f}, Disc Loss: {:.4f}, L1 Loss: {:.4f}' 
                       .format(epoch+1, num_epochs, i+1, total_step, gen_loss_avg / nBatches, disc_loss_avg / nBatches,  l1_loss_avg / nBatches))
            
        print ('Epoch [{}/{}], Step [{}/{}], Gen. Loss: {:.4f}, Disc Loss: {:.4f}, L1 Loss: {:.4f}' 
                       .format(epoch+1, num_epochs, i+1, total_step, gen_loss_avg / nBatches, disc_loss_avg / nBatches,  l1_loss_avg / nBatches))
        # Save model
        losses_list.append(disc_loss_avg / nBatches)
        torch.save(gan_gen.state_dict(), results_path+ '/' + model_name)
          
    return losses_list 

In [7]:
# Define Generator and Discriminator networks
gan_gen = Generator()
gan_disc = Discriminator(1)

#Initialize indepdent optimizer for both networks
learning_rate = .0005
optimizer_gen = torch.optim.Adam(gan_gen.parameters(),lr = learning_rate, weight_decay=1e-5)
optimizer_disc = torch.optim.Adam(gan_disc.parameters(),lr = learning_rate, weight_decay=1e-5)

# Train the GAN
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
loss_list = train_GAN(gan_gen,gan_disc, train_loader, optimizer_gen, optimizer_disc,
                      num_epochs=50, model_name='gan_ESRGAN.ckpt', device=device)

Epoch [1/50], Step [40/40], Gen. Loss: 1.7343, Disc Loss: 1.2924, L1 Loss: 0.1916
Epoch [2/50], Step [40/40], Gen. Loss: 1.4071, Disc Loss: 1.5525, L1 Loss: 0.0950
Epoch [3/50], Step [40/40], Gen. Loss: 1.3139, Disc Loss: 1.6424, L1 Loss: 0.0903
Epoch [4/50], Step [40/40], Gen. Loss: 1.4034, Disc Loss: 1.5182, L1 Loss: 0.0793
Epoch [5/50], Step [40/40], Gen. Loss: 1.3821, Disc Loss: 1.4573, L1 Loss: 0.0710
Epoch [6/50], Step [40/40], Gen. Loss: 1.4216, Disc Loss: 1.3857, L1 Loss: 0.0593
Epoch [7/50], Step [40/40], Gen. Loss: 1.3988, Disc Loss: 1.3962, L1 Loss: 0.0581
Epoch [8/50], Step [40/40], Gen. Loss: 1.4094, Disc Loss: 1.3818, L1 Loss: 0.0556
Epoch [9/50], Step [40/40], Gen. Loss: 1.4049, Disc Loss: 1.3910, L1 Loss: 0.0539
Epoch [10/50], Step [40/40], Gen. Loss: 1.4047, Disc Loss: 1.3829, L1 Loss: 0.0520
Epoch [11/50], Step [40/40], Gen. Loss: 1.4093, Disc Loss: 1.3770, L1 Loss: 0.0522
Epoch [12/50], Step [40/40], Gen. Loss: 1.4228, Disc Loss: 1.3670, L1 Loss: 0.0509
Epoch [13/50]

In [9]:
def superResolution():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # Load generator
    gan_gen = Generator()
    gan_gen = gan_gen.to(device)
    gan_gen.load_state_dict(torch.load(results_path+'gan_ESRGAN.ckpt'))
    gan_gen.eval() # Put in eval model
    gan_gen = gan_gen.to(device)

    reconstruction_error_avg = 0
    n_batches = 0
    # Init a Disc
    gan_disc = Discriminator(1)
    gan_disc = gan_disc.to(device)
  
    # Load test dataset
    test_loader = torch.utils.data.DataLoader(dataset=STLTest,
                                               batch_size=1, 
                                               shuffle=True)

    for i, (real_images,_) in enumerate(test_loader):
      real_images = real_images.to(device)
      if i >= 10: break
      x_gen = gan_gen.forward(gan_gen.sample(real_images))
      plt.imshow(np.moveaxis(x_gen.cpu().squeeze().detach().numpy(), 0, 2))
      plt.axis('off')
      plt.title("Generated Image low resolution")
      plt.show()

      plt.imshow(np.moveaxis(gan_gen.sample(real_images).cpu().squeeze().detach().numpy(), 0, 2))
      plt.axis('off')
      plt.title("Real Image low resolution")
      plt.show()

      x_gen_high_res = gan_gen.forward(real_images)
      plt.imshow(np.moveaxis(x_gen_high_res.cpu().squeeze().detach().numpy(), 0, 2))
      plt.axis('off')
      plt.title("Generated Image high resolution")
      plt.show()

      plt.imshow(np.moveaxis(real_images.cpu().squeeze().detach().numpy(), 0, 2))
      plt.axis('off')
      plt.title("Real Image high resolution")
      plt.show()

      reconstruction_error = (real_images-x_gen).abs().mean()

      reconstruction_error_avg += reconstruction_error.cpu().item()
      n_batches+=1

    print("Reconstruction error: ", str(reconstruction_error_avg/n_batches))


superResolution()

Output hidden; open in https://colab.research.google.com to view.