https://drive.google.com/file/d/1zUUeT9-vFgmyXCJfZALE5FoYQRrAZgHp/view?usp=sharing

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
img_width = 64
img_height = 64

In [None]:
tfrecord_train = 'train.tfrecord'
tfrecord_test = 'test.tfrecord'
tfrecord_dir = 'tfrecords'

if not os.path.exists(tfrecord_dir):
    os.makedirs(tfrecord_dir)

In [None]:
cur_dir = os.getcwd()
image_dir = os.path.join(cur_dir, 'korean_food_small')

valid_exts = ['.jpg', '.png']
print ('{} categories in {}'.format(len(os.listdir(image_dir)), image_dir))

categories = sorted(os.listdir(image_dir))
num_categ = len(categories)
print (categories)

In [None]:
imgnames = []
labels = []
for label, category in enumerate(categories):
    filelist = os.listdir(os.path.join(image_dir, category))
    imglist = []
    for f in filelist:
        ext = os.path.splitext(f)[-1]
        if ext.lower() not in valid_exts:
            continue
        img = Image.open(os.path.join(image_dir, category, f))
        img = np.asarray(img)
        #print(img.shape)
        if img.shape[2] != 3:
            #print(f)
            continue
        imglist.append(f)
    imgnames += imglist
    labels += [label]*len(imglist)
    print('{} {} images are found'.format(len(imglist), category))

In [None]:
numfiles = len(labels)
idxrand = np.random.permutation(numfiles)
idxtrain = idxrand[:int(0.8*numfiles)]
idxtest = idxrand[int(0.8*numfiles):]
print(len(idxtrain), len(idxtest))

In [None]:
train_tfr_path = os.path.join(cur_dir, tfrecord_dir, tfrecord_train)
test_tfr_path = os.path.join(cur_dir, tfrecord_dir, tfrecord_test)

writer_train = tf.python_io.TFRecordWriter(train_tfr_path)
writer_test = tf.python_io.TFRecordWriter(test_tfr_path)

In [None]:
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [None]:
for idx in idxtrain:
    label = labels[idx]
    img_path = os.path.join(image_dir, categories[label], imgnames[idx])
    
    image = Image.open(img_path)    
    image = image.resize((img_height, img_width))
    
    image = np.asarray(image)    
    image_str = image.tobytes()
    
    example = tf.train.Example(features=tf.train.Features(feature={
        'image': _bytes_feature(image_str),
        'label': _int64_feature(label)
    }))
    writer_train.write(example.SerializeToString())

In [None]:
for idx in idxtest:
    label = labels[idx]
    img_path = os.path.join(image_dir, categories[label], imgnames[idx])
    image = Image.open(img_path)    
    image = image.resize((img_height, img_width))
    
    image = np.asarray(image)    
    image_str = image.tobytes()
    
    example = tf.train.Example(features=tf.train.Features(feature={
        'image': _bytes_feature(image_str),
        'label': _int64_feature(label)
    }))
    writer_test.write(example.SerializeToString())

In [None]:
writer_train.close()
writer_test.close()