# SRGAN

Original paper is here:
[Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802)

Architecture
![SRGAN Architecture](srgan-model.jpeg)


A keras implementation of SRGAN.

In [1]:
import os, time, glob, gc

import tensorflow.keras as keras
import tensorflow as tf
import tensorlayer as tl
import numpy as np
import matplotlib.pyplot as plt
import h5py

from tensorflow.python.keras.utils.generic_utils import Progbar

from models.srgan import SRGAN

%matplotlib inline

# Autoreload modules
%load_ext autoreload
%autoreload 2

### Hyperparameters

In [2]:
lr_height = 24
lr_width = 24
upscaling_rate = 8

pack_size = 2000
batch_size = 4
lr_init = 1e-4

# initialize G
num_epoch_init_g = 100

# adversarial learning
steps_per_epoch = 5000
num_epoch_gan = 20
checkpoint_step = 10
#n_epoch = 2000
#lr_decay = 0.1
#decay_step = int(n_epoch/2)

verbose_interval = 1

# image paths
train_hr_img_path = 'datasets/DIV2K/DIV2K_train_HR'
train_lr_img_path = 'datasets/DIV2K/DIV2K_train_LR'

valid_hr_img_path = 'datasets/DIV2K/DIV2K_valid_HR'
valid_lr_img_path = 'datasets/DIV2K/DIV2K_valid_LR'

# checkpoints and saved model
checkpoint_dir = 'checkpoints/srgan/'
sample_dir = 'samples/srgan/'

In [3]:
def prepare_data(hr_img_path, lr_img_path, output_file):
    
    hr_file_list = np.array(tl.files.load_file_list(path=hr_img_path, regx='.*.png', printable=False))
    np.random.shuffle(hr_file_list)

    r = upscaling_rate
    HR = []
    LR = []
    
    load_prog = Progbar(hr_file_list.shape[0])
    prog = 0
    n = 0
    
    for hr_img_file in hr_file_list:

        file_name, file_ext = hr_img_file.split('.')
        hr_file = os.path.join(hr_img_path, hr_img_file)
        lr_file = os.path.join(lr_img_path, file_name + "x4." + file_ext)
        
        hr_img = plt.imread(hr_file)
        lr_img = plt.imread(lr_file)
        
        lr_shape = lr_img.shape

        for y in range(0, lr_shape[0]-lr_height+1, lr_height):
            y_end = y+lr_height
            for x in range(0, lr_shape[1]-lr_width+1, lr_width):
                x_end = x+lr_width
                
                lr_data = lr_img[y:y_end, x:x_end, :]
                hr_data = hr_img[y*r:y_end*r, x*r:x_end*r, :]

                HR.append(hr_data)
                LR.append(lr_data)

        if len(HR) > pack_size:

            hr_buf = np.array(HR)
            lr_buf = np.array(LR)    
            indices = np.arange(pack_size)
            np.random.shuffle(indices)
            hr_pack = hr_buf[indices]
            lr_pack = lr_buf[indices]
            
            # Save pack
            filename = "%s_%d.hdf5" % (output_file, n)
            hf = h5py.File(filename, 'w')
            hf.create_dataset('hr', data=hr_pack)
            hf.create_dataset('lr', data=lr_pack)
            hf.close()
            #print("%s saved" % filename)
            
            # Remove saved data
            HR = HR[pack_size:]
            LR = LR[pack_size:]
            
            n += 1
            
        prog += 1
        load_prog.update(prog)


