In [None]:
!pip install tensorflow-gan mtcnn

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers 
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
# import tensorflow_gan as tfgan
import time
import matplotlib.pyplot as plt
import os
import cv2 as cv

import warnings
warnings.filterwarnings("ignore")

from IPython import display

In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
# This is the TPU initialization code that has to be at the beginning.
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))

In [None]:
from google.cloud import storage
try:
    from google.colab import auth
    auth.authenticate_user()
    credentials=None

except ModuleNotFoundError:

    from google.oauth2 import service_account

    credentials = service_account.Credentials.from_service_account_file( #file location of GCS private key
        'xx')

In [None]:
client = storage.Client(project='deepfake-research', credentials=credentials)
objects = client.list_blobs('celeba-ds-jh', prefix='celeba_all_preprocessed')
tfrecords = []
for object_ in objects:
    path = str(object_).split(', ')[1]
    gs_path = os.path.join('gs://celeba-ds-jh', path)
    tfrecords.append(gs_path)

In [None]:
params = {'batch_size': 128, #128
         'image_dims': (192, 128),
         'noise_dims': 100,
         'ds_size': 202599,
         'start_epoch': 1,
         'end_epoch': 19}

In [None]:
@tf.autograph.experimental.do_not_convert
def input_function(params, mode=None):
    batch_size = params['batch_size']
    resized_height, resized_width = params['image_dims'] #s/b (192, 128)
        
#todo -- improve documentation 
    
    
    def preprocess_image(img):
        #decode TFexample record
        features_dictionary = {
            'image': tf.io.FixedLenFeature([], tf.string)
        }
        features = tf.io.parse_single_example(img, features_dictionary)
        decoded_image = tf.io.decode_jpeg(features['image'], 3)

        #add dim at the zero axis Shape will be from (x, y, z) -> (None, x, y, z)
        image_tensor = tf.expand_dims(decoded_image, 0)
        #undo the above line -- this is needed due to TF not allowing a filtered tensor py_function
        image_tensor = tf.gather(image_tensor, 0)

        #convert tensor values to between -1 and 1 (0 to 255 -> -1 to 1)
        image_tensor = (tf.cast(image_tensor, tf.float32) - 127.5) / 127.5

        return image_tensor
    

    def dot_map_function(img): #no longer needed
        [image,] = tf.py_function(preprocess_image, [img], [tf.float32])#s/b tf.float32
        return image
    

    
    image_dataset = (tf.data.TFRecordDataset(filenames = [tfrecords]).
                     cache().
                     map(preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).
                     cache(). 
                     repeat())
        
    image_dataset = (image_dataset.batch(batch_size,
                                        drop_remainder=True,)
                                        .prefetch(tf.data.experimental.AUTOTUNE))


    return image_dataset
        



In [None]:
def generator_input_function(params):
    batch_size = params['batch_size']
    noise_dims = params['noise_dims'] #this can be an arbitrary number
#     just_noise = True
    #ds to generate images

    noise_dataset = tf.random.normal([batch_size, noise_dims])
    return noise_dataset

In [None]:
def discriminator_model(input_shape=[192, 128, 3]):
    
    #consider creating a singel model that takes two inputs provides two outputs
    
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(32, (5, 5), padding='same',
                                     input_shape=input_shape)) #changed from 225*146
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.LeakyReLU()) #- examples of GANs I have found use LeakyReLU after each COnv2D layer
#     model.add(layers.Dropout(0.0)) #look into and consider adding dropout

    model.add(layers.Conv2D(64, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2,2))
    model.add(layers.LeakyReLU())
    
    model.add(layers.Conv2D(128, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2,2))
    model.add(layers.LeakyReLU())
    
    model.add(layers.Conv2D(256, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2, 2))
    model.add(layers.LeakyReLU())
    
    model.add(layers.Conv2D(256, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2, 2))
    model.add(layers.LeakyReLU())
    
    model.add(layers.Conv2D(512, (5, 5), padding='same'))
    model.add(layers.MaxPooling2D(2, 2))
    model.add(layers.LeakyReLU())
    
    model.add(layers.Flatten())
    model.add(layers.Dense(512))
    model.add(layers.Dense(1))

    return model
