In [None]:
import os
import random
import tensorflow as tf
from tqdm import tqdm
import dataset_utils


from matplotlib import pyplot as plt
%matplotlib inline  

In [None]:
_RANDOM_SEED = 0
_NUM_VALIDATION = 5000#8550


class JpegImageReader(object):
    def __init__(self):
        # Initializes function that decodes RGB JPEG data.
        self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
        self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

    def read_image_dims(self, sess, image_data):
        image = self.decode_jpeg(sess, image_data)
        return image.shape[0], image.shape[1]

    def decode_jpeg(self, sess, image_data):
        image = sess.run(self._decode_jpeg,
                     feed_dict={self._decode_jpeg_data: image_data})
        assert len(image.shape) == 3
        assert image.shape[2] == 3
        return image
    
class PngImageReader(object):
    def __init__(self):
        # Initializes function that decodes RGB JPEG data.
        self._decode_png_data = tf.placeholder(dtype=tf.string)
        self._decode_png = tf.image.decode_png(self._decode_png_data, channels=3)

    def read_image_dims(self, sess, image_data):
        image = self.decode_png(sess, image_data)
        return image.shape[0], image.shape[1]

    def decode_png(self, sess, image_data):
        image = sess.run(self._decode_png,
                     feed_dict={self._decode_png_data: image_data})
        assert len(image.shape) == 3
        assert image.shape[2] == 3
        return image

def _get_filenames_and_classes(dataset_dir):
    directories = []
    class_names = []
    for filename in os.listdir(dataset_dir):
        path = os.path.join(dataset_dir, filename)
        if os.path.isdir(path):
            directories.append(path)
            class_names.append(filename)

    photo_filenames = []
    for directory in directories:
        for filename in os.listdir(directory):
            path = os.path.join(directory, filename)
            photo_filenames.append(path)

    return photo_filenames, sorted(class_names)

def _dataset_exists(dataset_dir):
    for split_name in ['train', 'validation']:
        output_filename = _get_dataset_filename(dataset_dir, split_name)
        if not tf.gfile.Exists(output_filename):
            return False
    return True


def _get_dataset_filename(dataset_dir, split_name):
    output_filename = '%s.tfrecord' % (split_name)
    return os.path.join(dataset_dir, output_filename)


def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
    assert split_name in ['train', 'validation']

    with tf.Graph().as_default():
        image_reader = PngImageReader()

        with tf.Session('') as sess:
            output_filename = _get_dataset_filename(dataset_dir, split_name)
            with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
                files_written=0
                for filename in tqdm(filenames):
                    image_data = tf.gfile.FastGFile(filename, 'rb').read()
                    height, width = image_reader.read_image_dims(sess, image_data)

                    class_name = os.path.basename(os.path.dirname(filename))
                    class_id = class_names_to_ids[class_name]
                    
                    example = dataset_utils.image_to_tfexample(image_data, b'png', height, width, class_id)
                    tfrecord_writer.write(example.SerializeToString())
                    files_written+=1
            print('%d files written for %s' % (files_written, split_name))
                    
def generate_dataset(dataset_dir):
    if not tf.gfile.Exists(dataset_dir):
        tf.gfile.MakeDirs(dataset_dir)

    if _dataset_exists(dataset_dir):
        print('Dataset files already exist. Exiting without re-creating them.')
        return

    #dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
    photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
    class_names_to_ids = dict(zip(class_names, range(len(class_names))))

    # Divide into train and test:
    random.seed(_RANDOM_SEED)
    random.shuffle(photo_filenames)
    training_filenames = photo_filenames[_NUM_VALIDATION:]
    validation_filenames = photo_filenames[:_NUM_VALIDATION]

    # First, convert the training and validation sets.
    _convert_dataset('train', training_filenames, class_names_to_ids,
                   dataset_dir)
    _convert_dataset('validation', validation_filenames, class_names_to_ids,
                   dataset_dir)

    # Finally, write the labels file:
    labels_to_class_names = dict(zip(range(len(class_names)), class_names))
    dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
    return training_filenames, validation_filenames, class_names_to_ids

In [None]:
dataset_dir = "data/160x160_faces"
training_filenames, validation_filenames, class_names_to_ids = generate_dataset(dataset_dir)