In [1]:
import os
import numpy as np
from scipy.misc import imread, imresize
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
img_width = 64
img_height = 64

tfrecord_train = 'train.tfrecord'
tfrecord_test = 'test.tfrecord'
tfrecord_dir = 'tfrecords'

In [3]:
cur_dir = os.getcwd()
image_dir = os.path.join(cur_dir, '..', 'images')

valid_exts = ['.jpg', '.gif', '.png', '.jpeg']
print('%d categories in %s' % (len(os.listdir(image_dir)), image_dir))

categories = sorted(os.listdir(image_dir))
num_categ = len(categories)
print (categories)

2 categories in c:\Work\FC_TF_course\..\images
['cats', 'dogs']


In [4]:
imgnames = []
labels = []
for label, category in enumerate(categories):
    filelist = os.listdir(os.path.join(image_dir, category))
    imglist = []
    for f in filelist:
        ext = os.path.splitext(f)[-1]
        if ext.lower() not in valid_exts:
            continue
        imglist.append(f)
    imgnames += imglist
    labels += [label]*len(imglist)
    print('{} {} images are found'.format(len(imglist), category))

100 cats images are found
100 dogs images are found


In [5]:
numfiles = len(labels)
idxrand = np.random.permutation(numfiles)
idxtrain = idxrand[:int(0.8*numfiles)]
idxtest = idxrand[int(0.8*numfiles):]

In [6]:
writer_train = tf.python_io.TFRecordWriter(os.path.join(cur_dir, tfrecord_dir, tfrecord_train))
writer_test  = tf.python_io.TFRecordWriter(os.path.join(cur_dir, tfrecord_dir, tfrecord_test))

In [7]:
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [8]:
for idx in idxtrain:
    label = labels[idx]
    img_path = os.path.join(image_dir, categories[label], imgnames[idx])
    image = np.array(imread(img_path))
    
    image = imresize(image, [img_height, img_width])
    image_str = image.tostring()
    
    example = tf.train.Example(features=tf.train.Features(feature={
        'image': _bytes_feature(image_str),
        'label': _int64_feature(label)
    }))
    writer_train.write(example.SerializeToString())

In [9]:
for idx in idxtest:
    label = labels[idx]
    img_path = os.path.join(image_dir, categories[label], imgnames[idx])
    image = np.array(imread(img_path))
    
    image = imresize(image, [img_height, img_width])
    image_str = image.tostring()
    
    example = tf.train.Example(features=tf.train.Features(feature={
        'image': _bytes_feature(image_str),
        'label': _int64_feature(label)
    }))
    writer_test.write(example.SerializeToString())

In [10]:
writer_train.close()
writer_test.close()