# discriminator_model().summary()

In [None]:
def generator_model(params=params): #178 * 218
    input_shape = params['noise_dims']
    input_shape = (input_shape,)
    model = tf.keras.Sequential()
    model.add(layers.Dense(1536, use_bias=False, input_shape=input_shape))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((3, 2, 256)))

    model.add(layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False,
                                    activation='tanh'))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same'))


    model.add(layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same'))

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same'))

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same'))
    #number of filters on last layer must be equal to 3 (one for each of R, G, B)
    model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same')) #, activation='tanh' #tanh produces poor results on this layer



    return model
# generator_model().summary()

In [None]:
def loss_function(real_output, fake_output):
    '''
    takes output from a discriminator GAN model for a batch of real and fake images
    and returns the loss for the discriminator and generator in a GAN model
    
    args:
        real_output: output from a batch of real images passed through a discriminator, a tensor
        shapped (batch_size, 1)
        
        fake_output: output from a batch of fake images generated by a generator GAN model, 
        passed into a discriminator GAN model, a tensor shapped (batch_size, 1)
        
    returns:
        generator_loss: the generator loss for the training batch
        
        discriminator_loss: the discriminator loss for the training batch
    '''
    #I believe the entire function will need to be inside "tpu_strategy.scope()" (or maybe not?)
    #review https://www.tensorflow.org/tutorials/distribute/custom_training
    
    # with strategy.scope():
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True,reduction=tf.keras.losses.Reduction.NONE)
    generator_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
    #add some noise to target values
    extra_noise = tf.random.uniform((1,), 0, 0.1)
    discriminator_loss_fake = cross_entropy(tf.zeros_like(fake_output) + extra_noise, fake_output)
    discriminator_loss_real = cross_entropy(tf.ones_like(real_output) - extra_noise, real_output)
    discriminator_loss = discriminator_loss_fake + discriminator_loss_real
  
    return generator_loss, discriminator_loss

In [None]:
#create the models and optimizers in the strategy.scope
with strategy.scope():
    generator = generator_model()
    discriminator = discriminator_model()
    generator_optimizer = tf.keras.optimizers.Adam(1e-4)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
@tf.function
def training_step(params, real_images_, monitor_loss=False):
    generator_input_ = generator_input_function(params)

