In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers 
import tensorflow.compat.v1 as tf1
import tensorflow_gan as tgan
import tensorflow_datasets as tfds
import tensorflow_gan as tfgan
import time
import matplotlib.pyplot as plt
import os
import cv2 as cv
from mtcnn.mtcnn import MTCNN
face_detector = MTCNN()
import warnings
warnings.filterwarnings("ignore")

from IPython import display

In [None]:
params = {'batch_size': 20,
         'image_dims': (192, 128),
         'noise_dims': 100,
         'ds_size': 162770,
         'start_epoch': 1,
         'end_epoch': 20}

In [None]:
ds_bucket = '/Users/jeremiahherberg/Downloads/celeba-dataset/img_align_celeba/'

In [None]:
@tf.autograph.experimental.do_not_convert
def input_function(params, mode=None):
    batch_size = params['batch_size']
    resized_height, resized_width = params['image_dims'] #s/b (192, 128)
    noise_dims = params['noise_dims'] #this can be an arbitrary number
    #determine if test or train split is being used
#     shuffle = (mode == tf.estimator.ModeKeys.TRAIN)
    just_noise = (mode =='generator')
    
    

    if just_noise:
        #ds to generate images
        noise_dataset = (tf.data.Dataset.from_tensors(0)
                        .map(lambda _ : tf.random.normal([batch_size, noise_dims]))
                        .repeat(1 if just_noise else None))
        return noise_dataset

 
    
    
    def find_face(img):
        #add dim at the zero axis Shape will be from (x, y, z) -> (None, x, y, z)
        image_tensor = tf.expand_dims(img, 0)
        #undo the above line -- this is needed due to TF not allowing a filtered tensor py_function
        image_tensor = tf.gather(image_tensor, 0)
        #convert to array so face detector can read image
        image_array = np.array(image_tensor)        
        try:
            '''If the face detector cannot detect a face, the first line in the try statement
            will produce an IndexError. If this happens, pass the entire resized image into our model.
            Given the small number of pictures in the DS this will apply to, the affect on the model
            should be small, and a try/except statement should be more efficient than checking the len
            of the number of faces detected on every photo in the DS.'''
            x_start, y_start, x_len, y_len = face_detector.detect_faces(image_array)[0]['box']
            #image array will only cover the detected face
            face_array = image_array[y_start:(y_start + y_len), x_start:(x_start + x_len)]
        except IndexError:
            face_array = image_array
        #resize array to match input of model
        face_resized = tf.image.resize_with_pad(face_array,
                                               target_height=resized_height, 
                                               target_width=resized_width)
        #convert tensor values to between -1 and 1 (0 to 255 -> -1 to 1)
        face_resized = (tf.cast(face_resized, tf.float32) - 127.5) / 127.5

        return face_resized
    

    def find_face_dot_map_compatible(img):
        image = img['image'] #filter
        [image,] = tf.py_function(find_face, [image], [tf.float32])
        return image
    


    image_dataset = (tfds.load('celeb_a',
                              split='train',
                              data_dir=ds_bucket,
                              download=False, #tf record files should already be in the ds_bucket directory
                              shuffle_files=True)
                     .map(find_face_dot_map_compatible)
                    .cache()
                    .repeat())
    
    
#     if shuffle:
#         image_dataset = image_dataset.shuffle(buffer_size=205000, 
#                                               reshuffle_each_iteration=True)
        
    image_dataset = (image_dataset.batch(batch_size,
                                        drop_remainder=True)
                                        .prefetch(tf.data.experimental.AUTOTUNE))

    
    
    

    return image_dataset
        



In [None]:
def generator_input_function(params):
    batch_size = params['batch_size']
    noise_dims = params['noise_dims'] #this can be an arbitrary number
    just_noise = True
    #ds to generate images

    noise_dataset = tf.random.normal([batch_size, noise_dims])
    return noise_dataset

In [None]:
def discriminator_model(input_shape=[192, 128, 3]):
    
    
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(32, (5, 5), padding='same',
                                     input_shape=input_shape)) #changed from 225*146
    model.add(layers.MaxPooling2D((2, 2)))
