In [68]:
import tensorflow as tf
from tensorflow.keras.layers import LeakyReLU, Conv2DTranspose, Conv2D, ReLU, PReLU, Add, Concatenate, Activation, BatchNormalization, Dense, Flatten
from tensorflow.keras import Model, Input
import numpy as np
import matplotlib.pyplot as plt
import os
from time import time
from tensorflow.keras.optimizers import Adam

import skimage.transform

import seaborn as sns
import matplotlib.image as mpimg
from PIL import Image
import cv2


In [55]:
def discriminator(image_shape):
    '''
    Build a 70x70 PatchGAN discriminator
    '''
    input_img = Input(shape = image_shape)
    # k3n64s1
    d = Conv2D(64, (3,3), strides = (1,1), padding = 'same' )(input_img)
    d = LeakyReLU(0.2)(d)
    
    # k3n64s2
    d = Conv2D(64, (3,3), strides = (2,2), padding = 'same' )(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(0.2)(d)
    
#     # k3n128s1
#     d = Conv2D(128, (3,3), strides = (1,1), padding = 'same' )(d)
#     d = BatchNormalization()(d)
#     d = LeakyReLU(0.2)(d)
    
#     # k3n128s2
#     d = Conv2D(128, (3,3), strides = (2,2), padding = 'same' )(d)
#     d = BatchNormalization()(d)
#     d = LeakyReLU(0.2)(d)
    
#     # k3n256s1
#     d = Conv2D(256, (3,3), strides = (1,1), padding = 'same' )(d)
#     d = BatchNormalization()(d)
#     d = LeakyReLU(0.2)(d)
    
#     # k3n256s2
#     d = Conv2D(256, (3,3), strides = (2,2), padding = 'same' )(d)
#     d = BatchNormalization()(d)
#     d = LeakyReLU(0.2)(d)
    
#     # k3n512s1
#     d = Conv2D(512, (3,3), strides = (1,1), padding = 'same' )(d)
#     d = BatchNormalization()(d)
#     d = LeakyReLU(0.2)(d)
    
#     # k3n512s2
#     d = Conv2D(512, (3,3), strides = (2,2), padding = 'same' )(d)
#     d = BatchNormalization()(d)
#     d = LeakyReLU(0.2)(d)
    
    d = Flatten()(d)
    d = Dense(1024)(d)
    d = LeakyReLU(0.2)(d)

    d_out = Dense(1, activation = 'sigmoid')
    
    disc = Model(input_img, d_out)

    return disc

def resnet_block(n_filters, input_layer):
    # k3n64s1
    g = Conv2DTranspose(n_filters, (3,3), strides=(1,1), padding='same')(input_layer)
    g = BatchNormalization()(g)
    g = PReLU()(g)
    # k3n64s1
    g = Conv2DTranspose(n_filters, (3,3), strides=(1,1), padding='same')(g)
    g = BatchNormalization()(g)
    
#     g = Concatenate()([g, input_layer])
    g = Add()([g, input_layer])
    
    return g


def generator(image_shape, n_resnets):

    input_image = Input(shape=image_shape)
    
    g = Conv2DTranspose(64, (3,3), strides=(1,1), padding='same')(input_image)
    g = PReLU()(g)
    
    g_res = resnet_block(64, g)
    for i in range(n_resnets - 1):
        g_res = resnet_block(64, g_res)

    g_res = Conv2DTranspose(256, (3,3), strides=(1,1), padding='same')(g_res)
    g_res = BatchNormalization()(g_res)
    
    g = Add()([g, g_res])
    
    g = Concatenate()([g, g_res]) # Should be add instead of concatenate I think
    
    g = Conv2DTranspose(256, (3,3), strides=(1,1), padding='same')(g)
    g = tf.nn.depth_to_space(g, 2)
    g = PReLU()(g)

    g = Conv2DTranspose(64, (3,3), strides=(1,1), padding="same")(g)
    g = tf.nn.depth_to_space(g, 2)
    g = PReLU()(g)

    g = Conv2DTranspose(3, (3,3), strides=(1,1), padding="same")(g)
    
    output_image = Activation("sigmoid")(g)

    generator = Model(input_image, output_image)
    
    return generator

def feature_extractor(i, j):
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False
    features_list = [layer.output for layer in vgg.layers]
    
    for (k, layer) in zip(range(len(vgg.layers)), vgg.layers):
        if layer.name == "block"+str(i)+"_conv"+str(j):
            break
            
    return Model(vgg.input, features_list[k])
    

entropy_loss = tf.keras.losses.BinaryCrossentropy(from_logits = False)
mse_loss = tf.keras.losses.MeanSquaredError()

def discriminator_loss(real_output, fake_output):
    real_loss = entropy_loss(tf.ones_like(real_output), real_output)
    fake_loss = entropy_loss(tf.zeros_like(fake_output), fake_output)

    total_loss = real_loss + fake_loss

    return total_loss

def generator_loss(fake_output):
    loss = entropy_loss(tf.ones_like(fake_output), fake_output)
    return loss
        
def content_loss(hr_images, sr_images):
    
    hr_features = feat_ext(hr_images)
    sr_features = feat_ext(sr_images)
    
    content_loss = mse_loss(hr_features, sr_features)
    return content_loss

In [56]:
def progress_update(model, input_img, epoch):
    prediction = model(input_img, training = False)

    display_list = [input_img[0], prediction[0], cycled[0]]
    titles = ['Downsampled Image', 'Upsampled Image', 'Original Image']

    plt.figure(figsize=(12,12))

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(titles[i])
        plt.imshow(display_list[i] * 0.5 + 0.5) # scale images back from [-1,1] range to [0,1] for plotting
        plt.axis('off')

    plt.tight_layout()
    
    plt.savefig('image_at_epoch{:04d}.png'.format(epoch+1))
    
    plt.show()
    
def show_loss_history(losses, names, title):
    epoch_idx = range(1, len(losses)+1)
    plt.figure(figsize=(15,12))
    sns.set()
    for i in range(len(names)):
        sns.lineplot(x = epoch_idx, y = losses[:,i], label = names[i])
    plt.xlabel('Epoch')
    plt.legend(loc = 'best')
    plt.title(title)

    plt.savefig(title + '.png')
    print('Saved ' + title)

    plt.close()

In [61]:
def train_step(lr_images, hr_images):
#     with tf.GradientTape() as gen_adv_tape, tf.GradientTape() as gen_content_tape, tf.GradientTape() as disc_tape:
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:

        sr_images = gen(lr_images, training = True)

        real_output = disc(hr_images, training = True)
        fake_output = disc(sr_images, training = True)
        
        disc_loss = discriminator_loss(real_output, fake_output)
        
        gen_adv_loss = generator_loss(fake_output)
        gen_content_loss = content_loss(hr_images, sr_images)
        gen_total_loss = gen_adv_loss + gen_content_loss

    # Calculate gradients
    # gen_adv_gradients = gen_adv_tape.gradient(gen_adv_loss, gen.trainable_variables)
    # gen_content_gradients = gen_content_tape.gradient(gen_content_loss, gen.trainable_variables)
    
    gen_gradients = gen_tape.gradient(gen_total_loss, gen.trainable_variables)
    
    disc_gradients = disc_tape.gradient(disc_loss, disc.trainable_variables)
    
    # Apply gradients
    # gen_optimizer.apply_gradients(zip(gen_adv_gradients, gen.trainable_variables))
    # gen_optimizer.apply_gradients(zip(gen_content_gradients, gen.trainable_variables))
    
    gen_optimizer.apply_gradients(zip(gen_gradients, gen.trainable_variables))
    disc_optimizer.apply_gradients(zip(disc_gradients, disc.trainable_variables))
    
    return gen_adv_loss, gen_content_loss, disc_loss

In [77]:
def downSample(original_image, factor):
    orig_dim = (original_image.shape[1], original_image.shape[0]) # (width, height)

    width = int(original_image.shape[1] / factor)
    height = int(original_image.shape[0] / factor)
    dim = (width, height)
    downsampled_img = cv2.resize(original_image, dim, interpolation = cv2.INTER_AREA)
    return downsampled_img


def downSampleAll(images, factor):

    downsampled = [images[0]] * len(images)

    for i in range(len(images)):
        downsampled[i] = downSample(images[i], factor)

    return downsampled


In [78]:
def train(data, batch_size, start_epoch, n_epochs, factor):
    sample_image = next(iter(data))
    # create models directory if doesn't exist
    dirName = 'models'
    if not os.path.exists(dirName):
        os.mkdir(dirName)
        print("Directory " , dirName ,  " Created ")
    else:    
        print("Directory " , dirName ,  " already exists")
    
    for epoch in range(start_epoch, n_epochs):
        losses_per_epoch = np.zeros(3)
        start = time()
        print('Starting epoch {}'.format(epoch+1))
        
        n = 0
        for hr_images in data:
            print(hr_images.shape)
            lr_images = downSampleAll(hr_images, factor)
            losses_per_epoch += train_step(lr_images, hr_images)
            n += 1

        print('Epoch {} time: {:.2f}'.format(epoch+1, time()-start))
        
        losses_per_epoch /= n*batch_size
        
        if epoch == start_epoch:
            losses = np.asarray([losses_per_epoch])
        else:
            losses = np.vstack([losses, losses_per_epoch])

        show_loss_history(losses, ["gen_adv","gen_content", "disc"], title = 'Training History')    
            
        if (epoch + 1) % 5 == 0:
            progress_update(gen, sample_image, epoch)
        
        # save models every 10 epochs
        if (epoch + 1) % 10 == 0:
            # create directory for epoch
            dirName = 'models/models_at_epoch{:04d}'.format(epoch+1)
            if not os.path.exists(dirName):
                os.mkdir(dirName)
                
            gen.save(dirName + '/gen')
            disc.save(dirName + '/disc')


In [74]:
def load_data():
    data = np.ones((32, 128, 128, 3))
    return data

In [79]:
def loadImages(directory, images):
    # i = 0
    for filename in os.listdir(directory):
        if filename.endswith('.jpg'): # and i < 10 for testing
            image = mpimg.imread(os.path.join(directory,filename))
            images.append(image)
    return images

In [75]:
factor = 4
i = 5
j = 4

image_shape = (128,128,3)

data = load_data()

gen = generator(image_shape, 16)
disc = discriminator((image_shape[0]*factor, image_shape[1]*factor, 3))

gen_optimizer = Adam(lr = 0.0002, beta_1=0.5)
disc_optimizer = Adam(lr = 0.0002, beta_1=0.5)

feat_ext = feature_extractor(i, j) 

batch_size = 32
start_epoch = 0
n_epochs = 10
bufferSize = 1000

data = tf.data.Dataset.from_tensor_slices(data).shuffle(bufferSize).batch(batch_size)


train(data, batch_size, start_epoch, n_epochs, factor)


Directory  models  already exists
Starting epoch 1
(32, 128, 128, 3)


To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Epoch 1 time: 30.00
Saved Tr