def data_generator(file_list, train=True):
    
    while True:
        for file in file_list:
            hf = h5py.File(file, 'r')
            pack_hr = np.array(hf.get('hr'))
            pack_lr = np.array(hf.get('lr'))

            num_in_pack = pack_hr.shape[0]

            for i in range(num_in_pack//batch_size):
                batch_hr = pack_hr[i*batch_size:(i+1)*batch_size]
                batch_lr = pack_lr[i*batch_size:(i+1)*batch_size]
                yield (batch_lr, batch_hr)
            hf.close()

if not os.path.isfile(os.path.join(checkpoint_dir, "train_0.hdf5")):
    print("Pre-processing training data...")
    prepare_data(train_hr_img_path, train_lr_img_path, os.path.join(checkpoint_dir, "train"))
    
if not os.path.isfile(os.path.join(checkpoint_dir, "valid_0.hdf5")):
    print("Pre-processing validation data...")
    prepare_data(valid_hr_img_path, valid_lr_img_path, os.path.join(checkpoint_dir, "valid"))

train_files = glob.glob(os.path.join(checkpoint_dir, "train_*.hdf5"))
valid_files = glob.glob(os.path.join(checkpoint_dir, "valid_*.hdf5"))
tdg = data_generator(train_files)
vdg = data_generator(valid_files, train=False)

### Problems

#### OOM
At the begining, the model is too large to train. Everytime I got OOM even at G initialization training. So I did some search and made some changes. Now the model is running.
1. I add another GPU (MSI GTX 1070 8GB), which brings more memory in. This could be the most important change. Actually, not just GPU, I changed PSU (850W) as well to give more power supply to make two GPU running.
1. With an extra GPU, I moved VGG network to /device:gpu:1, which will not be trained but consumes lots of memory. Because some weired DINT loading to GPU problem in tensorflow (or keras), G and D have to sit in the same GPU to avoid data loading/unloading from GPU. Cannot fix that, so I can only move VGG.
1. Use `allow_growth`. Tensorflow always tries to allocate maximum memory it needs before actually running forward or backward according to my understading. But it can be changed by changing this parameter to `True`, which, in my case, saved some memory. https://github.com/keras-team/keras/issues/4161
1. Reduce input size from 96x96 to 48x48. This helps a lot.
#### Trainable warning

Another problem is this anonying warning:

`Warning: Discrepancy between trainable weights and collected trainable weights, did you set 'model.trainable' without calling 'model.compile' after ?`

Looks like it is popular in the implementation of GAN networks. Because D needs to be frozn when training GAN, then open for training, and toggle between two status in each batch.

Solution will be using tensorflow.python.keras.engine.network.Network to wrap D network, as a frozen model, which will not be updated during training, but weights gets updated when origin network been trained. Detail is in here: https://github.com/keras-team/keras/issues/8585

#### Negative loss

Loss function used by G and D were wrong, G was bin_crossentrophy, D was using mse. Swap them fixed the problem.

In [4]:
# Set allow_growth
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
keras.backend.set_session(sess)

# initialize tensorboard
#tensorboard = keras.callbacks.TensorBoard(log_dir="logs/{}".format(time()))

# DEFINE MODEL
srgan = SRGAN(lr_height, lr_width, upscaling_factor=upscaling_rate)

#print(srgan.generator.summary())
#print(srgan.discriminator.summary())

# Load previous training result
saved_model_path = os.path.join(checkpoint_dir, "trained")
if os.path.isfile(saved_model_path + "_0_generator.h5") and os.path.isfile(saved_model_path + "_0_discriminator.h5"):
    srgan.load_weights(saved_model_path + "_0_generator.h5", saved_model_path + "_0_discriminator.h5")


# Shape of output from discriminator
#d_output_shape = list(srgan.discriminator.output_shape)
#d_output_shape[0] = batch_size

# VALID / FAKE targets for discriminator
#real = tf.ones(d_output_shape)
#fake = tf.zeros(d_output_shape)

# initialize generator
checkpoint = keras.callbacks.ModelCheckpoint(os.path.join(checkpoint_dir, "weights.{epoch:03d}.hdf5"),
                                             verbose=1, period=20)
history = srgan.generator.fit_generator(tdg, steps_per_epoch=steps_per_epoch, validation_data=vdg,
                                        validation_steps=5, epochs=num_epoch_init_g, callbacks=[checkpoint])

training_loss = history.history['loss']
test_loss = history.history['val_loss']

# Visualize loss history
plt.plot(num_epoch_init_g, training_loss, 'r--')
plt.plot(num_epoch_init_g, test_loss, 'b-')
plt.legend(['Training Loss', 'Test Loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show();


# save model
#srgan.save_weights(os.path.join(checkpoint_dir, "init"))

'''

# train gan
g_losses = []
d_losses = []
bar = Progbar(steps_per_epoch)

#print(srgan.discriminator.summary())

for epoch in range(num_epoch_gan):
    
    print("Epoch %d/%d" % (epoch+1, num_epoch_gan))

    # batches
    for i in range(steps_per_epoch):

        (batch_lr, batch_hr) = next(tdg)

        # train discriminator
        generated_hr = srgan.generator.predict(batch_lr)

        real_loss = srgan.discriminator.train_on_batch(batch_hr, real)
        fake_loss = srgan.discriminator.train_on_batch(generated_hr, fake)
        d_loss = np.add(real_loss, fake_loss)

        # train generator
        features_hr = srgan.vgg.predict(batch_hr)
        g_loss = srgan.srgan.train_on_batch(batch_lr, [real, features_hr, batch_hr])
        
        # Save losses
        g_losses.append(g_loss[0])
        d_losses.append(d_loss[0])
        
        bar.update(i+1,
                   [('d_loss', d_loss[0]),
                    ('gan_loss', g_loss[0]),
                    ('gan_d_loss', g_loss[1]),
                    ('vgg_loss', g_loss[2]),
                    ('g_loss', g_loss[3])])
        
    if epoch % checkpoint_step == 0:
        srgan.save_weights(os.path.join(checkpoint_dir, "trained_%d" % epoch))


# Save weights
srgan.save_weights(os.path.join(checkpoint_dir, "trained"))

# Plot g_lost
plt.figure()
plt.subplot(211)
plt.plot(g_losses)

plt.subplot(212)
plt.plot(d_losses)

plt.show()
'''

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 00020: saving model to checkpoints/srgan/weights.020.hdf5
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
 569/5000 [==>...........................] - ETA: 10:47 - loss: 0.0185

KeyboardInterrupt: 

In [4]:

# Load valid dataset
data_filename = os.path.join(checkpoint_dir, "valid.hdf5")
if os.path.isfile(data_filename):
    hf = h5py.File(data_filename, 'r')
    valid_hr = np.array(hf.get('valid_hr'))
    valid_lr = np.array(hf.get('valid_lr'))
    hf.close()
else:
    valid_hr, valid_lr = load_data(valid_hr_img_path, valid_lr_img_path)

    # Save to disk
    hf = h5py.File(data_filename, 'w')
    hf.create_dataset('valid_hr', data=valid_hr)
    hf.create_dataset('valid_lr', data=valid_lr)
    hf.close()

print(valid_hr.shape)
print(valid_lr.shape)

(160, 384, 384, 3)
(160, 96, 96, 3)


In [3]:
# Reload model
srgan = SRGAN()
saved_model_path = os.path.join(checkpoint_dir, "trained")
srgan.load_weights(saved_model_path + "_generator.h5", None)

# Validation
#valid_loss = srgan.generator.evaluate(valid_lr, valid_hr)

#print(valid_loss)

In [6]:

def predict(input_image):
    
    input_shape = input_image.shape
    LR = np.zeros((0, lr_height, lr_width, 3))

    for y in range(0, input_shape[0]-lr_height+1, lr_height):
        y_end = y+lr_height
        for x in range(0, input_shape[1]-lr_width+1, lr_width):
            x_end = x+lr_width

            lr_data = input_image[y:y_end, x:x_end, :]
            if lr_data.shape[0] == lr_height and lr_data.shape[1] == lr_width:
                LR = np.append(LR, np.expand_dims(lr_data, 0), axis=0)

    print(LR.shape)
    
    output = srgan.generator.predict(LR)
    r = upscaling_rate

    output_image = np.zeros((input_shape[0]*r, input_shape[1]*r, output.shape[3]))
    hr_width, hr_height = lr_width*r, lr_height*r
    x, y = 0, 0
    
    x_max = output_image.shape[1]
    for data in range(output.shape[0]):
        x_end = x + hr_width
        y_end = y + hr_height
        output_image[y:y_end, x:x_end, :] = output[data]
        x += hr_width
        if x+hr_width >= x_max:
            x = 0
            y += hr_height

    return output_image


result = predict(plt.imread(os.path.join(valid_lr_img_path, "0801x4.png")))

# Plot g_lost
#print(result)
plt.imsave(os.path.join(sample_dir, "0801.png"), result)

(70, 48, 48, 3)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
