In [37]:
import urllib
import pathlib
import numpy as np
import tensorflow as tf

In [4]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [9]:
data_dir = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         fname='flower_photos', untar=True)
data_dir = pathlib.Path(data_dir)

In [15]:
image_count = len(list(data_dir.glob('*/*.jpg')))
image_count

3670

In [18]:
CLASS_NAME = np.array([item.name for item in data_dir.glob('*') if item.name != 'LICENSE.txt'])
CLASS_NAME

array(['tulips', 'roses', 'daisy', 'dandelion', 'sunflowers'],
      dtype='<U10')

## Load using keras.preprocessing

In [24]:
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

In [28]:
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
STEPS_PER_EPOCH = np.ceil(image_count / BATCH_SIZE)

In [39]:
train_data_gen = image_generator.flow_from_directory(
    directory=str(data_dir),
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    color_mode='rgb',
    classes=list(CLASS_NAME),
    class_mode='categorical',
    batch_size=BATCH_SIZE,
    shuffle=True
)

Found 3670 images belonging to 5 classes.


## Load using tf.data

In [40]:
list_ds = tf.data.Dataset.list_files(str(data_dir/'*'/'*'))

In [45]:
for f in list_ds.take(5):
    print(f.numpy())

b'/home/kaimo/.keras/datasets/flower_photos/daisy/9054268881_19792c5203_n.jpg'
b'/home/kaimo/.keras/datasets/flower_photos/daisy/2611119198_9d46b94392.jpg'
b'/home/kaimo/.keras/datasets/flower_photos/tulips/14093565032_a8f1e349d1.jpg'
b'/home/kaimo/.keras/datasets/flower_photos/tulips/15922772266_1167a06620.jpg'
b'/home/kaimo/.keras/datasets/flower_photos/dandelion/7226987694_34552c3115_n.jpg'


In [47]:
def get_label(filepath):
    parts = tf.strings.split(input=filepath, sep=os.path.sep)
    # return a tensor of shape [NUM_CLASS, ]
    return parts[-2] == CLASS_NAME

In [49]:
def decode_img(img):
    img = tf.image.decode_jpeg(contents=img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])

In [62]:
def process_path(filepath):
    label = get_label(filepath)
    img = tf.io.read_file(filename=filepath)
    img = decode_img(img)
    return img, label

In [63]:
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)

__Basic method for training__

In [67]:
def prepare_for_training(dataset, cache=True, shuffle_buffer_size=1000):
    # This is a small dataset, only load it once, and keep it in memory.
    # use `.cache(filename)` to cache preprocessing work for datasets that don't
    # fit in memory.
    if cache:
        if isinstance(cache, str):
            dataset = dataset.cache(cache)
        else:
            dataset = dataset.cache()

    dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
    
    # Repeat forever
    dataset = dataset.repeat()

    dataset = dataset.batch(BATCH_SIZE)

    # `prefetch` lets the dataset fetch batches in the background while the model
    # is training.
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)

    return dataset

In [69]:
train_ds = prepare_for_training(labeled_ds, cache=True, shuffle_buffer_size=1000)

# or
train_ds = labeled_ds.cache().shuffle(1000).repeat().batch(BATCH_SIZE).prefetch(AUTOTUNE)