In [1]:
# Refactored from Daniil's blog: TFRecords Guide
# http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/

import argparse
import os
import sys
from six.moves import urllib
import tarfile

from PIL import Image
import numpy as np
import skimage.io as io
from scipy.misc import imread, imresize
import tensorflow as tf
%matplotlib inline

IMAGE_WIDTH = 224
IMAGE_HEIGHT = 224
FLAGS = None

### Function for downloading a dataset

In [2]:
# from https://github.com/tensorflow/models/blob/master/slim/datasets/dataset_utils.py
def download_and_uncompress_tarball(tarball_url, dataset_dir):
    """Downloads the `tarball_url` and uncompresses it locally.
    Args:
    tarball_url: The URL of a tarball file.
    dataset_dir: The directory where the temporary files are stored.
    """
    filename = tarball_url.split('/')[-1]
    filepath = os.path.join(dataset_dir, filename)

    def _progress(count, block_size, total_size):
        sys.stdout.write('\r>> Downloading %s %.1f%%' % 
                         (filename, float(count * block_size) / float(total_size) * 100.0))
        sys.stdout.flush()
    filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
    print()
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
    tarfile.open(filepath, 'r:gz').extractall(dataset_dir)

### Auxiliary functions and main function for converting

In [3]:
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    
def main(_):
    # Download dataset (Caltech101)
    if not os.path.isdir(FLAGS.directory):
        parentdir = os.path.join(*FLAGS.directory.split('/')[:-1])
        os.makedirs(parentdir)
        download_and_uncompress_tarball(FLAGS.fileurl, parentdir)
        
    
    # Load Custom Dat (Caltech101)
    path = FLAGS.directory
    valid_exts = ['.jpg', '.gif', '.png', '.jpeg']
    print('%d categories in %s' % (len(os.listdir(path)), path))
    
    categories = sorted(os.listdir(path))
    num_categ = len(categories)
    
    imgnames = []
    labels = []
    for label, category in enumerate(categories):
        filelist = os.listdir(os.path.join(path, category))
        imglist = []
        for f in filelist:
            ext = os.path.splitext(f)[-1]
            if ext.lower() not in valid_exts:
                continue
            imglist.append(f)
        imgnames += imglist
        labels += [label]*len(imglist)
        print('[%3d/%3d] %d %s images are found' % (label+1, num_categ,
                                                    len(imglist), category))

    numfiles = len(labels)
    ratio = FLAGS.trainset_ratio
    idxrand = np.random.permutation(numfiles)
    idxtrain = idxrand[:int(ratio*numfiles)]
    idxtest = idxrand[int(ratio*numfiles):]
    
    tfrecord_fname_train = 'custom_train.tfrecords'
    tfrecord_fname_test = 'custom_test.tfrecords'
    
    writer_train = tf.python_io.TFRecordWriter(os.path.join(FLAGS.directory, '../', tfrecord_fname_train))
    writer_test = tf.python_io.TFRecordWriter(os.path.join(FLAGS.directory, '../', tfrecord_fname_test))
    
    disp_step = 400
    print('\nWriting \"%s\" start' % tfrecord_fname_train)
    for itr, idx in enumerate(idxtrain):
        label = labels[idx]
        imgpath = os.path.join(path, categories[label], imgnames[idx])
        image = np.array(Image.open(imgpath))
        
        try:
            depth = image.shape[2]
        except:
            image = np.tile(np.expand_dims(image, axis=2), [1,1,3])
            depth = image.shape[2]
            
        image = imresize(image, [IMAGE_HEIGHT, IMAGE_WIDTH])
        image_str = image.tostring()

        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(IMAGE_HEIGHT),
            'width': _int64_feature(IMAGE_WIDTH),
            'depth': _int64_feature(depth),
            'image_raw': _bytes_feature(image_str),
            'label': _int64_feature(label)
        }))
        writer_train.write(example.SerializeToString())
        
        if (itr+1)% disp_step == 0 or (itr+1) == idxtrain.shape[0]:
            print('[%3d/%3d] train data has been written' % (itr+1,
                                                             idxtrain.shape[0]))
    
    print('\nWriting \"%s\" start' % tfrecord_fname_test)
    for itr, idx in enumerate(idxtest):
        label = labels[idx]
        imgpath = os.path.join(path, categories[label], imgnames[idx])
        image = np.array(Image.open(imgpath))

        try:
            depth = image.shape[2]
        except:
            image = np.tile(np.expand_dims(image, axis=2), [1,1,3])
            depth = image.shape[2]
            
        image = imresize(image, [IMAGE_HEIGHT, IMAGE_WIDTH])
        image_str = image.tostring()

        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(IMAGE_HEIGHT),
            'width': _int64_feature(IMAGE_WIDTH),
            'depth': _int64_feature(depth),
            'image_raw': _bytes_feature(image_str),
            'label': _int64_feature(label)
        }))
        writer_test.write(example.SerializeToString())
        
        if (itr+1) % disp_step == 0 or (itr+1) == idxtest.shape[0]:
            print('[%3d/%3d] test data has been written' % (itr+1,
                                                            idxtest.shape[0]))

    writer_train.close()
    writer_test.close()
    print('Done')              

In [4]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--fileurl',
        type=str,
        default='http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz',
        help='Dataset url'
    )
    parser.add_argument(
        '--directory',
        type=str,
        default='tmp/101_ObjectCategories',
        help='Directory to read data files and write the converted result'
    )
    parser.add_argument(
        '--trainset_ratio',
        type=float,
        default=0.7,
        help='Ratio of trainset to whole data'
    )
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

>> Downloading 101_ObjectCategories.tar.gz 100.0%
Successfully downloaded 101_ObjectCategories.tar.gz 131740031 bytes.
102 categories in tmp/101_ObjectCategories
[  1/102] 467 BACKGROUND_Google images are found
[  2/102] 435 Faces images are found
[  3/102] 435 Faces_easy images are found
[  4/102] 200 Leopards images are found
[  5/102] 798 Motorbikes images are found
[  6/102] 55 accordion images are found
[  7/102] 800 airplanes images are found
[  8/102] 42 anchor images are found
[  9/102] 42 ant images are found
[ 10/102] 47 barrel images are found
[ 11/102] 54 bass images are found
[ 12/102] 46 beaver images are found
[ 13/102] 33 binocular images are found
[ 14/102] 128 bonsai images are found
[ 15/102] 98 brain images are found
[ 16/102] 43 brontosaurus images are found
[ 17/102] 85 buddha images are found
[ 18/102] 91 butterfly images are found
[ 19/102] 50 camera images are found
[ 20/102] 43 cannon images are found
[ 21/102] 123 car_side images are found
[ 22/102] 47 ceilin

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
