<a href="https://colab.research.google.com/github/martinpius/GANS/blob/main/Photo_Realistic_Super_Resolution_(SR)_GAN%2C_Pytorch_implementation_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount("/content/drive", force_remount = True)
try:
  COLAB = True
  import torch
  print(f">>>> You are on CoLaB with torch version {torch.__version__}")
except Exception as e:
  print(f">>>> {type(e)} {e}\n>>>> please correct {type(e)} and reload your drive")
  COLAB = False
def time_fmt(t: float = 123.189)->float():
  h = int(t / (60 * 60))
  m = int(t % (60 * 60) / 60)
  s = int(t % 60)
  return f"hrs: {h} min: {m:>02} sec: {s:>05.2f}"
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
print(f">>>> testing the time formating function........\n>>>> time elapsed\t{time_fmt()}")

Mounted at /content/drive
>>>> You are on CoLaB with torch version 1.9.0+cu102
>>>> testing the time formating function........
>>>> time elapsed	hrs: 0 min: 02 sec: 03.00


In [2]:
# in this notbook we are going to implement the SR-GAN to construct highly resolution images
# from low resolution images. More detail about the paper can be found here: 
# paper url (https://arxiv.org/abs/1609.04802)  

In [3]:
!pip install albumentations==0.4.6

