https://www.tensorflow.org/guide/datasets

https://www.tensorflow.org/guide/performance/datasets

# Reading inputs

tf.data API를 통해 우리는 on-memory(from_tensor_slice), generator(from_generator), file(TFrecord) 등과 같이 다양한 input source로부터 dataset을 만들고, interleave, map, filter, reduce 등과 같은 함수로 dataset을 재가공 또는 수정하여 최종적으로 iterator를 통해 데이터를 가져오는 방식임을 배웠다.

이제 이번 코드에서는 실제 각 case 별로 데이터를 읽어 graph로 가져오는 코드를 수행해본다.

In [14]:
import tensorflow as tf
import cv2
import os

# Reading images in categorical directories

In [15]:
def _get_filenames_and_classes(dataset_dir):
      flower_root = os.path.join(dataset_dir)
      directories = []
      class_names = []
      for dir_name in os.listdir(flower_root):
        path = os.path.join(flower_root, dir_name)
        if os.path.isdir(path):
          directories.append(path)
          class_names.append(dir_name)

      photo_filenames = []
      for directory in directories:
        for filename in os.listdir(directory):
          path = os.path.join(directory, filename)
          photo_filenames.append(path)

      return photo_filenames, sorted(class_names)

In [16]:
path = '/home/dan/prj/datasets/flowers/flower_example1/eval'

In [17]:
filepaths, class_names = _get_filenames_and_classes(path)
class_names_to_ids = dict(zip(class_names, range(len(class_names))))

In [18]:
# filename = tf.constant('/home/dan/datasets/flower_photos/sunflowers/14901528533_ac1ce09063.jpg')
# image_string = tf.read_file(filename)
# image_decoded = tf.image.decode_jpeg(image_string, channels=3)
# image_resized = tf.image.resize_images(image_decoded, [224, 224])
# with tf.Session() as sess:
#     image = sess.run(image_resized)    
    
def _parse_function(filename, label):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels = 3)
    image_resized = tf.image.resize_images(image_decoded, [224, 224])
    return image_resized, label

In [19]:
epoch = 10
batch_size = 64
num_images = len(filepaths)

dataset_filepath = tf.data.Dataset.from_tensor_slices(tf.cast(filepaths, tf.string))
dataset_class = tf.data.Dataset.from_tensor_slices(
    [class_names_to_ids[os.path.basename(os.path.dirname(filepath))] for filepath in filepaths])
dataset = tf.data.Dataset.zip((dataset_filepath, dataset_class))

In [20]:
dataset = dataset.shuffle(num_images)
dataset = dataset.repeat()
dataset = dataset.map(_parse_function, num_parallel_calls=4)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

print(next_element)

with tf.Session() as sess:
    for i in range(1):
        a, b = sess.run(next_element)
        

(<tf.Tensor 'IteratorGetNext_1:0' shape=(?, 224, 224, 3) dtype=float32>, <tf.Tensor 'IteratorGetNext_1:1' shape=(?,) dtype=int32>)


In [21]:
print(a.shape)
print(b.shape)

(64, 224, 224, 3)
(64,)
