<a href="https://colab.research.google.com/github/consequencesunintended/RefinementGAN/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import keras as kr
import numpy as np
import matplotlib.pyplot as plt
import time
from IPython import display

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Activation, Reshape
from tensorflow.keras.layers import Convolution2D
from tensorflow.keras.layers import concatenate
from keras.engine.topology import Layer

In [None]:
print ("Tensorflow version: {}".format( tf.__version__ ) )
print ("Keras version: {}".format( kr.__version__ ) )
print ("Numpy version: {}".format( np.__version__ ) )

In [None]:
DIMENSION = 64
FC_DIM = 128
BATCH_SIZE = 16
EPOCHS = 1000
SHUFFLE_BUFFER_SIZE = 100
MODEL_NUMBER = 1 # use model numbers 1-3 for different variations of the architecture
LEARNING_RATE = 1e-4
NUM_TEST_IMG = 4

In [None]:
data_dir = "data/real images/*.jpg"
real_images_ds = tf.data.Dataset.list_files(data_dir)

In [None]:
data_dir = "data/synthetic images/*.jpg"
synthetic_images_ds = tf.data.Dataset.list_files(data_dir)

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
def decode_img(img):
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, [DIMENSION, DIMENSION])

    img_8 = tf.image.resize(img, [DIMENSION // 8, DIMENSION // 8])
    img_4 = tf.image.resize(img, [DIMENSION // 4, DIMENSION // 4])
    img_2 = tf.image.resize(img, [DIMENSION // 2, DIMENSION // 2])
    
    if tf.random.uniform(()) > 0.5:
        img_8 = tf.image.flip_left_right(img_8)
        img_4 = tf.image.flip_left_right(img_4)
        img_2 = tf.image.flip_left_right(img_2)
        img = tf.image.flip_left_right(img)
        
            
    return img_2, img_4, img_8, img

In [None]:
def process_path(file_path):
    img = tf.io.read_file(file_path)
    img_2, img_4, img_8, img = decode_img(img)
    return img_2, img_4, img_8, img

In [None]:
real_images_ds = real_images_ds.map(process_path, num_parallel_calls=AUTOTUNE)
synthetic_images_ds = synthetic_images_ds.map(process_path, num_parallel_calls=AUTOTUNE)

In [None]:
train_dataset = real_images_ds.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
test_dataset = synthetic_images_ds.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
NUM_IMAGES = len(real_images_ds)

In [None]:
NUM_IMAGES

In [None]:
syn_img_2, syn_img_4, syn_img_8, syn_img = next(iter(test_dataset))

In [None]:
syn_img.shape

In [None]:
def plotImages(images_arr):
    fig, axes = plt.subplots(1, NUM_TEST_IMG, figsize=(DIMENSION,DIMENSION))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img.numpy())                       
        ax.axis('off')    
    plt.tight_layout()
    plt.show()

In [None]:
plotImages(syn_img[:NUM_TEST_IMG])

In [None]:
plotImages(syn_img_2[:NUM_TEST_IMG])

In [None]:
plotImages(syn_img_4[:NUM_TEST_IMG])

In [None]:
plotImages(syn_img_8[:NUM_TEST_IMG])

In [None]:
class ResizeNN(Layer):
    def __init__(self, image_size=(512, 512), **kwargs):
        self.image_size = image_size[0], image_size[1]
        super(ResizeNN, self).__init__(**kwargs)

    def call(self, inputs, **kwargs):
        return tf.image.resize(inputs, self.image_size, method='nearest')

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.image_size[0], self.image_size[1], input_shape[-1]

In [None]:
def disc_encoder():
    input_shape = [DIMENSION,DIMENSION,3]
    kernel= 3
    filters = 64

    inputs = Input(shape=input_shape)
    
    block_1 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(inputs)
    block_1 = tf.nn.elu(block_1)
    
    block_1 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_1)
    block_1 = tf.nn.elu(block_1)
    block_1 = ResizeNN([DIMENSION, DIMENSION])(block_1)
    
    print("Encoder - Block 1 Created!")
    
    block_2 = Convolution2D(2*filters, (kernel, kernel), strides=(1,1), padding="same")(block_1)
    block_2 = tf.nn.elu(block_2)
    
    block_2 = Convolution2D(2*filters, (kernel, kernel), strides=(1,1), padding="same")(block_2)
    block_2 = tf.nn.elu(block_2)
    block_2 = ResizeNN([DIMENSION//2, DIMENSION//2])(block_2)
    
    print("Encoder - Block 2 Created!")
    
    block_3 = Convolution2D(3*filters, (kernel, kernel), strides=(1,1), padding="same")(block_2)
    block_3 = tf.nn.elu(block_3)
    
    block_3 = Convolution2D(3*filters, (kernel, kernel), strides=(1,1), padding="same")(block_3)
    block_3 = tf.nn.elu(block_3)
    block_3 = ResizeNN([DIMENSION//4, DIMENSION//4])(block_3)
    
    print("Encoder - Block 3 Created!")
    
    block_4 = Convolution2D(4*filters, (kernel, kernel), strides=(1,1), padding="same")(block_3)
    block_4 = tf.nn.elu(block_4)
    
    block_4 = Convolution2D(4*filters, (kernel, kernel), strides=(1,1), padding="same")(block_4)
    block_4 = tf.nn.elu(block_4)
    block_4 = ResizeNN([DIMENSION//8, DIMENSION//8])(block_4)
    
    block_4_flatten  = Flatten()(block_4)
    dense_1 = Dense(8*8*4*filters)(block_4_flatten)
    dense_2 = Dense(FC_DIM)(dense_1)

    outputs = (dense_2)   
    
    print("Encoder - Block 4 Created!")

    return Model(inputs=inputs, outputs=outputs, name="disc_encoder")

In [None]:
def disc_decoder():
    kernel= 3
    filters = 64

    inputs = Input(shape=(FC_DIM,))
    
    dense_1 = Dense(8*8*filters)(inputs)
    dens_1_reshaped = Reshape([8,8,filters])(dense_1)
    
    block_1 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(dens_1_reshaped)
    block_1 = tf.nn.elu(block_1)
    
    block_1 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_1)
    block_1 = tf.nn.elu(block_1)
    block_1 = ResizeNN([DIMENSION//4, DIMENSION//4])(block_1)
    
    print("Decoder - Block 1 Created!")
    
    block_2 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_1)
    block_2 = tf.nn.elu(block_2)
    
    block_2 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_2)
    block_2 = tf.nn.elu(block_2)
    block_2 = ResizeNN([DIMENSION//2, DIMENSION//2])(block_2)
    
    print("Decoder - Block 2 Created!")
    
    block_3 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_2)
    block_3 = tf.nn.elu(block_3)
    
    block_3 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_3)
    block_3 = tf.nn.elu(block_3)
    block_3 = ResizeNN([DIMENSION, DIMENSION])(block_3)
    
    print("Decoder - Block 3 Created!")
    
    block_4 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_3)
    block_4 = tf.nn.elu(block_4)
    
    block_4 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_4)
    block_4 = tf.nn.elu(block_4)
    
    block_4 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_4)
    block_4 = tf.nn.elu(block_4)    
    
    block_4 = Convolution2D(3, (kernel, kernel), padding="same")(block_4)
    block_4 = Activation("sigmoid")(block_4)

    outputs = (block_4)   
    
    print("Decoder - Block 4 Created!")

    return Model(inputs=inputs, outputs=outputs, name="disc_decoder")

In [None]:
def make_discriminator_model():
    
    e_model = disc_encoder()
    d_model = disc_decoder()
    
    model = tf.keras.Sequential()

    model.add(e_model)
    model.add(d_model)
    
    print("Discriminator - Model Generated!")
    
    return model

In [None]:
discriminator = make_discriminator_model()

In [None]:
def make_generator_model( model_number = 1 ):
    kernel= 3
    filters = 64

    input_image_8 = Input(shape=[DIMENSION//8,DIMENSION//8,3])
    input_image_4 = Input(shape=[DIMENSION//4,DIMENSION//4,3])
    input_image_2 = Input(shape=[DIMENSION//2,DIMENSION//2,3])
    input_image   = Input(shape=[DIMENSION,DIMENSION,3])
    
    block_1 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(input_image_8)
    block_1 = tf.nn.elu(block_1)
    
    block_1 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_1)
    block_1 = tf.nn.elu(block_1)
    block_1 = ResizeNN([DIMENSION//4, DIMENSION//4])(block_1)
    
    print("Generator/Refiner - Block 1 Created!")
    
    if ( model_number == 2 or model_number == 3 ):
        block_2 = concatenate([input_image_4, block_1])
        block_2 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_2)
    else:
        block_2 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_1)
    block_2 = tf.nn.elu(block_2)
    
    block_2 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_2)
    block_2 = tf.nn.elu(block_2)
    block_2 = ResizeNN([DIMENSION//2, DIMENSION//2])(block_2)
    
    print("Generator/Refiner - Block 2 Created!")
    
    if ( model_number == 3 ):
        block_3 = concatenate([input_image_2, block_2])
        block_3 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_3)        
    else:
        block_3 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_2)        
        
    block_3 = tf.nn.elu(block_3)
    
    block_3 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_3)
    block_3 = tf.nn.elu(block_3)
    block_4 = ResizeNN([DIMENSION, DIMENSION])(block_3)
    
    print("Generator/Refiner - Block 3 Created!")

    block_4 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_4)
    block_4 = tf.nn.elu(block_4)
    
    block_4 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_4)
    block_4 = tf.nn.elu(block_4)
    
    block_4 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(block_4)
    block_4 = tf.nn.elu(block_4)    
    
    block_4 = Convolution2D(3, (kernel, kernel), padding="same")(block_4)
    block_4 = Activation("sigmoid")(block_4)

    outputs = (block_4)   
    
    print("Generator/Refiner - Model Generated!")

    return Model(inputs=[input_image, input_image_2, input_image_4, input_image_8], outputs=outputs, name="refiner")

In [None]:
generator = make_generator_model(MODEL_NUMBER)

In [None]:
def calculate_LGAN(v, Dv):
    
    diff = tf.abs(v - Dv)
    
    return tf.reduce_mean(diff)

In [None]:
def get_loss(k_t, x, D_x, G_z, D_G_z, outputs):
    
    lambda_r = 0.2
    lambda_k = 0.001
    gamma = 0.75
    
    LGAN_x = calculate_LGAN(x, D_x)
    LGAN_gz = calculate_LGAN(G_z, D_G_z)
    
    D_v_hr = outputs
    G_v_lr = G_z
    LRCN_z = calculate_LGAN(D_v_hr, G_v_lr)
    
    D_loss = LGAN_x - k_t * LGAN_gz
    G_loss = ( 1.0 - lambda_r ) * LGAN_gz + ( lambda_r ) * LRCN_z
          
    k_tp = k_t + lambda_k * (gamma * LGAN_x - LGAN_gz)
        
    return G_loss, D_loss, k_tp

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE,)
discriminator_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

In [None]:
 eval_input = [syn_img[:NUM_TEST_IMG], syn_img_2[:NUM_TEST_IMG], syn_img_4[:NUM_TEST_IMG], syn_img_8[:NUM_TEST_IMG]]

In [None]:
def train_step(k_t, values):
    
    inputs_2, inputs_4, inputs_8, outptuts = values
    
    D_gen_in = [outptuts, inputs_2, inputs_4, inputs_8]
    D_real_in = outptuts

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(D_gen_in, training=True)

        d_generated_images = discriminator(generated_images, training=True)
        discrimanted_images = discriminator(D_real_in, training=True)
                
        gen_loss, disc_loss, k_t = get_loss(k_t, D_real_in, discrimanted_images, generated_images, d_generated_images, outptuts)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)    

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return k_t

In [None]:
def display_refined_images(model, synthetic_images):
    
    predictions = model(synthetic_images, training=False)
    
    fig = plt.figure(figsize=(8,8))
    fig.suptitle('Synthetic Images', fontsize=16)
    for i in range(synthetic_images[0].shape[0]):
        
        plt.subplot(4, 4, i+1)        
        plt.imshow(synthetic_images[0][i, :, :] )            
        plt.axis('off')

    plt.show()
    
    fig = plt.figure(figsize=(8,8))
    fig.suptitle('Refined Images', fontsize=16)
    for i in range(predictions.shape[0]):
        
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :] )            
        plt.axis('off')

    plt.show()

In [None]:
STEP_SIZE = NUM_IMAGES // BATCH_SIZE

def train(epochs):
    
    k_t = 0.0    
    
    step = 0
    
    for epoch in range(epochs):
        
        start = time.time()
        
        g_loss = 0
        d_loss = 0

        for _ in range(STEP_SIZE):
            
            k_t = min(max(k_t, 0.0), 1.0)
            
            k_t = train_step(k_t, next(iter(train_dataset)))
            step += 1
            
            if ( step % 100 == 0 ):
                
                display.clear_output(wait=True)
                display_refined_images(generator, eval_input)
                        
        print ('{} seconds for epoch {}/{}'.format(time.time()-start, epoch + 1, epochs))

    display.clear_output(wait=True)
    display_refined_images(generator, eval_input)

In [None]:
%%time
train(EPOCHS)