In [30]:
import os
import time
import tensorflow as tf
import numpy as np
import random
import shutil
from PIL import Image

In [26]:
PHOTO_FOLDER_PATH = "../sketchy_database/256x256/photo/tx_000100000000/"

SKETCH_FOLDER_PATH = "../sketchy_database/256x256/sketch/tx_000100000000/"

OUTPUT_FOLDER_PATH = "../sketchy_database_photo_pairs/"

# The percentage of all photo-sketch image pairs allocated to training (remainder is used for testing).
TRAIN_TEST_SPLIT = 0.95

In [4]:
def get_category_names():
    directories = [x[1] for x in os.walk(PHOTO_FOLDER_PATH)]
    non_empty_dirs = [x for x in directories if x]
    return [item for subitem in non_empty_dirs for item in subitem]

In [5]:
def load_images(photo_file_path, sketch_file_path):
    photo = tf.io.read_file(photo_file_path)
    photo = tf.image.decode_jpeg(photo)
    
    sketch = tf.io.read_file(sketch_file_path)
    sketch = tf.image.decode_png(sketch)
    
    return photo, sketch

In [6]:
def combine_images(photo, sketch):
    # remove alpha channel
    if photo.shape[2] == 4:
        photo = photo[:,:,:3]
    
    if sketch.shape[2] == 4:
        sketch = sketch[:,:,:3]

    return np.concatenate([photo, sketch], axis=1)

In [7]:
def save_combined_image(image, name):
    if not os.path.exists(OUTPUT_FOLDER_PATH):
        os.mkdir(OUTPUT_FOLDER_PATH)
        
    image = Image.fromarray(image)
    image.save(OUTPUT_FOLDER_PATH + name + ".jpg", "JPEG")

In [8]:
category_count = 1
combined_image_count = 1
start_time = time.time()

category_names = get_category_names()

# Iterate across each of the category names (e.g. airplane, alarm_clock, ant, etc.)
for category_name in category_names:
    
    print("[" + str(category_count) + "/" + str(len(category_names)) + "] Combining " + category_name + " photos with their sketch pairs...")
    category_count += 1
    
    # Create a list containing all the file names of photos from the current category.
    photo_file_names = os.listdir(PHOTO_FOLDER_PATH + category_name)

    # Create a list containing all the file names of sketches from the current category.
    sketch_file_names = os.listdir(SKETCH_FOLDER_PATH + category_name)

    # Iterate across each of the photo files.
    for photo_file_name in photo_file_names:

        # Define the function that will be used to determine if a sketch file is a pair to the current photo file.
        def filterFunction(sketch_file_name):
            
            # Remove file extension from the photo file name.
            photo_file_name_without_extension = photo_file_name[:-4]
            
            # Remove the file extension and pair number from the end of the sketch file name.
            sketch_file_name_without_pair_number = sketch_file_name[:-4].rpartition('-')[0]
            
            # Check whether the photo and sketch are pairs based on the absolute name of the files.
            if sketch_file_name_without_pair_number == photo_file_name_without_extension:
                return True
            else:
                return False

        # Use the filter function to create a new list containing the file names of sketches that are a pair to the current photo file.
        sketch_pair_file_names = list(filter(filterFunction, sketch_file_names))

        for sketch_file_name in sketch_pair_file_names:
            
            # Form the complete file paths of the photo and sketch images.
            photo_file_path = PHOTO_FOLDER_PATH + category_name + "/" + photo_file_name
            sketch_file_path = SKETCH_FOLDER_PATH + category_name + "/" + sketch_file_name
            
            # Load both photo and sketch images using their file paths.
            photo, sketch = load_images(photo_file_path, sketch_file_path)
            
            # Create a combined image where the photo and sketch images are side by side.
            combined_image = combine_images(photo, sketch)
            
            # Remove file extension from the sketch file name.
            sketch_file_name_without_extension = sketch_file_name[:-4]
            
            # Save the combined image to the specified output folder using the original name of the sketch image file.
            save_combined_image(combined_image, category_name + "-" + sketch_file_name_without_extension)
            
            combined_image_count += 1
            
print("Time taken to combine all images: " + str((time.time() - start_time) / 60) + " seconds.")            

SyntaxError: invalid syntax (<ipython-input-8-9b976cd33af8>, line 60)

### Split data into training and test sub folders

In [34]:
combined_image_file_names = os.listdir(OUTPUT_FOLDER_PATH)
random.shuffle(combined_image_file_names)
total_images = len(combined_image_file_names)

number_of_training_images = round(TRAIN_TEST_SPLIT * total_images)
number_of_test_images = total_images - number_of_training_images
print("Splitting a total of " + str(total_images) + " images into " + str(number_of_training_images) + " training images and " + str(number_of_test_images) + " test images.")

train_file_names = combined_image_file_names[:number_of_training_images]
test_file_names = combined_image_file_names[number_of_training_images:]

if not os.path.exists(OUTPUT_FOLDER_PATH + 'train/'):
        os.mkdir(OUTPUT_FOLDER_PATH + 'train/')
        
if not os.path.exists(OUTPUT_FOLDER_PATH + 'test/'):
        os.mkdir(OUTPUT_FOLDER_PATH + 'test/')

for train_file_name in train_file_names:
    shutil.move(OUTPUT_FOLDER_PATH + train_file_name, OUTPUT_FOLDER_PATH + 'train/' + train_file_name)
    
for test_file_name in test_file_names:
    shutil.move(OUTPUT_FOLDER_PATH + test_file_name, OUTPUT_FOLDER_PATH + 'test/' + test_file_name)

Splitting a total of 75481 images into 71707 training images and 3774 test images.
