In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers 
import tensorflow.compat.v1 as tf1
import tensorflow_gan as tgan
import tensorflow_datasets as tfds
import tensorflow_gan as tfgan
import matplotlib.pyplot as plt
import os
import cv2 as cv
from mtcnn.mtcnn import MTCNN
face_detector = MTCNN()
import warnings
warnings.filterwarnings("ignore")



In [None]:
ds_bucket = '/Users/jeremiahherberg/Downloads/celeba-dataset/img_align_celeba/'

In [None]:
@tf.autograph.experimental.do_not_convert
def input_function(mode, params):
    batch_size = params['batch_size']
    #confirm this will not break TPUEstimator -- if it does, will need to switch to 'noise_dims'
    resized_height, resized_width = params['image_dims'] #s/b (225, 146)
    #height_dim * width_dim * 3 (RGB_dim)
    noise_dims = params['noise_dims'] #this can be an arbitrary number
    #determine if test or train split is being used
    split = 'train' if mode == tf.estimator.ModeKeys.TRAIN else 'test'
    shuffle = (mode == tf.estimator.ModeKeys.TRAIN)
    just_noise = (mode ==tf.estimator.ModeKeys.PREDICT)
    
    #ds to generate images
    noise_dataset = (tf.data.Dataset.from_tensors(0)
                    .map(lambda _ : tf.random.normal([batch_size, noise_dims]))
                    .repeat(1 if just_noise else None))
    

    if just_noise:
        return noise_dataset

 
    
    
    def find_face(img):
        #add dim at the zero axis Shape will be from (x, y, z) -> (None, x, y, z)
        image_tensor = tf.expand_dims(img, 0)
        #undo the above line -- this is needed due to TF not allowing a filtered tensor py_function
        image_tensor = tf.gather(image_tensor, 0)
        #convert to array so face detector can read image
        image_array = np.array(image_tensor)        
        try:
            '''If the face detector cannot detect a face, the first line in the try statement
            will produce an IndexError. If this happens, pass the entire resized image into our model.
            Given the small number of pictures in the DS this will apply to, the affect on the model
            should be small, and a try/except statement should be more efficient than checking the len
            of the number of faces detected on every photo in the DS.'''
            x_start, y_start, x_len, y_len = face_detector.detect_faces(image_array)[0]['box']
            #image array will only cover the detected face
            face_array = image_array[y_start:(y_start + y_len), x_start:(x_start + x_len)]
        except IndexError:
            face_array = image_array
        #resize array to match input of model
        face_resized = tf.image.resize_with_pad(face_array,
                                               target_height=resized_height, 
                                               target_width=resized_width)
        #convert tensor values to between -1 and 1 (0 to 255 -> -1 to 1)
        face_resized = (tf.cast(face_resized, tf.float32) - 127.5) / 127.5

        return face_resized
    

    def find_face_dot_map_compatible(img):
        image = img['image'] #filter
        [image,] = tf.py_function(find_face, [image], [tf.float32])
        return image
    


    image_dataset = (tfds.load('celeb_a',
                              split=split,
                              data_dir=ds_bucket,
                              download=False, #tf record files should already be in the ds_bucket directory
                              shuffle_files=True)
                     .map(find_face_dot_map_compatible)
                    .cache()
                    .repeat())
    
    
    if shuffle:
        image_dataset = image_dataset.shuffle(buffer_size=205000, 
                                              reshuffle_each_iteration=True)
        
    image_dataset = (image_dataset.batch(batch_size,
                                        drop_remainder=True)
                                        .prefetch(tf.data.experimental.AUTOTUNE))

    
    
    

    return tf.data.Dataset.zip((noise_dataset, image_dataset))
        



In [None]:
def discriminator_model(input_shape=[192, 128, 3]):
    
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(32, (5, 5), padding='same',
                                     input_shape=input_shape)) #changed from 225*146
    model.add(layers.MaxPooling2D((2, 2)))
#     model.add(layers.LeakyReLU()) #- examples of GANs I have found use LeakyReLU after each COnv2D layer
#     model.add(layers.Dropout(0.0)) #look into and consider adding dropout

    model.add(layers.Conv2D(64, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2,2))
    
    model.add(layers.Conv2D(128, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2,2))
    
    model.add(layers.Conv2D(256, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2, 2))
    
    model.add(layers.Conv2D(256, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2, 2))
    
    model.add(layers.Conv2D(512, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2, 2))
    
    model.add(layers.Flatten())
    model.add(layers.Dense(512))
    model.add(layers.Dense(1))

    return model
discriminator_model().summary()

In [None]:
def generator_model(input_shape=(1028,)): #178 * 218
    model = tf.keras.Sequential()
    model.add(layers.Dense(3078, use_bias=False, input_shape=input_shape))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((3, 2, 256)))

    model.add(layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False,
                                    activation='tanh'))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same'))


    model.add(layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same'))

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same'))

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same'))

    model.add(layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same'))



    return model
generator_model().summary()