In [1]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
img_width = 64
img_height = 64

In [3]:
tfrecord_train = 'train.tfrecord'
tfrecord_test = 'test.tfrecord'
tfrecord_dir = 'tfrecords'

if not os.path.exists(tfrecord_dir):
    os.makedirs(tfrecord_dir)

In [4]:
cur_dir = os.getcwd()
image_dir = os.path.join(cur_dir, '..', 'data', 'korean_food')

valid_exts = ['.jpg', '.png']
print ('{} categories in {}'.format(len(os.listdir(image_dir)), image_dir))

categories = sorted(os.listdir(image_dir))
num_categ = len(categories)
print (categories)

20 categories in c:\Work\tensorflow_training_8th-master\..\data\korean_food
['bibimbap', 'bossam', 'bulgogi', 'daegaejjim', 'dakbokkeumtang', 'doenjang_chigae', 'galbijjim', 'galchijorim', 'ganjang_gejang', 'gimbob', 'gyoza', 'jjajangmyeon', 'kimchi', 'nangmyeon', 'ojingeo_bokkeum', 'ramen', 'samgupsal', 'samgyetang', 'soondae', 'soondubu_jjigae']


In [5]:
imgnames = []
labels = []
for label, category in enumerate(categories):
    filelist = os.listdir(os.path.join(image_dir, category))
    imglist = []
    for f in filelist:
        ext = os.path.splitext(f)[-1]
        if ext.lower() not in valid_exts:
            continue
        img = Image.open(os.path.join(image_dir, category, f))
        img = np.asarray(img)
        #print(img.shape)
        if img.shape[2] != 3:
            #print(f)
            continue
        imglist.append(f)
    imgnames += imglist
    labels += [label]*len(imglist)
    print('{} {} images are found'.format(len(imglist), category))

904 bibimbap images are found
605 bossam images are found
688 bulgogi images are found
622 daegaejjim images are found
631 dakbokkeumtang images are found
629 doenjang_chigae images are found
782 galbijjim images are found
698 galchijorim images are found
687 ganjang_gejang images are found
753 gimbob images are found
899 gyoza images are found
599 jjajangmyeon images are found
875 kimchi images are found
639 nangmyeon images are found
721 ojingeo_bokkeum images are found
911 ramen images are found
670 samgupsal images are found
660 samgyetang images are found
670 soondae images are found
614 soondubu_jjigae images are found


In [6]:
numfiles = len(labels)
idxrand = np.random.permutation(numfiles)
idxtrain = idxrand[:int(0.8*numfiles)]
idxtest = idxrand[int(0.8*numfiles):]
print(len(idxtrain), len(idxtest))

11405 2852


In [7]:
train_tfr_path = os.path.join(cur_dir, tfrecord_dir, tfrecord_train)
test_tfr_path = os.path.join(cur_dir, tfrecord_dir, tfrecord_test)

writer_train = tf.python_io.TFRecordWriter(train_tfr_path)
writer_test = tf.python_io.TFRecordWriter(test_tfr_path)

In [8]:
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [9]:
for idx in idxtrain:
    label = labels[idx]
    img_path = os.path.join(image_dir, categories[label], imgnames[idx])
    
    image = Image.open(img_path)    
    image = image.resize((img_height, img_width))
    
    image = np.asarray(image)    
    image_str = image.tobytes()
    
    example = tf.train.Example(features=tf.train.Features(feature={
        'image': _bytes_feature(image_str),
        'label': _int64_feature(label)
    }))
    writer_train.write(example.SerializeToString())

In [10]:
for idx in idxtest:
    label = labels[idx]
    img_path = os.path.join(image_dir, categories[label], imgnames[idx])
    image = Image.open(img_path)    
    image = image.resize((img_height, img_width))
    
    image = np.asarray(image)    
    image_str = image.tobytes()
    
    example = tf.train.Example(features=tf.train.Features(feature={
        'image': _bytes_feature(image_str),
        'label': _int64_feature(label)
    }))
    writer_test.write(example.SerializeToString())

In [11]:
writer_train.close()
writer_test.close()