Collecting albumentations==0.4.6
[?25l  Downloading https://files.pythonhosted.org/packages/92/33/1c459c2c9a4028ec75527eff88bc4e2d256555189f42af4baf4d7bd89233/albumentations-0.4.6.tar.gz (117kB)
[K     |████████████████████████████████| 122kB 23.6MB/s 
Collecting imgaug>=0.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl (948kB)
[K     |████████████████████████████████| 952kB 42.2MB/s 
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.4.6-cp37-none-any.whl size=65175 sha256=b5cff43b884ba54c5a14df1ee105f8283abfeac89016e52986e474c45741e830
  Stored in directory: /root/.cache/pip/wheels/c7/f4/89/56d1bee5c421c36c1a951eeb4adcc32fbb82f5344c086efa14
Successfully built albumentations
Installing collected packages: imgaug, albumentations
  Found

In [4]:
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from tensorflow import summary
from tqdm import tqdm
from PIL import Image
import albumentations as B
from albumentations.pytorch import ToTensorV2
import numpy as np
import pandas as pd
import time, datetime, random, os
%load_ext tensorboard

In [5]:
# setup the seed value for reproducability and gpu to deterministic:
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [6]:
# The SR-GAN consists of two subnetwork as like in the ussual GANS
# The generator follow a ResNet architecture where we have 16 residual
# blocks. The start is the conv block and after residual blocks we use
# the upsample block (x 4) followed by a conv block.
# The discriminantor is a convlution network with different kernels to 
# classify high resolution (original image) from the upscaled (generated)
# high resolution images.
# The generator use PReLU() activation while the discriminator use LeakyReLU()
# The loss function consists of 3 components (For the generator we have component)
# related to perceptual loss and content loss (we use vgg19 to achieve this) and the
# discriminator employ the ussual GANS loss.

In [7]:
class CNNBLOCK(nn.Module):
  ''' 
  we will use this block to build every conv layer for the generator and
  the discriminator.
  '''
  def __init__(self, 
               in_channels,
               out_channels,
               discriminator = False,
               use_bn = True,
               use_act = True,
               **kwargs):
    super(CNNBLOCK, self).__init__()
    self.use_act = use_act
    self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias = not use_bn)
    self.bnorm = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
    self.activation = (nn.LeakyReLU(0.2, inplace = True) 
    if discriminator
    else nn.PReLU(num_parameters = out_channels))
  
  def forward(self, input_tensor):
    return (self.activation(self.bnorm(self.cnn(input_tensor))) 
    if self.use_act else self.bnorm(self.cnn(input_tensor)))



class UPSAMPLEBLOCK(nn.Module):
  def __init__(self, in_channels, scale_factor):
    super(UPSAMPLEBLOCK, self).__init__()
    self.conv = nn.Conv2d(in_channels, 
                          in_channels * scale_factor **2,
                          kernel_size = 3,
                          stride = 1,
                          padding = 1)
    self.pxs = nn.PixelShuffle(scale_factor)
    self.act = nn.PReLU(num_parameters = in_channels)
  
  def forward(self,input_tensor):
    return self.act(self.pxs(self.conv(input_tensor)))


class RESIDUALBLOCK(nn.Module):
  ''' Residual block will be used only for building the generator
  network
  '''
  def __init__(self, in_channels):
    super(RESIDUALBLOCK, self).__init__()
    self.conv1 = CNNBLOCK(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1)
    self.conv2 = CNNBLOCK(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1, use_act = False)
  
  def forward(self, input_tensor):
    out = input_tensor
    x = self.conv1(input_tensor)
    x = self.conv2(x)
    return x + out

class GENERATOR(nn.Module):
  ''' This is the ResNet with 16 residual blocks'''
  def __init__(self, in_channels = 3, num_channels = 64, num_blocks = 16):
    super(GENERATOR, self).__init__()
    self.initial_block = CNNBLOCK(in_channels, 
                                  num_channels, 
                                  kernel_size = 9, 
                                  stride = 1,
                                  padding = 4, 
                                  use_bn = False)
    self.residuals = nn.Sequential(
        *[RESIDUALBLOCK(num_channels) for _ in range(num_blocks)] # this create all 16 resblocks
    )
    self.convblock = CNNBLOCK(num_channels,
                              num_channels, 
                              kernel_size = 3, 
                              stride = 1, 
                              padding = 1, 
                              use_act = False)
    self.upsamples = nn.Sequential(UPSAMPLEBLOCK(in_channels = num_channels, scale_factor = 2),
                                   UPSAMPLEBLOCK(in_channels = num_channels, scale_factor = 2))
    self.finalconv = nn.Conv2d(in_channels = num_channels,
                               out_channels = in_channels, 
                               kernel_size = 9,
                               stride = 1,
                               padding = 4)
  
  def forward(self, input_tensor):
    initial = self.initial_block(input_tensor)
    x = self.residuals(initial)
    x = self.convblock(x) + initial
    x = self.upsamples(x)
    x = self.finalconv(x)
    return torch.tanh(x)


class DISCRIMINATOR(nn.Module):
  ''' has similar architecture as vggnet'''
  def __init__(self, in_channels = 3, 
               features = [64, 64, 128, 128, 256,256, 512,512]):
    super(DISCRIMINATOR, self).__init__()
    blocks = []
    for idx, feature in enumerate(features):
      blocks.append(
          CNNBLOCK(in_channels,
                   feature,
                   kernel_size = 3,
                   stride = 1 + idx % 2,
                   padding = 1,
                   discriminator = True,
                   use_act = True,
                   use_bn = False if idx == 0 else True))
      in_channels = feature
    self.blocks = nn.Sequential(*blocks)
    self.classifier = nn.Sequential(
        nn.AdaptiveAvgPool2d((6,6)),
        nn.Flatten(),
        nn.Linear(in_features = 512*6*6, out_features = 1024),
        nn.LeakyReLU(0.2, inplace = True),
        nn.Linear(in_features = 1024, out_features = 1))
  
  def forward(self, input_tensor):
    x = self.blocks(input_tensor)
    return self.classifier(x)
  


In [8]:
# testing the network if it gives the desired outputs shapes:
def __test__():
  lR = 24 # we need to upscale to 3, 96, 96
  with torch.cuda.amp.autocast():
    lr_imgs = torch.randn(size = (32, 3, lR, lR))
    generator = GENERATOR()
    discriminator = DISCRIMINATOR()
    gen_out = generator(lr_imgs)
    disc_out = discriminator(gen_out)
    print(f">>>> The generator's shape: {gen_out.shape}\n>>>> discriminator's shape: {disc_out.shape}")

In [9]:
__test__()

>>>> The generator's shape: torch.Size([32, 3, 96, 96])
>>>> discriminator's shape: torch.Size([32, 1])


In [10]:
# This model is trained using combination of loss components. 
# the perceptual loss is computed using vgg19-features and is 
# included  with GANS loss to optimize the network.


In [11]:
# for vgg19 pretrained network layers considered are 0:36  i.e
vgg = torchvision.models.vgg19(pretrained = True).features

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))




