# SRResnet-GAN

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, PReLU, LeakyReLU, Layer, Conv2D, BatchNormalization, Flatten
from tensorflow.keras.applications.vgg16 import VGG16

## Residual Block

In [None]:
class ResidualBlock(Layer):
    def __init__(self, channel=64, kernel_size=(3, 3)):
        super().__init__()
        pass

    def call(self, x, training=None, mask=None):
        pass

## Conv-Bn-Relu Block

In [None]:
class ConvBnLReluBlock(Layer):
    def __init__(self, kernel_size=(3, 3), channel=64):
        super().__init__()
        pass

    def call(self, x, training=None, mask=None):
        pass

## Generator (SRResnet)

In [None]:
class Generator(Model):
    def __init__(self, channel=64, num_resblock=5):
        super().__init__()
        pass

    def call(self, x, training=None, mask=None):
        pass

## Discriminator

In [None]:
class Discriminator(Model):
    def __init__(self, channel=64):
        super().__init__()
        pass

    def call(self, x, training=None, mask=None):
        pass

## Dataset (Caltech101)

In [None]:
dataset = tfds.load(name='caltech101', split='train')
dataset = dataset.map(lambda x: (tf.image.resize(tf.cast(x['image'], tf.float32), (8, 8), tf.image.ResizeMethod.BICUBIC) / 255.0,
                                 tf.image.resize(tf.cast(x['image'], tf.float32), (32, 32), tf.image.ResizeMethod.BICUBIC) / 255.0)).batch(1)

## VGG Model, Hyperparameters

In [None]:
generator = Generator(16)
discriminator = Discriminator(16)

vgg = VGG16(include_top=False, weights='imagenet', input_shape=(32, 32, 3))
vgg = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
vgg.trainable = False

w_gan = 1e-2
w_vgg = 1e-5

optim_d = tf.optimizers.Adam(1e-4)
optim_g = tf.optimizers.Adam(1e-4)

d_mean = tf.metrics.Mean()
g_mean = tf.metrics.Mean()
vgg_mean = tf.metrics.Mean()
l1_mean = tf.metrics.Mean()

## Losses

In [None]:
@tf.function
def l1_loss_func(y, y_):
    pass

@tf.function
def vgg_loss_func(y, y_):
    pass

@tf.function
def discriminator_loss(real, fake):
    pass

@tf.function
def generator_loss(fake):
    pass


## Training Step

In [None]:
@tf.function
def train_step(image_lr, image_hr, optim_d, optim_g):
    pass

## Training Loop

In [None]:
for epoch in range(100):
    for img_lr, img_hr in dataset:
        d_loss, g_loss, vgg_loss, l1_loss = train_step(img_lr, img_hr, optim_d, optim_g)

        d_mean.update_state(d_loss)
        g_mean.update_state(g_loss)
        vgg_mean.update_state(vgg_loss)
        l1_mean.update_state(l1_loss)

    print('epoch: {}, d_loss: {}, g_loss: {}, vgg_loss: {}, l1_loss: {}'.format(epoch+1,
                                                                d_mean.result(),
                                                                g_mean.result(),
                                                                vgg_mean.result(),
                                                                l1_mean.result()))
    img_sr_list = list()
    img_lr_list = list()
    img_hr_list = list()
    for img_lr, img_hr in dataset.take(10):
        img_sr = generator(img_lr)
        
        img_lr_list.append(tf.image.resize(img_lr[0], (32, 32),
                            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR))
        img_sr_list.append(img_sr[0])
        img_hr_list.append(img_hr[0])
    
    img_lr = np.concatenate(img_lr_list, axis=1)
    img_sr = np.concatenate(img_sr_list, axis=1)
    img_hr = np.concatenate(img_hr_list, axis=1)
    img = np.concatenate([img_lr, img_sr, img_hr], axis=0)
    
    plt.imshow(img)
    plt.show()

    d_mean.reset_states()
    g_mean.reset_states()
    vgg_mean.reset_states()
    l1_mean.reset_states()