In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

In [None]:
(train_ds, val_ds, test_ds), metadata = tfds.load('tf_flowers', split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], with_info=True, as_supervised=True)

In [None]:
num_classes = metadata.features['label'].num_classes
num_classes

In [None]:
get_label_name = metadata.features['label'].int2str

In [None]:
image, label = next(iter(train_ds))

In [None]:
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))

In [None]:
image.shape

In [None]:
IMG_SIZE = 180

resize_and_rescale = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.Resizing(IMG_SIZE, IMG_SIZE),
    tf.keras.layers.experimental.preprocessing.Rescaling(1./255)
])

In [None]:
result = resize_and_rescale(image)

In [None]:
_ = plt.imshow(result)

In [None]:
print("Min and Max of image: {}, {}".format(np.min(result), np.max(result)))

In [None]:
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip(),
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.2)
])

In [None]:
image = tf.expand_dims(image, 0)

In [None]:
plt.figure(figsize=(10, 10))
for i in range(9):
    augmented_image = data_augmentation(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0])

In [None]:
model = tf.keras.Sequential([
    resize_and_rescale,
    data_augmentation,
    tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),
])

In [None]:
aug_ds = train_ds.map(lambda x, y: (resize_and_rescale(x, training=True), y))

In [None]:
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds, shuffle=False, augment=False):
    ds = ds.map(lambda x, y: (resize_and_rescale(x), y), num_parallel_calls=AUTOTUNE)

    if shuffle:
        ds = ds.shuffle(1000)
    
    ds = ds.batch(batch_size)

    if augment:
        ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    
    return ds.prefetch(AUTOTUNE)

In [None]:
train_ds = prepare(train_ds, True, True)
val_ds = prepare(val_ds)
test_ds = prepare(test_ds)

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(num_classes)
])

model.compile('adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(True), metrics=['accuracy'])

epochs = 10
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)

In [None]:
loss, acc = model.evaluate(test_ds)
acc

In [None]:
def random_invert_img(x, p = 0.5):
    if tf.random.uniform([]) < 0.5:
        x = (255 - x)
    else:
        x
    return x

In [None]:
def random_invert(factor=0.5):
    return tf.keras.layers.Lambda(lambda x: random_invert_img(x, factor))

random_invert = random_invert()

In [None]:
plt.figure(figsize=(10,10))
for i in range(9):
    augmented_image = random_invert(image)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_image[0].numpy().astype('uint8'))

In [None]:
class RandomInvert(tf.keras.layers.Layer):
    def __init__(self, factor=0.5, **kwargs):
        super().__init__(**kwargs)
        self.factor = factor
    
    def call(self, x):
        return random_invert_img(x)

In [None]:
_ = plt.imshow(RandomInvert()(image)[0])

In [None]:
(train_ds, val_ds, test_ds), metadata = tfds.load('tf_flowers', split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], with_info=True, as_supervised=True)

In [None]:
image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))

In [None]:
def visualize(original, augmented):
    fig = plt.figure()
    plt.subplot(1, 2, 1)
    plt.title('Origin image')
    plt.imshow(original)

    plt.subplot(1, 2, 2)
    plt.title('Augmented')
    plt.imshow(augmented)

In [None]:
flipped = tf.image.flip_left_right(image)
visualize(image, flipped)

In [None]:
grayscaled = tf.image.rgb_to_grayscale(image)
visualize(image, tf.squeeze(grayscaled))
_ = plt.colorbar()

In [None]:
saturated = tf.image.adjust_saturation(image, 3)
visualize(image, saturated)

In [None]:
brighted = tf.image.adjust_brightness(image, 0.4)
visualize(image, brighted)

In [None]:
croped = tf.image.central_crop(image, 0.5)
visualize(image, croped)

In [None]:
rotated = tf.image.rot90(image)
visualize(image, rotated)

In [None]:
for i in range(3):
    seed = (i, 0)
    stateless_random_brightness = tf.image.stateless_random_brightness(image, 0.95, seed)
    visualize(image, stateless_random_brightness)

In [None]:
for i in range(3):
    seed = (i, 0)
    stateless_random_contrast = tf.image.stateless_random_contrast(image, 0.1, 0.9, seed)
    visualize(image, stateless_random_contrast)

In [None]:
for i in range(3):
    seed = (i, 0)
    stateless_random_crop = tf.image.stateless_random_crop(image, [210, 300, 3], seed)
    visualize(image, stateless_random_crop)

In [None]:
(train_datasets, val_ds, test_ds), metadata = tfds.load('tf_flowers', split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], with_info=True, as_supervised=True)

In [None]:
def resize_and_rescale(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = (image / 255.0)
    return image, label

In [None]:
def augment(image_label, seed):
    image, label = image_label
    image, label = resize_and_rescale(image, label)
    image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6)

    new_seed = tf.random.experimental.stateless_split(seed, 1)[0, :]

    image = tf.image.stateless_random_crop(image, [IMG_SIZE, IMG_SIZE, 3], seed)

    image = tf.image.stateless_random_brightness(image, 0.5, new_seed)

    image = tf.clip_by_value(image, 0, 1)

    return image, label

In [None]:
counter = tf.data.experimental.Counter()
train_ds = tf.data.Dataset.zip((train_datasets, (counter, counter)))

In [None]:
train_ds = train_ds.shuffle(1000).map(augment, num_parallel_calls = AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)

In [None]:
val_ds = val_ds.map(resize_and_rescale, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)

In [None]:
test_ds = test_ds.map(resize_and_rescale, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)

In [None]:
rng = tf.random.Generator.from_seed(123, alg='philox')

In [None]:
def f(x, y):
    seed = rng.make_seeds(2)[0]
    image, label = augment((x, y), seed)
    return image, label

In [None]:
train_ds = train_datasets.shuffle(1000).map(f, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)

In [None]:
val_ds = val_ds.map(resize_and_rescale, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)

In [None]:
test_ds = test_ds.map(resize_and_rescale, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)