# SRResnet-GAN

In [24]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, PReLU, LeakyReLU, Layer, Conv2D, BatchNormalization, Flatten
from tensorflow.keras.applications.vgg16 import VGG16

## Residual Block

In [25]:
class ResidualBlock(Layer):
    def __init__(self, channel=64, kernel_size=(3, 3)):
        super().__init__()
        self.conv1 = Conv2D(channel,kernel_size,padding='same',use_bias=False)
        self.bn1 = BatchNormalization()
        self.prelu = PReLU()
        self.conv2 = Conv2D(channel,kernel_size,padding='same',use_bias=False)
        self.bn2 = BatchNormalization()

    def call(self, x, training=None, mask=None):
        h = self.prelu(self.bn1(self.conv1(x),training))
        return x+ self.bn2(self.conv2(h),training)

## Conv-Bn-Relu Block

In [26]:
class ConvBnLReluBlock(Layer):
    def __init__(self, kernel_size=(3, 3), channel=64):
        super().__init__()
        self.conv1 = Conv2D(channel,kernel_size,padding='same',use_bias=False)
        self.bn1 = BatchNormalization()
        self.lrelu = LeakyReLU()

    def call(self, x, training=None, mask=None):
        return self.lrelu(self.bn1(self.conv1(x),training))

## Generator (SRResnet)

In [27]:
class Generator(Model):
    def __init__(self, channel=64, num_resblock=5):
        super().__init__()
        self.conv1 = Conv2D(channel,(9,9),padding='same')
        self.prelu1 = PReLU()
        
        self.resblock_list = [ResidualBlock(channel,(3,3)) for _ in range(num_resblock)]
        
        self.conv2 = Conv2D(channel,(3,3),padding='same',use_bias=False)
        self.bn = BatchNormalization()
        self.prelu2 =PReLU()
        
        self.conv3 = Conv2D(channel*4,(3,3),padding='same')
        self.prelu3= PReLU()
        
        self.conv4 = Conv2D(channel*4,(3,3),padding='same')
        self.prelu4= PReLU()
        self.conv5 = Conv2D(3,(9,9),padding='same')
        

    def call(self, x, training=None, mask=None):
        h = self.prelu1(self.conv1(x))
        h_skip = h
        for resblock in self.resblock_list:
            h = resblock(h,training)
        
        h = self.prelu2(self.bn(self.conv2(h),training))
        h = h+h_skip
        
        h = self.prelu3(tf.nn.depth_to_space(self.conv3(h),2))
        h = self.prelu4(tf.nn.depth_to_space(self.conv4(h),2))
        
        return self.conv5(h)
        

## Discriminator

In [28]:
class Discriminator(Model):
    def __init__(self, channel=64):
        super().__init__()
        self.conv1 = Conv2D(channel,(3,3),padding='same')
        self.lrelu1 = LeakyReLU()
        
        self.block_list = list()
        self.block_list.append(ConvBnLReluBlock(channel))
        self.block_list.append(ConvBnLReluBlock(channel*2))
        self.block_list.append(ConvBnLReluBlock(channel*2))
        self.block_list.append(ConvBnLReluBlock(channel*4))
        self.block_list.append(ConvBnLReluBlock(channel*4))
        self.block_list.append(ConvBnLReluBlock(channel*8))
        self.block_list.append(ConvBnLReluBlock(channel*8))
        
        self.flatten = Flatten()
        self.dense1 = Dense(1024)
        self.lrelu2 = LeakyReLU()
        self.dense2 = Dense(1,activation="sigmoid")
        

    def call(self, x, training=None, mask=None):
        h = self.lrelu1(self.conv1(x))
        
        for blocklist in self.block_list:
            h = blocklist(h,training)
        
        h = self.flatten(h)
        h = self.lrelu2(self.dense1(h))
        return self.dense2(h)

## Dataset (Caltech101)

In [29]:
dataset = tfds.load(name='downloads', split='train')
dataset = dataset.map(lambda x: (tf.image.resize(tf.cast(x['image'], tf.float32), (8, 8), tf.image.ResizeMethod.BICUBIC) / 255.0,
                                 tf.image.resize(tf.cast(x['image'], tf.float32), (32, 32), tf.image.ResizeMethod.BICUBIC) / 255.0)).batch(1)

