In [368]:
import tensorflow as tf
from collections import OrderedDict
from functools import reduce
import numpy as np


shape_lr = (36,36,1)
shape_hr = (72,72,1)

class VGGLossNoActivation(object):
    """By ESRGAN a more effective perceptual loss constraining on features before activation rather than 
    after activation as practiced in SRGAN. 
    Reference: https://arxiv.org/abs/1809.00219"""

    def __init__(self, image_shape):
        self.model = self.create_model(image_shape)
        
    def create_model(self,image_shape):
        
        
        vgg19 = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet', input_shape=image_shape)

        x = tf.keras.layers.Conv2D(512, (3, 3),padding='same',
                                   name='block5_conv4')(vgg19.get_layer('block5_conv3').output)
        
        model = tf.keras.Model(inputs=vgg19.input, outputs=x)
        model.trainable = False
        return model
    
    def preprocess_vgg(self, x):
        if isinstance(x, np.ndarray):
            return tf.keras.applications.vgg19.preprocess_input((x))
        else:            
            return tf.keras.layers.Lambda(lambda x: tf.keras.applications.vgg19.preprocess_input((x)))(x)
        
    # computes VGG loss 
    def perceptual_loss(self, y_true, y_pred):
        return tf.math.reduce_mean(tf.math.square(self.model(self.preprocess_vgg(y_true)) - self.model(self.preprocess_vgg(y_pred))),None)
    
    def euclidean_content_loss(self, y_true, y_pred):
        return tf.math.sqrt(tf.math.reduce_sum(tf.math.square(self.model(self.preprocess_vgg(y_true)) - self.model(self.preprocess_vgg(y_pred))), axis=None))
    
    def compoundLoss(self, y_true, y_pred,alfa=10e-2,beta=10e0):
        return (alfa * tf.math.reduce_mean(tf.math.square(self.model(self.preprocess_vgg(y_true)) - self.model(self.preprocess_vgg(y_pred))),None) + beta * tf.math.sqrt(tf.math.reduce_sum(tf.math.square(y_pred - y_true), axis=None)))


def psnr(y, y_pred,max_val=1.0):
    y = tf.image.convert_image_dtype(y, tf.float32)
    y_pred = tf.image.convert_image_dtype(y_pred, tf.float32)
    values = tf.image.psnr(y, y_pred, max_val=max_val)
    return tf.reduce_mean(values)

def ssim(y, y_pred,max_val=1.0):
    y = tf.image.convert_image_dtype(y, tf.float32)
    y_pred = tf.image.convert_image_dtype(y_pred, tf.float32)
    values = tf.image.ssim(y, y_pred, max_val=max_val, filter_size=11,
                          filter_sigma=1.5, k1=0.01, k2=0.03)
    return tf.reduce_mean(values)


