# Super-resolution of CelebA using Generative Adversarial Networks.
The dataset can be downloaded from: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=0
## Instrustion on running the script:
1. Download the dataset from the provided link
2. Save the folder 'img_align_celeba' to 'datasets/'
4. Run the sript using command 'python srgan.py'

In [0]:
! mkdir datasets
! unzip /content/drive/'My Drive'/img_align_celeba.zip -q -d datasets

mkdir: cannot create directory ‘datasets’: File exists
Archive:  /content/drive/My Drive/img_align_celeba.zip
caution: filename not matched:  -q


In [0]:
! pip install git+https://www.github.com/keras-team/keras-contrib.git

Collecting git+https://www.github.com/keras-team/keras-contrib.git
  Cloning https://www.github.com/keras-team/keras-contrib.git to /tmp/pip-req-build-m2w2lqh_
  Running command git clone -q https://www.github.com/keras-team/keras-contrib.git /tmp/pip-req-build-m2w2lqh_
Building wheels for collected packages: keras-contrib
  Building wheel for keras-contrib (setup.py) ... [?25l[?25hdone
  Created wheel for keras-contrib: filename=keras_contrib-2.0.8-cp36-none-any.whl size=101064 sha256=218fab403e6b5c9b9440ac3e251fd392dcb6e1c0542103ac9c1f7970fa067daf
  Stored in directory: /tmp/pip-ephem-wheel-cache-bigv3lu5/wheels/11/27/c8/4ed56de7b55f4f61244e2dc6ef3cdbaff2692527a2ce6502ba
Successfully built keras-contrib


In [0]:
# ! pip install scipy==1.1.0

Collecting scipy==1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/a8/0b/f163da98d3a01b3e0ef1cab8dd2123c34aee2bafbb1c5bffa354cc8a1730/scipy-1.1.0-cp36-cp36m-manylinux1_x86_64.whl (31.2MB)
[K     |████████████████████████████████| 31.2MB 109kB/s 
[31mERROR: tensorflow 2.2.0rc3 has requirement scipy==1.4.1; python_version >= "3", but you'll have scipy 1.1.0 which is incompatible.[0m
[31mERROR: plotnine 0.6.0 has requirement scipy>=1.2.0, but you'll have scipy 1.1.0 which is incompatible.[0m
[31mERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.[0m
Installing collected packages: scipy
  Found existing installation: scipy 1.4.1
    Uninstalling scipy-1.4.1:
      Successfully uninstalled scipy-1.4.1
Successfully installed scipy-1.1.0


In [0]:
import scipy

from keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os
from glob import glob
import keras.backend as K
import scipy.misc

In [0]:
class DataLoader():
  def __init__(self, dataset_name, img_res=(128, 128)):
    self.dataset_name = dataset_name
    self.img_res = img_res

  def load_data(self, batch_size=1, is_testing=False):
    data_type = "train" if not is_testing else "test"
    
    path = glob('./datasets/%s/*' % (self.dataset_name))

    batch_images = np.random.choice(path, size=batch_size)

    imgs_hr = []
    imgs_lr = []
    for img_path in batch_images:
      img = self.imread(img_path)

      h, w = self.img_res
      low_h, low_w = int(h / 4), int(w / 4)

      img_hr = scipy.misc.imresize(img, self.img_res)
      img_lr = scipy.misc.imresize(img, (low_h, low_w))

      # If training => do random flip
      if not is_testing and np.random.random() < 0.5:
        img_hr = np.fliplr(img_hr)
        img_lr = np.fliplr(img_lr)

      imgs_hr.append(img_hr)
      imgs_lr.append(img_lr)

    imgs_hr = np.array(imgs_hr) / 127.5 - 1.
    imgs_lr = np.array(imgs_lr) / 127.5 - 1.

    return imgs_hr, imgs_lr


  def imread(self, path):
    return scipy.misc.imread(path, mode='RGB').astype(np.float)

In [0]:
channels = 3
lr_height = 64
lr_width = 64
lr_shape = (lr_height, lr_width, channels)
hr_height = lr_height * 4
hr_width = lr_width * 4
hr_shape = (hr_height, hr_width, channels)

n_residual_blocks = 16
optimizer = Adam(0.0002, 0.5)

We use a pre-trained VGG19 model to extract image features from the high resolution and the generated high resolution images and minimize the mse between them

In [0]:
def build_vgg():
  vgg = VGG19(weights='imagenet')
  vgg.outputs = [vgg.layers[9].output]
  img = Input(shape=hr_shape)
  img_features = vgg(img)
  return Model(img, img_features)

In [0]:
vgg = build_vgg()
vgg.trainable = False
vgg.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

In [0]:
# configure data loader
dataset_name = 'img_align_celeba'
data_loader = DataLoader(dataset_name=dataset_name, img_res=(hr_height, hr_width))

In [0]:
# calculate output shape of D (PatchGAN
patch = int(hr_height / 2**4)
disc_patch = (patch, patch, 1)

In [0]:
# Number of filters in the first layer of G and D
gf = 64
df = 64

build and compile the discriminator

In [0]:
def build_discriminator():
  def d_block(layer_input, filters, strides=1, bn=True):
    """Discriminator layer"""
    d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
    d = LeakyReLU(alpha=0.2)(d)
    if bn:
        d = BatchNormalization(momentum=0.8)(d)
    return d

  # Input img
  d0 = Input(shape=hr_shape)

  d1 = d_block(d0, df, bn=False)
  d2 = d_block(d1, df, strides=2)
  d3 = d_block(d2, df*2)
  d4 = d_block(d3, df*2, strides=2)
  d5 = d_block(d4, df*4)
  d6 = d_block(d5, df*4, strides=2)
  d7 = d_block(d6, df*8)
  d8 = d_block(d7, df*8, strides=2)

  d9 = Dense(df*16)(d8)
  d10 = LeakyReLU(alpha=0.2)(d9)
  validity = Dense(1, activation='sigmoid')(d10)

  return Model(d0, validity)

In [0]:
discriminator = build_discriminator()
discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

Build the generator

In [0]:
def build_generator():

  def residual_block(layer_input, filters):
    """Residual block described in paper"""
    d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
    d = Activation('relu')(d)
    d = BatchNormalization(momentum=0.8)(d)
    d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
    d = BatchNormalization(momentum=0.8)(d)
    d = Add()([d, layer_input])
    return d

  def deconv2d(layer_input):
    """Layers used during upsampling"""
    u = UpSampling2D(size=2)(layer_input)
    u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
    u = Activation('relu')(u)
    return u

  # Low resolution image input
  img_lr = Input(shape=lr_shape)

  # Pre-residual block
  c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
  c1 = Activation('relu')(c1)

  # Propogate through residual blocks
  r = residual_block(c1, gf)
  for _ in range(n_residual_blocks - 1):
      r = residual_block(r, gf)

  # Post-residual block
  c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
  c2 = BatchNormalization(momentum=0.8)(c2)
  c2 = Add()([c2, c1])

  # Upsampling
  u1 = deconv2d(c2)
  u2 = deconv2d(u1)

  # Generate high resolution output
  gen_hr = Conv2D(channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)

  return Model(img_lr, gen_hr)

In [0]:
generator = build_generator()

In [0]:
# High res. and low res. images
img_hr = Input(shape=hr_shape)
img_lr = Input(shape=lr_shape)

In [0]:
# generate high res. version from low res.
fake_hr = generator(img_lr)

In [0]:
# extract image features of the generated img
fake_features = vgg(fake_hr)

In [0]:
# for the combined model we will only train the generator
discriminator.trainable = False

In [0]:
# Discriminator determines validity of generated high res. images
validity = discriminator(fake_hr)
combined = Model([img_lr, img_hr], [validity, fake_features])
combined.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[1e-3, 1], optimizer=optimizer)

In [0]:
def sample_images(epoch):
  os.makedirs('images/%s' % dataset_name, exist_ok=True)
  r,c =2,2
  imgs_hr, imgs_lr = data_loader.load_data(batch_size=2, is_testing=True)
  fake_hr = generator.predict(imgs_lr)

  # rescale imgaes 0 - 1
  imgs_lr = 0.5 * imgs_lr + 0.5
  fake_hr = 0.5 * fake_hr + 0.5
  imgs_hr = 0.5 * imgs_hr + 0.5

  # save generated imgages and the high resolution originals
  titles = ["Generated", "Original"]
  fig, axs = plt.subplots(r, c)
  cnt = 0
  for row in range(r):
    for col, image in enumerate([fake_hr, imgs_hr]):
      axs[row, col].imshow(image[row])
      axs[row, col].set_title(titles[col])
      axs[row, col].axis('off')
    cnt += 1
  fig.savefig('images/%s/%d.png' % (dataset_name, epoch))
  plt.close()

  # Save low resolution images for comparison

  for i in range(r):
    fig = plt.figure()
    plt.imshow(imgs_lr[i])
    fig.savefig('images/%s/%d_lowres%d.png'%(dataset_name, epoch, i))
    plt.close()

## Train

In [0]:
epochs=30000
batch_size=1
sample_interval=50

In [0]:
start_time = datetime.datetime.now()
for epoch in range(epochs):
  # ----------------------
  #  Train Discriminator
  # ----------------------

  # Sample images and their conditioning counterparts
  imgs_hr, imgs_lr = data_loader.load_data(batch_size)

  # From low res. image generate high res. version
  fake_hr = generator.predict(imgs_lr)

  valid = np.ones((batch_size,) + disc_patch)
  fake = np.zeros((batch_size,) + disc_patch)

  # Train the discriminators (original images = real / generated = Fake)
  d_loss_real = discriminator.train_on_batch(imgs_hr, valid)
  d_loss_fake = discriminator.train_on_batch(fake_hr, fake)
  d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

  # ------------------
  #  Train Generator
  # ------------------

  # Sample images and their conditioning counterparts
  imgs_hr, imgs_lr = data_loader.load_data(batch_size)

  # The generators want the discriminators to label the generated images as real
  valid = np.ones((batch_size,) + disc_patch)

  # Extract ground truth image features using pre-trained VGG19 model
  image_features = vgg.predict(imgs_hr)

  # Train the generators
  g_loss = combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])

  elapsed_time = datetime.datetime.now() - start_time
  # Plot the progress
  print ("%d time: %s" % (epoch, elapsed_time))

  # If at save interval => save generated image samples
  if epoch % sample_interval == 0:
      sample_images(epoch)

`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
  'Discrepancy between trainable weights and collected trainable'


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
995 time: 0:18:44.888837
996 time: 0:18:45.965280
997 time: 0:18:47.040771
998 time: 0:18:48.120014
999 time: 0:18:49.197854
1000 time: 0:18:50.274799
1001 time: 0:18:52.085480
1002 time: 0:18:53.172072
1003 time: 0:18:54.264865
1004 time: 0:18:55.339617
1005 time: 0:18:56.417147
1006 time: 0:18:57.506680
1007 time: 0:18:58.599876
1008 time: 0:18:59.677276
1009 time: 0:19:00.758482
1010 time: 0:19:01.839145
1011 time: 0:19:02.914705
1012 time: 0:19:04.004658
1013 time: 0:19:05.088366
1014 time: 0:19:06.172966
1015 time: 0:19:07.267206
1016 time: 0:19:08.359933
1017 time: 0:19:09.471681
1018 time: 0:19:10.577645
1019 time: 0:19:11.675177
1020 time: 0:19:12.780263
1021 time: 0:19:13.871003
1022 time: 0:19:14.968034
1023 time: 0:19:16.052691
1024 time: 0:19:17.140028
1025 time: 0:19:18.220037
1026 time: 0:19:19.310812
1027 time: 0:19:20.395252
1028 time: 0:19:21.479478
1029 time: 0:19:22.565049
1030 time: 0:19:23.665164
1031

KeyboardInterrupt: ignored