DatasetNotFoundError: Dataset downloads not found.
Available datasets:
	- abstract_reasoning
	- accentdb
	- aeslc
	- aflw2k3d
	- ag_news_subset
	- ai2_arc
	- ai2_arc_with_ir
	- amazon_us_reviews
	- anli
	- arc
	- bair_robot_pushing_small
	- bccd
	- beans
	- big_patent
	- bigearthnet
	- billsum
	- binarized_mnist
	- binary_alpha_digits
	- blimp
	- bool_q
	- c4
	- caltech101
	- caltech_birds2010
	- caltech_birds2011
	- cars196
	- cassava
	- cats_vs_dogs
	- celeb_a
	- celeb_a_hq
	- cfq
	- cherry_blossoms
	- chexpert
	- cifar10
	- cifar100
	- cifar10_1
	- cifar10_corrupted
	- citrus_leaves
	- cityscapes
	- civil_comments
	- clevr
	- clic
	- clinc_oos
	- cmaterdb
	- cnn_dailymail
	- coco
	- coco_captions
	- coil100
	- colorectal_histology
	- colorectal_histology_large
	- common_voice
	- coqa
	- cos_e
	- cosmos_qa
	- covid19sum
	- crema_d
	- curated_breast_imaging_ddsm
	- cycle_gan
	- dart
	- davis
	- deep_weeds
	- definite_pronoun_resolution
	- dementiabank
	- diabetic_retinopathy_detection
	- div2k
	- dmlab
	- downsampled_imagenet
	- drop
	- dsprites
	- dtd
	- duke_ultrasound
	- e2e_cleaned
	- emnist
	- eraser_multi_rc
	- esnli
	- eurosat
	- fashion_mnist
	- flic
	- flores
	- food101
	- forest_fires
	- fuss
	- gap
	- geirhos_conflict_stimuli
	- genomics_ood
	- german_credit_numeric
	- gigaword
	- glue
	- goemotions
	- gpt3
	- groove
	- gtzan
	- gtzan_music_speech
	- hellaswag
	- higgs
	- horses_or_humans
	- howell
	- i_naturalist2017
	- imagenet2012
	- imagenet2012_corrupted
	- imagenet2012_real
	- imagenet2012_subset
	- imagenet_a
	- imagenet_r
	- imagenet_resized
	- imagenet_v2
	- imagenette
	- imagewang
	- imdb_reviews
	- irc_disentanglement
	- iris
	- kitti
	- kmnist
	- lambada
	- lfw
	- librispeech
	- librispeech_lm
	- libritts
	- ljspeech
	- lm1b
	- lost_and_found
	- lsun
	- lvis
	- malaria
	- math_dataset
	- mctaco
	- mlqa
	- mnist
	- mnist_corrupted
	- movie_lens
	- movie_rationales
	- movielens
	- moving_mnist
	- multi_news
	- multi_nli
	- multi_nli_mismatch
	- natural_questions
	- natural_questions_open
	- newsroom
	- nsynth
	- nyu_depth_v2
	- omniglot
	- open_images_challenge2019_detection
	- open_images_v4
	- openbookqa
	- opinion_abstracts
	- opinosis
	- opus
	- oxford_flowers102
	- oxford_iiit_pet
	- para_crawl
	- patch_camelyon
	- paws_wiki
	- paws_x_wiki
	- pet_finder
	- pg19
	- piqa
	- places365_small
	- plant_leaves
	- plant_village
	- plantae_k
	- qa4mre
	- qasc
	- quac
	- quickdraw_bitmap
	- race
	- radon
	- reddit
	- reddit_disentanglement
	- reddit_tifu
	- resisc45
	- robonet
	- rock_paper_scissors
	- rock_you
	- s3o4d
	- salient_span_wikipedia
	- samsum
	- savee
	- scan
	- scene_parse150
	- scicite
	- scientific_papers
	- sentiment140
	- shapes3d
	- siscore
	- smallnorb
	- snli
	- so2sat
	- speech_commands
	- spoken_digit
	- squad
	- stanford_dogs
	- stanford_online_products
	- starcraft_video
	- stl10
	- story_cloze
	- sun397
	- super_glue
	- svhn_cropped
	- ted_hrlr_translate
	- ted_multi_translate
	- tedlium
	- tf_flowers
	- the300w_lp
	- tiny_shakespeare
	- titanic
	- trec
	- trivia_qa
	- tydi_qa
	- uc_merced
	- ucf101
	- vctk
	- vgg_face2
	- visual_domain_decathlon
	- voc
	- voxceleb
	- voxforge
	- waymo_open_dataset
	- web_nlg
	- web_questions
	- wider_face
	- wiki40b
	- wiki_bio
	- wiki_table_questions
	- wiki_table_text
	- wikihow
	- wikipedia
	- wikipedia_toxicity_subtypes
	- wine_quality
	- winogrande
	- wmt14_translate
	- wmt15_translate
	- wmt16_translate
	- wmt17_translate
	- wmt18_translate
	- wmt19_translate
	- wmt_t2t_translate
	- wmt_translate
	- wordnet
	- wsc273
	- xnli
	- xquad
	- xsum
	- xtreme_pawsx
	- xtreme_xnli
	- yelp_polarity_reviews
	- yes_no