In [12]:
print(vgg)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [13]:
class PerceptualLoss(nn.Module):
  ''' this class will compute the percetual loss using pretrained vgg19'''
  def __init__(self):
    super(PerceptualLoss, self).__init__()
    self.vgg19 = torchvision.models.vgg19(pretrained = True).features[:36].eval().to(device = device)
    self.loss = nn.MSELoss()
     # we need not to train the vggnet again
    for pars in self.vgg19.parameters():
      pars.requires_grad = False
  
  def forward(self, LR, HR):
    ''' 
    inputs is low resolution image (which is upgraded with the generator)
    and output is the high resolution image (original images without distortion)
    '''
    vgg_LR_inputs = self.vgg19(LR)
    vgg_HR_target = self.vgg19(HR)
    # now we get the perceptual loss using vggnet as 
    return self.loss(vgg_LR_inputs, vgg_HR_target)




In [14]:
# Tensorboard directories:
current_time = datetime.datetime.now().timestamp()
fake_path = "logs/tensorboard/generator_srgan/"+str(current_time)
real_path = "logs/tensorboard/discriminator_srgan/"+str(current_time)
fake_writer = summary.create_file_writer(fake_path)
real_writer = summary.create_file_writer(real_path)

In [16]:
# testing the perceptual loss function if its gives what we needed:
lR = torch.randn(size = (32, 3, 24,24)).to(device = device)
hR = torch.randn(size = (32, 3, 96,96)).to(device = device)
perceptual_loss = PerceptualLoss()
print(f">>>> The perceptual loss: {perceptual_loss.forward(lR, hR)}")

>>>> The perceptual loss: 0.26543715596199036


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  return F.mse_loss(input, target, reduction=self.reduction)


In [17]:
# Hyperparameters to be used for training and preprocessing images
learning_rate = 1e-4
EPOCHS = 10
batch_size = 16
HR = 96
LR = HR//4
img_channels = 3

transforms_hR = transforms.Compose(
    [
     transforms.Normalize(np.array([0.5 for _ in range(img_channels)]),np.array([0.5 for _ in range(img_channels)]))
    
    ]
)

transforms_lR = transforms.Compose(
    [transforms.Resize(size = (LR, LR), interpolation = Image.BICUBIC),
    transforms.Normalize(np.array([0 for _ in range(img_channels)]),np.array([1 for _ in range(img_channels)]))
])

transform_General = transforms.Compose([transforms.ToTensor(),
                                      transforms.RandomCrop((HR, HR))                                       
])

test_transform = transforms.Compose(
    [
     transforms.Normalize(mean = [0,0,0], std = [1,1,1]),
     transforms.ToTensor()
    ]
)


  "Argument interpolation should be of type InterpolationMode instead of int. "


In [18]:
os.chdir("/content/drive/MyDrive/flickr30k_images/flickr30k_images")
# ploting examples after training:
def plot_examples(low_res_folder, gen):
    files = os.listdir(low_res_folder)

    gen.eval()
    for file in files:
        image = Image.open("/test1_SRGAN" + file)
        with torch.no_grad():
          upscaled_img = gen(test_transform(np.asarray(image))
                .unsqueeze(0)
                .to(device = device)
            )
        save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}")

In [19]:
highres_transform = B.Compose(
    [
        B.Normalize(mean=np.array([0.5, 0.5, 0.5]), std=np.array([0.5, 0.5, 0.5])),
        ToTensorV2(),
    ]
)

lowres_transform = B.Compose(
    [
        B.Resize(width=LR, height=LR, interpolation=Image.BICUBIC),
        B.Normalize(mean=np.array([0, 0, 0]), std=np.array([1, 1, 1])),
        ToTensorV2(),
    ]
)

both_transforms = B.Compose(
    [
        B.RandomCrop(width=HR, height= HR),
        B.HorizontalFlip(p=0.5),
        B.RandomRotate90(p=0.5),
    ]
)

In [20]:
# We use the flickr30k dataset for demo:
class ImageFolder(Dataset):
  def __init__(self, root_dir, csv_file, tr1 = None, tr2 = None, tr3 = None):
    super(ImageFolder, self).__init__()
    self.root_dir = root_dir
    self.data = pd.read_csv(csv_file, error_bad_lines = False)
    self.tr = True
    self.tr1 = tr1
    self.tr2 = tr2
    self.tr3 = tr3

  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, index):
    img_path = os.path.join(self.root_dir, self.data.iloc[index, 0])
    image = Image.open(img_path)
    if self.tr:
      image = self.tr1(image)
      lr_image = self.tr2((image))
      hr_image = self.tr3(image)
    return lr_image, hr_image

