# SRGAN

Original paper is here:
[Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802)

Architecture
![SRGAN Architecture](srgan-model.jpeg)


A keras implementation of SRGAN.

In [1]:
import os, time

import tensorflow.keras as keras
import tensorflow as tf
import tensorlayer as tl
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import h5py

from tensorflow.python.keras.utils.generic_utils import Progbar

from models.srgan import SRGAN

%matplotlib inline

# Autoreload modules
%load_ext autoreload
%autoreload 2

### Hyperparameters

In [2]:
lr_height = 96
lr_width = 96
upscaling_rate = 4

batch_size = 2
lr_init = 1e-4

# initialize G
num_epoch_init_g = 3

# adversarial learning
num_epoch_gan = 10
#n_epoch = 2000
#lr_decay = 0.1
#decay_step = int(n_epoch/2)

verbose_interval = 1

# image paths
train_hr_img_path = 'datasets/DIV2K/DIV2K_train_HR'
train_lr_img_path = 'datasets/DIV2K/DIV2K_train_LR'

valid_hr_img_path = 'datasets/DIV2K/DIV2K_valid_HR'
valid_lr_img_path = 'datasets/DIV2K/DIV2K_valid_LR'

# checkpoints and saved model
checkpoint_dir = 'checkpoints/srgan/'


In [3]:
def load_data(hr_img_path, lr_img_path):
    
    hr_file_list = sorted(np.array(tl.files.load_file_list(path=hr_img_path, regx='.*.png', printable=False)))

    r = upscaling_rate
    HR = np.zeros((0, lr_height*r, lr_width*r, 3))
    LR = np.zeros((0, lr_height, lr_width, 3))
    VGG = np.zeros((0, 224, 224, 3))
    
    for hr_img_file in hr_file_list:

        file_name, file_ext = hr_img_file.split('.')
        hr_file = os.path.join(hr_img_path, hr_img_file)
        lr_file = os.path.join(lr_img_path, file_name + "x4." + file_ext)
        
        hr_img = plt.imread(hr_file)
        lr_img = plt.imread(lr_file)
        
        lr_shape = lr_img.shape

        for y in range(0, lr_shape[0]-lr_height+1, lr_height):
            y_end = y+lr_height
            for x in range(0, lr_shape[1]-lr_width+1, lr_width):
                x_end = x+lr_width
                
                lr_data = lr_img[y:y_end, x:x_end, :]
                hr_data = hr_img[y*r:y_end*r, x*r:x_end*r, :]
                
                # Create a 224x224 image for vgg feature by resizing HR
                vgg_img = Image.fromarray(np.uint8(hr_data*255))
                vgg_data = np.asarray(vgg_img.resize((224, 224), Image.BICUBIC))/255
                
                HR = np.append(HR, np.expand_dims(hr_data, 0), axis=0)
                LR = np.append(LR, np.expand_dims(lr_data, 0), axis=0)
                VGG = np.append(VGG, np.expand_dims(vgg_data, 0), axis=0)
                
    
    indices = np.arange(HR.shape[0])
    np.random.shuffle(indices)
    HR = HR[indices]
    LR = LR[indices]
    VGG = VGG[indices]

    return HR, LR, VGG


data_filename = os.path.join(checkpoint_dir, "train.hdf5")
if os.path.isfile(data_filename):
    hf = h5py.File(data_filename, 'r')
    train_hr = np.array(hf.get('train_hr'))
    train_lr = np.array(hf.get('train_lr'))
    train_vgg = np.array(hf.get('train_vgg'))
    hf.close()
else:
    train_hr, train_lr, train_vgg = load_data(train_hr_img_path, train_lr_img_path)

    # Save to disk
    hf = h5py.File(data_filename, 'w')
    hf.create_dataset('train_hr', data=train_hr)
    hf.create_dataset('train_lr', data=train_lr)
    hf.create_dataset('train_vgg', data=train_vgg)
    hf.close()

print(train_hr.shape)
print(train_lr.shape)
print(train_vgg.shape)

(380, 384, 384, 3)
(380, 96, 96, 3)
(380, 224, 224, 3)


In [4]:
# Set allow_growth
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
keras.backend.set_session(sess)

# DEFINE MODEL
srgan = SRGAN()

# Shape of output from discriminator
d_output_shape = list(srgan.discriminator.output_shape)
d_output_shape[0] = batch_size

# VALID / FAKE targets for discriminator
real = tf.ones(d_output_shape)
fake = tf.zeros(d_output_shape)

# train set size
size = train_hr.shape[0]

# initialize generator
srgan.generator.fit(train_lr, train_hr, batch_size, epochs=num_epoch_init_g)

# save model
srgan.save_weights(os.path.join(checkpoint_dir, "init"))

# train gan
g_losses = []
d_losses = []
bar = Progbar(size)

for epoch in range(num_epoch_gan):
    
    print()
    print("Epoch %d/%d" % (epoch+1, num_epoch_gan))

    # batches
    for i in range(0, size, batch_size):

        batch_hr = train_hr[i:i+batch_size]
        batch_lr = train_lr[i:i+batch_size]
        batch_vgg = train_vgg[i:i+batch_size]

        # train discriminator
        generated_hr = srgan.generator.predict(batch_lr)
        
        #srgan.discriminator.trainable = True
        #srgan.discriminator.compile(loss='mse', optimizer=srgan.discriminator.optimizer, metrics=['accuracy'])
        
        real_loss = srgan.discriminator.train_on_batch(batch_hr, real)
        fake_loss = srgan.discriminator.train_on_batch(generated_hr, fake)
        d_loss = np.add(real_loss, fake_loss)

        # train generator
        features_hr = srgan.vgg.predict(batch_vgg)
        g_loss = srgan.srgan.train_on_batch(batch_lr, [real, features_hr, batch_hr])
        
        # Save losses
        g_losses.append(g_loss[0])
        d_losses.append(d_loss[0])
        
        bar.update(i+batch_size, [('loss', g_loss[0]), ('d_loss', g_loss[1]), ('vgg_loss', g_loss[2]), ('g_loss', g_loss[3])])


# Save weights
srgan.save_weights(os.path.join(checkpoint_dir, "trained"))

# Plot g_lost
plt.figure()
plt.subplot(211)
plt.plot(g_losses)

plt.subplot(212)
plt.plot(d_losses)

plt.show()

Epoch 1/3
Epoch 2/3
Epoch 3/3

Epoch 1/10

Epoch 2/10

Epoch 3/10

Epoch 4/10

Epoch 5/10

Epoch 6/10

MemoryError: 

In [None]:
plt.plot(losses)
plt.show()