#make the below a function
#inputs real_images, generator_input
    def step(generator_input, real_images, monitor_loss_=False):
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            fake_images = generator(generator_input, training=True)

            real_output = discriminator(real_images, training=True)
            fake_output = discriminator(fake_images, training=True)

            generator_loss, discriminator_loss = loss_function(real_output, fake_output)
            # cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True, 
            #                                                    reduction=tf.keras.losses.Reduction.NONE)
            # generator_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
            # discriminator_loss_fake = cross_entropy(tf.zeros_like(fake_output), fake_output)
            # discriminator_loss_real = cross_entropy(tf.ones_like(real_output), real_output)
            # discriminator_loss = discriminator_loss_fake + discriminator_loss_real

        generator_gradients = gen_tape.gradient(generator_loss, generator.trainable_variables)
        discriminator_gradients = disc_tape.gradient(discriminator_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
        
        if monitor_loss_:
            avg_generator_loss = tf.reduce_mean(generator_loss)
            avg_discriminator_loss = tf.reduce_mean(discriminator_loss)
            gen_loss.append(avg_generator_loss)
            disc_loss.append(avg_discriminator_loss)

    #run strategy.run on the above function
    # step(generator_input_, real_images_) 
    strategy.run(step, args=(generator_input_, real_images_))

In [None]:
def display_generated_pic(gen_output_tensor, return_=False):
    '''
    Function to convert tensor from a range of -1 to 1 -> 0 to 1 and display the resulting image
    Will display the converted tensor as an image and can return a tensor that is converted as per above
    '''
    generated_image = (gen_output_tensor + 1) / 2
    plt.imshow(generated_image)
    plt.show()
    if return_:
        return generated_image

In [None]:
@tf.function
def train_gan(params):
    ds_size = params['ds_size'] #celeba DS is 162,770 images for training
    start_epoch = params['start_epoch']
    end_epoch = params['end_epoch']
    batch_size = params['batch_size']
    params_ = {}
    params_['noise_dims'] = params['noise_dims']

    params_['batch_size'] = 1
    steps_per_epoch = int(tf.math.ceil(ds_size / batch_size))


    # train_ds = strategy.experimental_distribute_datasets_from_function(
    #     lambda _ : input_function(params)
    # )
    
    real_images = input_function(params)
    # real_images = iter(train_ds)
    start_time = time.time()
    start_time_ = time.time()
    
    for epoch in np.arange(start_epoch, end_epoch + 1):
        counter = 0
        tme_start = time.time()
        for step in real_images:
            start = time.time()
            training_step(params, step)
            end = time.time()
            tme = end - start
            counter +=1
            if counter % 100 == 60:
            #remove below 3 lines once training is optimized
            #every 100 steps, starting at step 60, display some stats and a generated image
                display.clear_output(wait=True)
                print('step {} took {} seconds'.format(counter, tme))
                print('total time: {} // {} per step'. format ((end - tme_start), ((end- tme_start) / counter)))
                gen_input = generator_input_function(params_)
                gen_output = generator(gen_input)
                _ = display_generated_pic(gen_output[0])
                print(params)

            
            #end loop (finish the epoch) once the appropiate number of steps have been completed
            if counter >= steps_per_epoch:
                break


            
        #display time stats every 10 epochs
        if (epoch) % 10 == 9:
            end_time = time.time()
            set_time = end_time - start_time_
            set_minutes = int(set_time / 60)
            set_seconds = round(set_time % 60, 2)
            total_time = end_time - start_time
            total_hours = int(total_time / (60 * 60))
            total_minutes = int(total_time % (60 * 60))
            total_seconds = round(total_time % 60, 2)
            display.clear_output(wait=True)

            print('Set of 10 epochs, ending in epoch {} has taken {} minutes, {} seconds'.format(epoch, 
                                                                                                 set_minutes,
                                                                                                 set_seconds))
            print('Time elapsed through epoch {} is {} hours, {} minutes, {} seconds'.format(epoch,
                                                                                              total_hours,
                                                                                              total_minutes,
                                                                                              total_seconds))
            #stop training after 4 hours
            if total_hours >= 4:
                break

        #if (epoch + 1) % 50 == 0:
            #save discriminator and generator via model.save
            #find how how frequently want to save the model
        
            
    
    

In [None]:
ds_size = params['ds_size'] #celeba DS is 162,770 images for training
start_epoch = params['start_epoch']
end_epoch = params['end_epoch']
batch_size = params['batch_size']
params_ = {}
params_['noise_dims'] = params['noise_dims']

params_['batch_size'] = 1
steps_per_epoch = int(tf.math.ceil(ds_size / batch_size))

In [None]:
real_images = input_function(params)

In [None]:
#training code
step_counter = 0
gen_loss = []
disc_loss = []
step = []
# epochs = end_epoch - start_epoch 
for epoch in range(start_epoch, end_epoch + 1):
    tme_start = time.time()
    for image_batch in real_images:
      # start = time.time()
        training_step(params, image_batch)
        end = time.time()
        tme= end - tme_start
        step_counter +=1
        if step_counter %100 == 60:
            training_step(params, image_batch, monitor_loss=True)
            step.append(epoch * step_counter)
            display.clear_output(wait=True)
            print('step {}, epoch {}'.format(step_counter, epoch))
            print('total time: {} seconds // {} per step'. format ((end - tme_start), ((end- tme_start) / step_counter)))
            gen_input = generator_input_function(params_)
            gen_output = generator(gen_input)
            _ = display_generated_pic(gen_output[0])
            print(params)
                    #step  #loss
            plt.plot(step, gen_loss, label='gen')
            plt.plot(step, disc_loss, label='disc')
            plt.ylabel('loss')
            plt.xlabel('step')
            plt.legend()#('upper left')
            plt.show()

        if step_counter > steps_per_epoch:
            step_counter = 0
            break #end epoch

    

In [None]:
gan_model_path = 'gs://jh-gan-testing/gan_1'
disc_model_path = 'gs://jh-gan-testing/disc_1'

In [None]:
generator.save(gan_model_path)
discriminator.save(disc_model_path)