class Dataset:
    def __init__(self, batch_size, dataset_path, dataset_info_path, shuffle_buffer_size=0):
        self._batch_size = batch_size
        self._shuffle_buffer_size = shuffle_buffer_size
        self._dataset_path = dataset_path
        with open(dataset_info_path, 'r') as dataset_info:
            self.examples_num = int(dataset_info.readline())
            self.scale_factor = int(dataset_info.readline())
            self.input_info = OrderedDict()
            for line in dataset_info.readlines():
                items = line.split(',')
                self.input_info[items[0]] = [int(dim) for dim in items[1:]]

    def _parse_tf_example(self, example_proto):
        features = dict([(key, tf.io.FixedLenFeature([], tf.string)) for key, _ in self.input_info.items()])
        parsed_features = tf.io.parse_single_example(example_proto, features=features)

        return [tf.reshape(tf.cast(tf.io.decode_raw(parsed_features[key], tf.uint8), tf.float32), value)
                for key, value in self.input_info.items()]
    
    def get_data(self,epoch=None):
        dataset = tf.data.TFRecordDataset(self._dataset_path)
        dataset = dataset.map(self._parse_tf_example,num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.batch(self._batch_size,drop_remainder=True).repeat(epoch)
        return dataset


def discriminator(filters=64,input_shape=shape_hr):

    def conv2d_block(input, filters, strides=1, bn=True):
        d = tf.keras.layers.Conv2D(filters, kernel_size=3, strides=strides, padding='same')(input)
        if bn:
            d = tf.keras.layers.BatchNormalization(momentum=0.8)(d)
        d = tf.keras.layers.LeakyReLU(alpha=0.2)(d)
        return d
    
    input = tf.keras.layers.Input(shape=input_shape)
    x = conv2d_block(input, filters, bn=False)
    x = conv2d_block(x, filters, strides=2)
    x = conv2d_block(x, filters*2)
    x = conv2d_block(x, filters*2, strides=2)
    x = conv2d_block(x, filters*4)
    x = conv2d_block(x, filters*4, strides=2)
    x = conv2d_block(x, filters*8)
    x = conv2d_block(x, filters*8, strides=2)
    x = tf.keras.layers.Dense(filters*16)(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
    x = tf.keras.layers.Dense(0.4)(x)
    x = tf.keras.layers.Dense(1,activation='sigmoid')(x)
    model = tf.keras.Model(inputs=input, outputs=x,name='Discriminator')
    return model


def generator(scale_factor=2):   
    inputs = tf.keras.layers.Input(shape=(None,None,1),name='input')
    
    net = tf.pad(inputs, [[0, 0], [0, 0], [0, 0], [0, 0]], 'SYMMETRIC')
    net = tf.keras.layers.Conv2D(32, 3,padding='same',strides=(1, 1), name='conv1',
                                kernel_initializer=tf.keras.initializers.HeNormal())(net)
    net = tf.keras.layers.LeakyReLU(alpha=0.2)(net)
    net1 = net
    net = tf.keras.layers.Conv2D(32, 3,padding='same',strides=(1, 1), name='conv2',
                                kernel_initializer=tf.keras.initializers.HeNormal())(net)
    net = tf.keras.layers.LeakyReLU(alpha=0.2)(net)
    net2 = net 
    net = tf.keras.layers.add([net1, net2])
    
    net = tf.keras.layers.Conv2D(32, 3,padding='same',strides=(1, 1), name='conv3',
                                kernel_initializer=tf.keras.initializers.HeNormal())(net)
    net = tf.keras.layers.LeakyReLU(alpha=0.2)(net)
    net3 = net
    net = tf.keras.layers.concatenate([net1, net2, net3],axis=3)
    
    net = tf.keras.layers.Conv2D(scale_factor ** 2, 3,activation='tanh', 
                            padding='same',strides=(1, 1), name='conv4',
                            kernel_initializer=tf.keras.initializers.HeNormal())(net)
    outputs = tf.keras.layers.Lambda(lambda x:tf.nn.depth_to_space(x,scale_factor),
                                        name = 'prediction')(net)
    model = tf.keras.Model(inputs=inputs, outputs=outputs,name='rtvsrgan')
    return model

import tensorflow as tf
from functools import reduce

class GAN(tf.keras.Model):
    def __init__(self, discriminator, generator):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        
    def compile(self, d_optimizer, g_optimizer, d_loss, g_loss, metrics):
        super(GAN, self).compile(metrics = metrics)
        self.d_optimizer = d_optimizer
        self.d_loss = d_loss 
        self.g_optimizer = g_optimizer
        self.g_loss = g_loss   
        
    def load_weights_gen(self,checkpoint_filepath):
        self.generator.load_weights(checkpoint_filepath)
    
    def load_weights_dis(self,checkpoint_filepath):
        self.discriminator.load_weights(checkpoint_filepath)
    
    def save_weights_gen(self,checkpoint_filepath):
        # Save the weights
        self.generator.save_weights(checkpoint_filepath)
             
        
    def train_step(self, data):
        if isinstance(data, tuple):
            img_lr, img_hr = data
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            img_sr = self.generator(img_lr, training=True)

            real_output = self.discriminator(img_hr, training=True)
            fake_output = self.discriminator(img_sr, training=True)

            g_loss,c_loss, a_loss, p_loss = self.g_loss(fake_output,img_hr,img_sr)
            d_loss = self.d_loss(real_output, fake_output)
            
        gradients_of_generator = gen_tape.gradient(g_loss, self.generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(d_loss, self.discriminator.trainable_variables)

        self.g_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
        self.d_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
        
        self.compiled_metrics.update_state(img_hr, img_sr) 
        
        return reduce(lambda x,y: dict(x, **y), 
                      ({"d_loss": d_loss, "g_loss": g_loss,"a_loss": a_loss, "c_loss": c_loss, "p_loss": p_loss },
                       {m.name: m.result() for m in self.metrics})) 
     

In [376]:
from save_img_callback import SaveImageCallback
import argparse

MODEL='rtvsrgan'
BATCH_SIZE = 32
TEST_BATCH_SIZE = 4
SHUFFLE_BUFFER_SIZE = 45150
OPTIMIZER='adam'
LEARNING_RATE = 1e-4
LEARNING_DECAY_RATE = 1e-1
LEARNING_DECAY_EPOCHS = 40
MOMENTUM = 0.9
NUM_EPOCHS = 100
STEPS_PER_EPOCH = 500
SAVE_NUM = 2
STEPS_PER_LOG = 1000
EPOCHS_PER_SAVE = 1
LOGDIR = 'logdir'
CHECKPOINT = 'checkpoint/'

TEST_LOGDIR='test_logdir/'


TRAINING_DATASET_PATH='/home/joao/Documentos/projetos/ssd/dataset/train_football-qp17/dataset.tfrecords'
TRAINING_DATASET_INFO_PATH='/home/joao/Documentos/projetos/ssd/dataset/train_football-qp17/dataset_info.txt'

TESTING_DATASET_PATH='/home/joao/Documentos/projetos/ssd/dataset/test_football-qp17/dataset.tfrecords'
TESTING_DATASET_INFO_PATH='/home/joao/Documentos/projetos/ssd/dataset/test_football-qp17/dataset_info.txt'

def get_arguments(string=""):
    parser = argparse.ArgumentParser(description='train one of the models for image and video super-resolution')
    parser.add_argument('--model', type=str, default=MODEL, choices=['espcn','rtvsresnt','rtvsrgan'],
                        help='What model to train')
    parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,
                        help='Number of images in batch')
    parser.add_argument('--valid_batch_size', type=int, default=TEST_BATCH_SIZE,
                        help='Number of images in test batch')
    parser.add_argument('--train_dataset_path', type=str, default=TRAINING_DATASET_PATH,
                        help='Path to the train dataset')
    parser.add_argument('--train_dataset_info_path', type=str, default=TRAINING_DATASET_INFO_PATH,
                        help='Path to the train dataset info')
    parser.add_argument('--valid_dataset_path', type=str, default=TESTING_DATASET_PATH,
                        help='Path to the test dataset')
    parser.add_argument('--valid_dataset_info_path', type=str, default=TESTING_DATASET_INFO_PATH,
                        help='Path to the train dataset info')
    parser.add_argument('--ckpt_path', default=CHECKPOINT,
                        help='Path to the model checkpoint to evaluate')
    parser.add_argument('--load_weights', action='store_true',
                        help='Path to the model checkpoint to evaluate')
    parser.add_argument('--shuffle_buffer_size', type=int, default=SHUFFLE_BUFFER_SIZE,
                        help='Buffer size used for shuffling examples in dataset')
    parser.add_argument('--optimizer', type=str, default=OPTIMIZER, choices=['adam', 'momentum', 'sgd'],
                        help='What optimizer to use for training')
    parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,
                        help='Learning rate used for training')
    parser.add_argument('--use_lr_decay', action='store_true',
                        help='Whether to apply exponential decay to the learning rate')
    parser.add_argument('--lr_decay_rate', type=float, default=LEARNING_DECAY_RATE,
                        help='Learning rate decay rate used in exponential decay')
    parser.add_argument('--lr_decay_epochs', type=int, default=LEARNING_DECAY_EPOCHS,
                        help='Number of epochs before full decay rate tick used in exponential decay')
    parser.add_argument('--staircase_lr_decay', action='store_true',
                        help='Whether to decay the learning rate at discrete intervals')
    parser.add_argument('--num_epochs', type=int, default=NUM_EPOCHS,
                        help='Number of training epochs')
    parser.add_argument('--steps_per_epochs', type=int, default=STEPS_PER_EPOCH,
                        help='How many steps per epochs')
    parser.add_argument('--save_num', type=int, default=SAVE_NUM,
                        help='How many images to write to summary')
    parser.add_argument('--steps_per_log', type=int, default=STEPS_PER_LOG,
                        help='How often to save summaries')
    parser.add_argument('--epochs_per_save', type=int, default=EPOCHS_PER_SAVE,
                        help='How often to save checkpoints')
    parser.add_argument('--use_mc', action='store_true',
                        help='Whether to use motion compensation in video super resolution model')
    parser.add_argument('--mc_independent', action='store_true',
                        help='Whether to train motion compensation network independent from super resolution network')
    parser.add_argument('--logdir', type=str, default=LOGDIR,
                        help='Where to save checkpoints and summaries')
    parser.add_argument('--test_logdir', type=str, default=TEST_LOGDIR,
                        help='Where to save tests images')

    return parser.parse_args(string)


args = get_arguments()

train_dataset = Dataset(args.batch_size,
        args.train_dataset_path,
        args.train_dataset_info_path,
        args.shuffle_buffer_size)

steps_per_epoch = train_dataset.examples_num // args.batch_size \
        if train_dataset.examples_num % args.batch_size != 0 else 0

train_dataset = train_dataset.get_data(args.num_epochs)
train_batch = train_dataset.map(lambda x0,x1,x2,y: (x1/255.0,y/255.0))


valid_dataset = Dataset(args.valid_batch_size,
        args.valid_dataset_path,
        args.valid_dataset_info_path)

valid_dataset = valid_dataset.get_data()
    
test_batch = valid_dataset.map(lambda x0,x1,x2,y: (x1,y))
test_batch = iter(test_batch).get_next() 


g=generator()
d=discriminator()

save_img_callback = SaveImageCallback(
            dataset=test_batch,
            model=g,
            model_name=args.model,
            epochs_per_save=args.epochs_per_save,
            log_dir=args.test_logdir)

callbacks=[save_img_callback]

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)