dataset = ImageFolder(root_dir = "/content/drive/MyDrive/flickr30k_images/flickr8k/images",
                      csv_file = "/content/drive/MyDrive/flickr30k_images/flickr8k/captions.txt",
                      tr1 = transform_General,
                      tr2 = transforms_lR,
                      tr3 = transforms_hR)

loader = DataLoader(dataset = dataset, shuffle = True, batch_size = batch_size)
LR_image, HR_image = next(iter(loader))
print(f">>>> LR_image shape: {LR_image.shape}\n>>>> HR_image shape = {HR_image.shape}")

>>>> LR_image shape: torch.Size([16, 3, 24, 24])
>>>> HR_image shape = torch.Size([16, 3, 96, 96])


In [21]:
# Instantiating the model class and save to device
generator = GENERATOR().to(device = device)
discriminator = DISCRIMINATOR().to(device = device)


In [22]:
# Get the losses and optimizers objects:
bce = nn.BCEWithLogitsLoss()
vgg_loss = PerceptualLoss()
mse = nn.MSELoss()
gen_optimimizer = optim.Adam(params = generator.parameters(), lr = learning_rate, betas = (0.9, 0.999))
disc_optimizer = optim.Adam(params = discriminator.parameters(), lr = learning_rate, betas = (0.9, 0.999))


In [None]:
# The training loop:
step = 0
tic = time.time()
for epoch in range(EPOCHS):
  print(f"\n>>>> training starts for epoch <{epoch + 1}>\n>>>> please wait while the model is training.................................")

  for idx, (lR, hR) in enumerate(tqdm(loader)):
    lR = lR.to(device = device)
    hR = hR.to(device = device)
    # training the discriminator on both hR and lR images
    # we still have to maximize log(D(x) + log(1 - D(G(x))))

    lR_fake = generator(lR) 
    lR_disc_preds = discriminator(lR_fake.detach()) # to re-use in the generator we detach the gradients
    hR_disc_preds = discriminator(hR)
    disc_real_loss = bce(hR_disc_preds, torch.ones_like(hR_disc_preds) - torch.rand_like(hR_disc_preds)*0.1)
    disc_fake_loss = bce(lR_disc_preds, torch.zeros_like(lR_disc_preds))
    disc_loss = (disc_real_loss + disc_fake_loss) / 2
    discriminator.zero_grad()
    disc_loss.backward()
    disc_optimizer.step()

    # training the generator: 
    # Here is when we utilize the vgg-loss but the main idea remain the same. [Max(log(D(G(x))))]
    gen_out = discriminator(lR_fake)
    #mse_loss = mse(lR, hR) # main purpose of SRGAN is to replace this component with the vgg-loss
    adversarial_loss = 1e-3 * bce(gen_out, torch.ones_like(gen_out))
    gen_vgg_loss = 0.006* vgg_loss(lR_fake, hR) # why 0.006????Not clear for me
    gen_loss = adversarial_loss + gen_vgg_loss # leave out the mse_loss or take mse loss with the vgg-loss
    generator.zero_grad()
    gen_loss.backward()
    gen_optimimizer.step()

    if idx% 200 == 0:
      print(f"\n>>> end of epoch <{epoch + 1}>: generator loss:=>=>=>{gen_loss:.4f}: discriminator loss=>=>=>{disc_loss:.4f}")
      #plot_examples("/content/drive/MyDrive/flickr30k_images/flickr30k_images/", generator)
      with fake_writer.as_default():
        summary.scalar("generator_loss", gen_loss.cpu().detach().numpy(), step = step)
      with real_writer.as_default():
        summary.scalar('discriminator_loss', disc_loss.cpu().detach().numpy(), step = step)
      step += 1
%tensorboard --logdir logs/tensorboard
toc = time.time()
print(f"\n>>>> time elapsed for training srgan for 10 epochs: {time_fmt(toc - tic)}")



  0%|          | 0/2529 [00:00<?, ?it/s]


>>>> training starts for epoch <1>
>>>> please wait while the model is training.................................


  0%|          | 1/2529 [00:04<3:20:49,  4.77s/it]


>>> end of epoch <1>: generator loss:=>=>=>0.0070: discriminator loss=>=>=>0.6795


  8%|▊         | 201/2529 [12:41<2:18:57,  3.58s/it]


>>> end of epoch <1>: generator loss:=>=>=>0.0050: discriminator loss=>=>=>0.4164


 12%|█▏        | 316/2529 [18:25<1:50:29,  3.00s/it]