In [1]:
# Refactored from Daniil's blog: TFRecords Guide
# http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/

import argparse
import os
import sys

from PIL import Image
import numpy as np
import skimage.io as io
from scipy.misc import imread, imresize
import tensorflow as tf
%matplotlib inline

FLAGS = None

In [2]:
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert_to(data_set, name):
    """Converts a dataset to tfrecords."""
    images = data_set.images
    labels = data_set.labels
    num_examples = data_set.num_examples
    
    if images.shape[0] != num_examples:
        raise ValueError('Images size %d does not match label size %d.' %
                        (images.shape[0], num_examples))
    rows = images.shape[0]
    cols = images.shape[2]
    depth = images.shape[3]
    
    filename = os.path.join(FLAGS.directory, name + '.tfrecords')
    print('Writing', filename)
    writer = tf.python_io.TFRecordWriter(filename)
    for index in range(num_examples):
        image_str = images[index].tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(rows),
            'width': _int64_feature(cols),
            'depth': _int64_feature(depth),
            'label': _int64_feature(int(labels[index])),
            'image_raw': _bytess_feature(image_raw)
        }))
        writer.write(example.SerializeToString())
    writer.close()
    
def main(_):
    # Load Custom Dat (Caltech101)
    cwd = os.getcwd()
    path = os.path.join(cwd, FLAGS.directory)
    valid_exts = ['.jpg', '.gif', '.png', '.jpeg']
    print('%d categories in %s' % (len(os.listdir(path)), path))
    
    categories = sorted(os.listdir(path))
    num_categ = len(categories)
    
    imgnames = []
    labels = []
    for label, category in enumerate(categories):
        filelist = os.listdir(os.path.join(path, category))
        imglist = []
        for f in filelist:
            ext = os.path.splitext(f)[-1]
            if ext.lower() not in valid_exts:
                continue
            imglist.append(f)
        imgnames += imglist
        labels += [label]*len(imglist)
        print('[%3d/%3d] %d %s images are found' % (label+1, num_categ,
                                                    len(imglist), category))

    numfiles = len(labels)
    ratio = FLAGS.trainset_ratio
    idxrand = np.random.permutation(numfiles)
    idxtrain = idxrand[:int(ratio*numfiles)]
    idxtest = idxrand[int(ratio*numfiles):]
    
    tfrecord_fname_train = 'custom_train.tfrecords'
    tfrecord_fname_test = 'custom_test.tfrecords'
    
    writer_train = tf.python_io.TFRecordWriter(os.path.join(FLAGS.directory, '../', tfrecord_fname_train))
    writer_test = tf.python_io.TFRecordWriter(os.path.join(FLAGS.directory, '../', tfrecord_fname_test))
    
    disp_step = 400
    print('\nWriting \"%s\" start' % tfrecord_fname_train)
    for itr, idx in enumerate(idxtrain):
        label = labels[idx]
        imgpath = os.path.join(path, categories[label], imgnames[idx])
        image = np.array(Image.open(imgpath))
        
        height = image.shape[0]
        width = image.shape[1]
        try:
            depth = image.shape[2]
        except:
            depth = 1
        image_str = image.tostring()

        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(height),
            'width': _int64_feature(width),
            'depth': _int64_feature(depth),
            'image_raw': _bytes_feature(image_str),
            'label': _int64_feature(label)
        }))
        writer_train.write(example.SerializeToString())
        
        if (itr+1)% disp_step == 0 or (itr+1) == idxtrain.shape[0]:
            print('[%3d/%3d] train data has been written' % (itr+1,
                                                             idxtrain.shape[0]))
    
    print('\nWriting \"%s\" start' % tfrecord_fname_test)
    for itr, idx in enumerate(idxtest):
        label = labels[idx]
        imgpath = os.path.join(path, categories[label], imgnames[idx])
        image = np.array(Image.open(imgpath))
        height = image.shape[0]
        width = image.shape[1]
        try:
            depth = image.shape[2]
        except:
            depth = 1
        image_str = image.tostring()

        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(height),
            'width': _int64_feature(width),
            'depth': _int64_feature(depth),
            'image_raw': _bytes_feature(image_str),
            'label': _int64_feature(label)
        }))
        writer_test.write(example.SerializeToString())
        
        if (itr+1) % disp_step == 0 or (itr+1) == idxtest.shape[0]:
            print('[%3d/%3d] test data has been written' % (itr+1,
                                                            idxtest.shape[0]))

    writer_train.close()
    writer_test.close()
    print('Done')              

In [3]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--directory',
        type=str,
        default='tmp/data/101_ObjectCategories',
        help='Directory to read data files and write the converted result'
    )
    parser.add_argument(
        '--trainset_ratio',
        type=float,
        default=0.7,
        help='Ratio of trainset to whole data'
    )
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

101 categories in /home/chang/notebooks/2017/tf-slim/tmp/data/101_ObjectCategories
[  1/101] 435 Faces images are found
[  2/101] 435 Faces_easy images are found
[  3/101] 200 Leopards images are found
[  4/101] 798 Motorbikes images are found
[  5/101] 55 accordion images are found
[  6/101] 800 airplanes images are found
[  7/101] 42 anchor images are found
[  8/101] 42 ant images are found
[  9/101] 47 barrel images are found
[ 10/101] 54 bass images are found
[ 11/101] 46 beaver images are found
[ 12/101] 33 binocular images are found
[ 13/101] 128 bonsai images are found
[ 14/101] 98 brain images are found
[ 15/101] 43 brontosaurus images are found
[ 16/101] 85 buddha images are found
[ 17/101] 91 butterfly images are found
[ 18/101] 50 camera images are found
[ 19/101] 43 cannon images are found
[ 20/101] 123 car_side images are found
[ 21/101] 47 ceiling_fan images are found
[ 22/101] 59 cellphone images are found
[ 23/101] 62 chair images are found
[ 24/101] 107 chandelier imag

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
