In [6]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
import sys

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, BatchNormalization, Flatten, Dense, LeakyReLU, PReLU, Add, Input, Lambda
from tensorflow.keras.activations import sigmoid
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.applications import VGG19
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, LambdaCallback
import tensorflow.keras.backend as K

In [21]:
BATCHSIZE = 10
NUMIMAGES = 5000

In [8]:
class DataLoader:
    def __init__(self, path, batch_size=BATCHSIZE):
        self.path = path
        self.batch_size = batch_size
        self.idx = 1
    
    def reset(self):
        self.idx = 0
    
    def get_next(self, num_blurs=1):
        impath = os.path.join(self.path, str(self.idx).zfill(4) + ".jpg")
        
        hr_img = tf.io.read_file(impath)
        hr_img = tf.image.decode_image(hr_img)
        hr_img = tf.image.convert_image_dtype(hr_img, tf.float32)
        hr_img = tf.image.resize(hr_img, (416, 416))
        
        #lr_img = self.blur(hr_img, num_blurs)
        lr_img = self.downscale(hr_img)
        
        self.idx += 1
        return hr_img, lr_img
    
    def get_next_batch(self, num_blurs=1):
        hrs = []
        lrs = []
        for _ in range(self.batch_size):
            (hr, lr) = self.get_next(num_blurs)
            hrs.append(tf.expand_dims(hr, 0))
            lrs.append(tf.expand_dims(lr, 0))
        hrs = tf.concat(hrs, 0)
        lrs = tf.concat(lrs, 0)
        return hrs, lrs
    
    def __call__(self):
        yield self.get_next()
    
    def blur(self, img, n=1):
        img = tf.identity(img)
        img = tf.image.convert_image_dtype(img, tf.float32)
        
        img = tf.transpose(img, perm=[2,0,1])
        img = tf.expand_dims(img, -1)
        
        kernel = tf.constant(np.array([[1,2,1], [2,4,2], [1,2,1]]), dtype=tf.float32)
        kernel = tf.expand_dims(kernel, -1)
        kernel = tf.expand_dims(kernel, -1)
        
        for _ in range(n):
            img = tf.nn.conv2d(img, kernel, strides=[1,1,1,1], padding='SAME') / (16)
        img = tf.transpose(img, perm=[3,1,2,0])
        img = tf.squeeze(img, 0)
        return img
    
    def downscale(self, img):
        return tf.image.resize(img, (104, 104))

def show_hrlr(pair):
    (hr, lr) = pair
    if len(hr.shape) == 4:
        hr = tf.squeeze(hr, 0)
        lr = tf.squeeze(lr, 0)
    (hr, lr) = dl.get_next(50)
    plt.subplot(1,2,1)
    plt.imshow(hr)
    plt.subplot(1,2,2)
    plt.imshow(lr)

In [9]:
def ResBlock(x):
    skip = x
    x = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")(x)
    x = BatchNormalization(momentum=0.5)(x)
    x = PReLU(alpha_initializer="zeros", alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(x)
    x = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")(x)
    x = BatchNormalization(momentum=0.5)(x)
    x = Add()([skip, x])
    return x

def PixelShuffle(x):
    x = Conv2D(filters=256, kernel_size=3, strides=1, padding="same")(x)
    x = Lambda(lambda x : tf.nn.depth_to_space(x, 2))(x)
    x = PReLU(alpha_initializer="zeros", alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(x)
    return x

def Generator():
    x = x_input = Input((104, 104, 3))
    
    x = Conv2D(filters=64, kernel_size=9, strides=1, padding="same")(x)
    x = x_skip = PReLU(alpha_initializer="zeros", alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(x)
    
    for _ in range(16):
        x = ResBlock(x)
    
    x = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")(x)
    x = BatchNormalization(momentum=0.5)(x)
    x = Add()([x_skip, x])
    
    for _ in range(2):
        x = PixelShuffle(x)
    
    x = Conv2D(filters=3, kernel_size=9, strides=1, padding="same")(x)
    
    return Model(inputs=x_input, outputs=x)

In [10]:
def ConvBlock(x, filters, strides):
    x = Conv2D(filters=filters, kernel_size=3, strides=strides, padding="same")(x)
    x = BatchNormalization(momentum=0.5)(x)
    x = LeakyReLU(alpha=0.2)(x)
    return x

def Discriminator():
    x = x_input = Input((416, 416, 3))
    
    x = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = ConvBlock(x, filters=64, strides=2)
    x = ConvBlock(x, filters=128, strides=1)
    x = ConvBlock(x, filters=128, strides=2)
    x = ConvBlock(x, filters=256, strides=1)
    x = ConvBlock(x, filters=256, strides=2)
    x = ConvBlock(x, filters=512, strides=1)
    x = ConvBlock(x, filters=512, strides=2)
    
    x = Flatten()(x)
    x = Dense(units=1024)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dense(1)(x)
    x = sigmoid(x)
    
    return Model(inputs=x_input, outputs=x)

In [22]:
class GAN(Model):
    def __init__(self):
        super(GAN, self).__init__()
        
        self.G = Generator()
        self.D = Discriminator()
        
        self.vgg = VGG19(include_top=False, weights="imagenet", input_shape=(416,416,3))
        self.vgg.trainable = False
        for layer in self.vgg.layers:
            layer.trainable = False
        self.vgg = Model(inputs=self.vgg.input, outputs=self.vgg.get_layer("block5_conv4").output)
    
    def compile(self, content_loss, adversarial_loss, gen_optimizer, disc_optimizer):
        super(GAN, self).compile()
        self.content_loss = content_loss
        self.adversarial_loss = adversarial_loss
        self.gen_optimizer = gen_optimizer
        self.disc_optimizer = disc_optimizer
    
    def train_step(self, data):
        (hr, lr) = data
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            sr = self.G(lr)
            discriminator_output_fake = self.D(sr)
            discriminator_output_real = self.D(hr)
            gen_loss = self.content_loss(self.vgg, hr, sr)
            disc_loss = self.adversarial_loss(discriminator_output_fake, discriminator_output_real)
        
        gen_grads = gen_tape.gradient(gen_loss, self.G.trainable_weights)
        disc_grads = disc_tape.gradient(disc_loss, self.D.trainable_weights)
        
        self.gen_optimizer.apply_gradients(zip(gen_grads, self.G.trainable_weights))
        self.disc_optimizer.apply_gradients(zip(disc_grads, self.D.trainable_weights))
        
        return {"content_loss" : gen_loss, "adversarial_loss" : disc_loss}

In [23]:
def content_loss(vgg, hr, sr):
    loss = K.mean(K.square(vgg(hr) - vgg(sr)))
    return loss

In [24]:
def train():
    dl = DataLoader(path="/Users/anonymous/Documents/warudo/datasets/coco_val2017/")
    gan = GAN()
    gan.compile(
        content_loss=content_loss,
        adversarial_loss=BinaryCrossentropy(),
        gen_optimizer=Adam(0.003),
        disc_optimizer=Adam(0.003)
    )
    num_batches = NUMIMAGES / BATCHSIZE
    for _ in num_batches:
        gan.train_step(dl.get_next_batch())