### setup

In [6]:
import cv2
import os
import random

In [4]:
# tensorflow dependencies (model compenents and deep learning components)
# using tensorflow funcational api
import tensorflow as tf
from keras.models import Model
from keras.layers import Layer, Conv2D, Dense, MaxPooling2D, Input, Flatten


In [5]:
# avoiding out of memory (OOM) errors by setting GPU Memory Consumption Growth
gpu_list = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpu_list:
    tf.config.experimental.set_memory_growth(gpu, True)

### preprocessing

In [4]:
'''
data will come from the DIV2K dataset, Flickr2K, and OutdoorSceneTraining datasets.
Model is trained in RGB channels with random horizontal flips and 90 degree rotations.
'''

AUTO = tf.data.AUTOTUNE

In [13]:
def cropping(loRes, hiRes, random=False, center=False, crop_size=96, scale=4):
    ''' 
    This function will perform a random cropping of two imput images, one of low resolution and 
    another of high resolution. This is so images patches are created at a specific size to be fed
    into the program to train the model. It will return a tuple with the two cropped patches.
    '''

    ## low res image crop size calculated form the high res image crop
    ## outputted high resolution image will be 4 times larger than the low res image
    loRes_crop = crop_size//4

    loRes_shape = tf.shape(loRes)[:2]
    height, width, _ = loRes_shape

    if random:
        loheight = random.randint(0, height-loRes_crop)
        lowidth = random.randint(0, width-loRes_crop)

        hiheight = loheight*scale
        hiwidth = lowidth*scale
        
    elif center:
        loheight = height + (loRes_crop//2)
        lowidth = width + (loRes_crop//2)

        hiheight = (height*scale) + (crop_size//2)
        hiwidth = (width*scale) + (crop_size//2)


    loRes_cropped = tf.slice(loRes, [loheight, lowidth, 0], [(loRes_crop), (loRes_crop), 3])
    hiRes_cropped = tf.slice(hiRes, [hiheight, hiwidth, 0], [(crop_size), (crop_size), 3])
    
    return (loRes_cropped, hiRes_cropped)


In [12]:
def augmentations(loRes, hiRes):
    
    ## flipping
    if tf.random.uniform([]) < 0.5:
        loRes = tf.image.flip_left_right(loRes)
        hiRes = tf.image.flip_left_right(hiRes)
    if tf.random.uniform([]) < 0.5:
        loRes = tf.image.flip_up_down(loRes)
        hiRes = tf.image.flip_up_down(hiRes)

    ## rotating    
    if tf.random.uniform([]) < 0.5:
        loRes = tf.image.rot90(loRes)
        hiRes = tf.image.rot90(hiRes)
    
    return (loRes, hiRes)

In [14]:
def load_train(data, features):
    example_image = tf.io.parse_single_example(data, features)

    loRes = tf.io.parse_tensor(example_image['lowresolution'], out_type = tf.uint8)
    hiRes = tf.io.parse_tensor(example_image['highresolution'], out_type = tf.uint8)

    (loRes, hiRes) = cropping(loRes, hiRes, random=True)
    (loRes, hiRes) = augmentations(loRes, hiRes)

    loRes = tf.reshape(loRes, (24, 24, 3))
    hiRes = tf.reshape(hiRes, (96, 96, 3))

    return (loRes, hiRes)

def load_test(data, features):
    example_image = tf.io.parse_single_example(data, features)

    loRes = tf.io.parse_tensor(example_image['lowresolution'], out_type = tf.uint8)
    hiRes = tf.io.parse_tensor(example_image['highresolution'], out_type = tf.uint8)

    (loRes, hiRes) = cropping(loRes, hiRes, center=True)

    loRes = tf.reshape(loRes, (24, 24, 3))
    hiRes = tf.reshape(hiRes, (96, 96, 3))

    return (loRes, hiRes)


In [15]:
def justload(filenames, batchSize, train=False):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)

    if train:
        dataset = dataset.map(load_train, num_parallel_calls=AUTO)
    else:
        dataset = dataset.map(load_test, num_parallel_calls=AUTO)
    
    dataset = dataset.cache()
    dataset = dataset.shuffle(batchSize)
    dataset = dataset.batch(batchSize)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(AUTO)

    return dataset

In [9]:
t3 = tf.constant([[[1, 3, 5, 7],
                   [9, 11, 13, 15]],
                  [[17, 19, 21, 23],
                   [25, 27, 29, 31]]
                  ])

print(tf.slice(t3,
               begin=[1, 1, 0],
               size=[1, 1, 2]))

tf.Tensor([[[25 27]]], shape=(1, 1, 2), dtype=int32)


### model

In [3]:
import keras.layers as kl
from tensorflow._api.v2.nn import depth_to_space
import numpy as np

In [7]:
example_tensor = np.random.rand(1, 32, 32, 3)

x = kl.Conv2D(32, 3, padding='same', activation=kl.LeakyReLU(alpha=0.2)) (example_tensor)
x2 = kl.Conv2D(64, 3, padding='same', activation=kl.LeakyReLU(alpha=0.2)) (x)

result = kl.concatenate([x, x2])

In [16]:
class ESRGAN(object):
	def RDB(self, x):
		x1 = kl.Conv2D(32, 3, padding="same", activation=kl.LeakyReLU(alpha=0.2))(x)
		x1 = kl.concatenate([x, x1])

		x2 = kl.Conv2D(32, 3, padding="same", activation=kl.LeakyReLU(alpha=0.2))(x1)
		x2 = kl.concatenate([x, x1, x2])

		x3 = kl.Conv2D(32, 3, padding="same", activation=kl.LeakyReLU(alpha=0.2))(x2)
		x3 = kl.concatenate([x, x1, x2, x3])

		x4 = kl.Conv2D(32, 3, padding="same", activation=kl.LeakyReLU(alpha=0.2))(x3)
		x4 = kl.concatenate([x, x1, x2, x3, x4])

		x5 = kl.Conv2D(64, 3, padding="same")(x4)
		x5 = kl.Lambda(lambda x : x * 0.2) (x5)
		xSkip = kl.Add() ([x, x5])

		# xSkip = kl.Lambda(lambda x : x * 0.2) (xSkip)

		return xSkip
	
	def RRDB(self, x_input):
		x = self.RDB(x_input)
		x = self.RDB(x)
		x = self.RDB(x)
		x = kl.Lambda(lambda x : x * 0.2) (x)
		x = kl.Add() ([x_input, x])
		return x

	def generator(self):
		input = Input(shape=(None, None, 3), name='input_image')
		xIn = kl.Rescaling(scale=1.0/255, offset=0.0)(input)

		# conv block with leaky activation
		x = kl.Conv2D(64, 9, padding="same", activation=kl.LeakyReLU(alpha=0.2))(xIn)

		x1 = self.RRDB(x)

		# residual in residual blocks
		for block in range(15):
			x1 = self.RRDB(x1)
		
		xSkip = kl.Conv2D(64, 3, padding="same")(x1)
		xSkip = kl.concatenate([x, xSkip])

		# upscaling 
		x = kl.Conv2D(128, 3, padding="same")(xSkip)
		x = tf.nn.depth_to_space(x, 2)
		x = kl.LeakyReLU(alpha=0.2) (x)

		x = kl.Conv2D(64, 3, padding="same")(x)
		x = tf.nn.depth_to_space(x, 2)
		x = kl.LeakyReLU(alpha=0.2) (x)

		# back to conv blocks
		x = kl.Conv2D(3, 9, padding="same", activation="tanh") (x)
		output = kl.Rescaling(scale=127.5, offset=127.5) (x)

		return Model(input, output, name='generator!!')


	def discriminator():
		pass

In [17]:
esrgan = ESRGAN()

gen = esrgan.generator()
# disc = esrgan.Discriminator(64, 0.2, 4)


In [21]:
gen.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 rescaling (Rescaling)          (None, None, None,   0           ['input_1[0][0]']                
                                3)                                                                
                                                                                                  
 conv2d (Conv2D)                (None, None, None,   15616       ['rescaling[0][0]']              
                                64)                                                           

In [26]:
gen.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 rescaling_3 (Rescaling)        (None, None, None,   0           ['input_3[0][0]']                
                                3)                                                                
                                                                                                  
 conv2d_93 (Conv2D)             (None, None, None,   15616       ['rescaling_3[0][0]']            
                                64)                                                         

In [19]:
gen.summary()

Model: "generator!!"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, None, None,  0           []                               
                                 3)]                                                              
                                                                                                  
 rescaling_5 (Rescaling)        (None, None, None,   0           ['input_image[0][0]']            
                                3)                                                                
                                                                                                  
 conv2d_265 (Conv2D)            (None, None, None,   15616       ['rescaling_5[0][0]']            
                                64)                                                     