adv_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
#cont_loss = tf.keras.losses.MeanAbsoluteError()
cont_loss = tf.keras.losses.Huber()
#cont_loss = tf.keras.losses.MeanSquaredError()


shape_hr = (72,72,3)    
vgg_loss = VGGLossNoActivation(shape_hr)
perc_loss = vgg_loss.content_loss
 

lbd = 1 * 1e0
eta = 1 * 1e0
mu = 1 * 1e-2

def discriminator_loss(real_output, fake_output):
    noise = 0.05 * tf.random.uniform(tf.shape(real_output))
    real_loss = cross_entropy(tf.ones_like(real_output)-noise, real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output)+noise, fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output,img_hr,img_sr):
    noise = 0.05 * tf.random.uniform(tf.shape(fake_output))
    a_loss = adv_loss(tf.ones_like(fake_output)-noise, fake_output) * lbd
    c_loss = cont_loss(img_hr,img_sr) * eta
    img_hr = tf.keras.layers.Concatenate()([img_hr, img_hr, img_hr])
    img_sr = tf.keras.layers.Concatenate()([img_sr, img_sr, img_sr])
    p_loss = perc_loss(img_hr,img_sr) * mu
    total_loss = c_loss + a_loss + p_loss
    return total_loss, c_loss , a_loss , p_loss


gan = GAN(discriminator = d, generator = g)

gan.compile(d_optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate),
            g_optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate),
            d_loss = discriminator_loss,
            g_loss = generator_loss,
            metrics=[psnr,ssim])

In [377]:
gan.fit(train_batch, epochs=args.num_epochs,verbose=1,steps_per_epoch=steps_per_epoch,callbacks=callbacks)

Epoch 1/100
 146/1410 [==>...........................] - ETA: 42:00 - d_loss: 1.3863 - g_loss: 0.7057 - a_loss: 0.6931 - c_loss: 0.0124 - p_loss: 2.1462e-04 - psnr: 19.0049 - ssim: 0.3825

KeyboardInterrupt: 