Check that:
    - if dataset was added recently, it may only be available
      in `tfds-nightly`
    - the dataset name is spelled correctly
    - dataset class defines all base class abstract methods
    - the module defining the dataset class is imported


## VGG Model, Hyperparameters

In [30]:
generator = Generator(16)
discriminator = Discriminator(16)

vgg = VGG16(include_top=False, weights='imagenet', input_shape=(32, 32, 3))
vgg = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
vgg.trainable = False

w_gan = 1e-2
w_vgg = 1e-5

optim_d = tf.optimizers.Adam(1e-4)
optim_g = tf.optimizers.Adam(1e-4)

d_mean = tf.metrics.Mean()
g_mean = tf.metrics.Mean()
vgg_mean = tf.metrics.Mean()
l1_mean = tf.metrics.Mean()

## Losses

In [31]:
@tf.function
def l1_loss_func(y, y_):
    return tf.reduce_mean(tf.abs(y-y_))

@tf.function
def vgg_loss_func(y, y_):
    return tf.reduce_mean(vgg(y)-vgg(y_)**2)

@tf.function
def discriminator_loss(real, fake):
    real_loss = tf.keras.losses.BinaryCrossentropy(tf.ones_like(real),real)
    fake_loss = tf.keras.losses.BinaryCrossentropy(tf.zeros_like(fake),fake)
    return real_loss + fake_loss

@tf.function
def generator_loss(fake):
    fake_loss = tf.keras.losses.BinaryCrossentropy(tf.ones_like(fake),fake)
    return fake_loss


## Training Step

In [32]:
@tf.function
def train_step(image_lr, image_hr, optim_d, optim_g):
    with tf.GradientTape() as tape_d, tf.GradientTape() as tape_g:
        image_sr = generator(image_lr, training=True)
        
        d_real = discriminator(image_hr, training=True)
        d_fake = discriminator(image_sr, training=True)
        
        d_loss = discriminator_loss(d_real, d_fake)
        g_loss = generator_loss(d_fake)
        
        vgg_loss = vgg_loss_func(image_hr, image_sr)
        l1_loss = l1_loss_func(image_hr, image_sr)
        
        loss = w_gan * g_loss + w_vgg * vgg_loss + l1_loss
        
        gradients_d = tape_d.gradient(d_loss, discriminator.trainable_weights)
        gradients_g = tape_g.gradient(loss, generator.trainable_weights)
    
    optim_d.apply_gradients(zip(gradients_d, discriminator.trainable_weights))
    optim_g.apply_gradients(zip(gradients_g, generator.trainable_weights))
    return d_loss, g_loss, vgg_loss, l1_loss

## Training Loop

In [33]:
for epoch in range(100):
    for img_lr, img_hr in dataset:
        d_loss, g_loss, vgg_loss, l1_loss = train_step(img_lr, img_hr, optim_d, optim_g)

        d_mean.update_state(d_loss)
        g_mean.update_state(g_loss)
        vgg_mean.update_state(vgg_loss)
        l1_mean.update_state(l1_loss)

    print('epoch: {}, d_loss: {}, g_loss: {}, vgg_loss: {}, l1_loss: {}'.format(epoch+1,
                                                                d_mean.result(),
                                                                g_mean.result(),
                                                                vgg_mean.result(),
                                                                l1_mean.result()))
    img_sr_list = list()
    img_lr_list = list()
    img_hr_list = list()
    for img_lr, img_hr in dataset.take(10):
        img_sr = generator(img_lr)
        
        img_lr_list.append(tf.image.resize(img_lr[0], (32, 32),
                            method=tf.image.ResizeMethod.NEAREST_NEIGHBOR))
        img_sr_list.append(img_sr[0])
        img_hr_list.append(img_hr[0])
    
    img_lr = np.concatenate(img_lr_list, axis=1)
    img_sr = np.concatenate(img_sr_list, axis=1)
    img_hr = np.concatenate(img_hr_list, axis=1)
    img = np.concatenate([img_lr, img_sr, img_hr], axis=0)
    
    plt.imshow(img)
    plt.show()

    d_mean.reset_states()
    g_mean.reset_states()
    vgg_mean.reset_states()
    l1_mean.reset_states()

NameError: name 'dataset' is not defined