In [None]:
! pip install mtcnn

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from IPython import display
import matplotlib.pyplot as plt
from PIL import Image 
from google.cloud import storage
import os
import io
import time
import math
import cv2 as cv
from mtcnn.mtcnn import MTCNN
face_detector = MTCNN()

In [None]:
try:
    from google.colab import auth
    auth.authenticate_user()

except ModuleNotFoundError:

    from google.oauth2 import service_account

    credentials = service_account.Credentials.from_service_account_file( #file location of GCS private key
        'xx')

In [None]:
#uncomment below comment if working outside of colab
client = storage.Client(project='deepfake-research')#, credentials=credentials)
#uncomment below for testing
objects = client.list_blobs('celeba-jh', prefix='img_align_celeba/img_align_celeba')#,max_results=100)
image_list = []
for object_ in objects:
    path = str(object_).split(', ')[1]
    image_list.append(path)

In [None]:
len(tfrecords)

In [None]:
params = {'batch_size': 100, 
         'image_dims': (192, 128),
         'noise_dims': 100,
         'ds_size': 162770,
         'start_epoch': 1,
         'end_epoch': 1}

In [None]:
def convert_image_to_tfExample(image):
    #passes in an image and returns a tf.Example with the image byte string being the only feature
    features = {
        'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
    }
    return tf.train.Example(features=tf.train.Features(feature=features))

In [None]:
def preprocess_image(image, params=params):
  
    resized_height, resized_width = params['image_dims'] #s/b (192, 128)
    try:
        '''If the face detector cannot detect a face, the first line in the try statement
        will produce an IndexError. If this happens, pass the entire resized image into our model.
        Given the small number of pictures in the DS this will apply to, the affect on the model
        should be small, and a try/except statement should be more efficient than checking the len
        of the number of faces detected on every photo in the DS.'''
        x_start, y_start, x_len, y_len = face_detector.detect_faces(image)[0]['box']
        #the face detector will sometimes return a negative number for the x/y starting points
        #if a negative number is returned for a starting point, do not crop from the left/bottom
        x_start, y_start = max(x_start, 0), max(y_start, 0)
        #image array will only cover the detected face
        face_array = image[y_start:(y_start + y_len), x_start:(x_start + x_len)]
    except IndexError:
        face_array = image
    #resize array to match input of model
    face_resized = tf.image.resize_with_pad(face_array,
                                           target_height=resized_height, 
                                           target_width=resized_width)
    #convert the dtype to tf.uint8
    face_resized = (tf.cast(face_resized, tf.uint8)) #PIL will not accept -1 to 1
    #convert to array
    face_resized = np.array(face_resized)

    return face_resized



In [None]:
def create_and_upload_tfrecord_file(start_index,
                                    end_index,
                                    TFRecord_object_name,
                                    file_list=image_list,
                                    TFRecord_bucket_name='celeba-ds-jh', #find bucket name for TFRecord files
                                    tmp_TFRecord_file='tmp.tfrecord',
                                    img_bucket_name='celeba-jh',
                                    tmp_image_file='image.jpg',
                                   ):
    img_bucket = client.get_bucket(img_bucket_name)
    with tf.io.TFRecordWriter(tmp_TFRecord_file) as writer:
        for file in file_list[start_index: end_index]: 
            original_image_string = img_bucket.get_blob(file).download_as_string()
            original_image_bytes = io.BytesIO(original_image_string)
            original_image = Image.open(original_image_bytes)
            original_image_array = np.array(original_image)
            preprocessed_image = preprocess_image(original_image_array) #preprocessing step
            preprocessed_image_ = Image.fromarray(preprocessed_image)
            preprocessed_image_.save(tmp_image_file)
            preprocessed_image_string = open(tmp_image_file, 'rb').read()
            tf_example = convert_image_to_tfExample(preprocessed_image_string)
            writer.write(tf_example.SerializeToString())
        # preprocessed_image_string.close()
    #upload tmp_TFRecord_file to TFRecord_bucket_name
    TFRecord_bucket = client.get_bucket(TFRecord_bucket_name)
    TFRecord_object = TFRecord_bucket.blob(TFRecord_object_name) #new object name
    TFRecord_object.upload_from_filename(tmp_TFRecord_file) #old file name
    os.remove(tmp_image_file)
    os.remove(tmp_TFRecord_file)

        
        


In [None]:
def split_TFRecords(file_list=image_list,
                   num_TFRecord_files=20
                   ):
    number_of_images = len(file_list)
    images_per_TFRecord_file = math.ceil(number_of_images / num_TFRecord_files)
    images_last_TFRecord_file = number_of_images % images_per_TFRecord_file
    start_index = 0
    TFR_counter = 1
    for TFRecord_file in range(num_TFRecord_files):
        if TFRecord_file == num_TFRecord_files:
            end_index = start_index + images_last_TFRecord_file
        else:
            end_index = start_index + images_per_TFRecord_file
        TFRecord_object_name = 'celeba_all_preprocessed.tfrecord_{}_of{}'.format(TFR_counter, num_TFRecord_files)
        create_and_upload_tfrecord_file(start_index, end_index, TFRecord_object_name)
        TFR_counter +=1
        start_index = end_index
    

In [None]:
split_TFRecords()