#     model.add(layers.LeakyReLU()) #- examples of GANs I have found use LeakyReLU after each COnv2D layer
#     model.add(layers.Dropout(0.0)) #look into and consider adding dropout

    model.add(layers.Conv2D(64, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2,2))
    
    model.add(layers.Conv2D(128, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2,2))
    
    model.add(layers.Conv2D(256, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2, 2))
    
    model.add(layers.Conv2D(256, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2, 2))
    
    model.add(layers.Conv2D(512, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2, 2))
    
    model.add(layers.Flatten())
    model.add(layers.Dense(512))
    model.add(layers.Dense(1))

    return model
discriminator_model().summary()

In [None]:
def generator_model(params=params): #178 * 218
    input_shape = params['noise_dims']
    input_shape = (input_shape,)
    model = tf.keras.Sequential()
    model.add(layers.Dense(1536, use_bias=False, input_shape=input_shape))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((3, 2, 256)))

    model.add(layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False,
                                    activation='tanh'))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same'))


    model.add(layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same'))

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same'))

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same'))
    #number of filters on last layer must be equal to 3 (one for each of R, G, B)
    model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same')) 



    return model
generator_model().summary()

In [None]:
generator = generator_model()
discriminator = discriminator_model()

In [None]:
def loss_function(real_output, fake_output):
    '''
    takes output from a discriminator GAN model for a batch of real and fake images
    and returns the loss for the discriminator and generator in a GAN model
    
    args:
        real_output: output from a batch of real images passed through a discriminator, a tensor
        shapped (batch_size, 1)
        
        fake_output: output from a batch of fake images generated by a generator GAN model, 
        passed into a discriminator GAN model, a tensor shapped (batch_size, 1)
        
    returns:
        generator_loss: the generator loss for the training batch
        
        discriminator_loss: the discriminator loss for the training batch
    '''
    #I believe the entire function will need to be inside "tpu_strategy.scope()"
    #review https://www.tensorflow.org/tutorials/distribute/custom_training
    
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    generator_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
    
    discriminator_loss_fake = cross_entropy(tf.zeros_like(fake_output), fake_output)
    discriminator_loss_real = cross_entropy(tf.ones_like(real_output), real_output)
    discriminator_loss = discriminator_loss_fake + discriminator_loss_real
    
    return generator_loss, discriminator_loss

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
@tf.function
def training_step(params, real_images):
    generator_input = generator_input_function(params)


    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        fake_images = generator(generator_input, training=True)

        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(fake_images, training=True)

        generator_loss, discriminator_loss = loss_function(real_output, fake_output)

    generator_gradients = gen_tape.gradient(generator_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(discriminator_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

In [None]:
def train_gan(params):
    ds_size = params['ds_size'] #celeba DS is 162,770 images for training
    start_epoch = params['start_epoch']
    end_epoch = params['end_epoch']
    batch_size = params['batch_size']
    steps_per_epoch = int(tf.math.ceil(ds_size / batch_size))
    
    real_images = input_function(params)
    start_time = time.time()
    start_time_ = time.time()
    
    for epoch in np.arange(start_epoch, end_epoch + 1):
        for step in real_images:
            print(step.shape)
            training_step(params, step)

            
            #display time stats every 10 epochs
            if (epoch) % 10 == 0:
                end_time = time.time()
                set_time = end_time - start_time_
                set_minutes = int(set_time / 60)
                set_seconds = round(set_time % 60, 2)
                total_time = end_time - start_time
                total_hours = int(total_time / (60 * 60))
                total_minutes = int(total_time % (60 * 60))
                total_seconds = round(total_time % 60, 2)
                display.clear_output()

                print('Set of 10 epochs, ending in epoch {} has taken {} minutes, {} seconds'.format(epoch, 
                                                                                                     set_minutes,
                                                                                                     set_seconds))
                print('Time elapsed through epoch {} is {} hours, {} minutes, {} seconds').format(epoch,
                                                                                                  total_hours,
                                                                                                  total_minutes,
                                                                                                  total_seconds)
                #stop training after 4 hours
                if total_hours >= 4:
                    break
            
            #if (epoch + 1) % 50 == 0:
                #save discriminator and generator via model.save
                #find how how frequently want to save the model
        
            
    
    

In [None]:
train_gan(params)