In [None]:
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow_datasets as tfds
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input, MobileNetV2
from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

print(f"tf.__version__ = {tf.__version__}")

In [None]:
CHECK_POINT_DIR = './checkpoint'
INPUT_SIZE = 224
BATCH_SIZE = 32
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
(train_data_ds, val_data_ds), metadata = tfds.load(
    name='tf_flowers',
    split=['train[:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)
num_classes = metadata.features['label'].num_classes

In [None]:
AUGMENT_SIZE = 384

def encode_one_hot(image, label):
    label = tf.one_hot(label, num_classes)
    return image, label

def augment(image, label):
    image = tf.cast(image, tf.float32)
    shape = tf.shape(image)
    height, width = shape[0], shape[1]
    size = tf.minimum(height, width)
    size = tf.cast(tf.multiply(tf.cast(size, tf.float32), 0.8), tf.int32)
    image = tf.image.random_crop(image, size=[size, size, 3])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.resize(image, [INPUT_SIZE, INPUT_SIZE])
    image = preprocess_input(image)
    return image, label

train_ds = (
    train_data_ds
    .shuffle(4096)
    .map(augment, num_parallel_calls=AUTOTUNE)
    .map(encode_one_hot, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
) 

In [None]:
def center_crop_and_resize(image, label):
    image = tf.cast(image, tf.float32)
    shape = tf.shape(image)
    height, width = shape[0], shape[1]
    size = tf.minimum(height, width)
    image = tf.image.crop_to_bounding_box(image, (height - size) // 2, (width - size) // 2, size, size)
    image = tf.image.resize(image, [INPUT_SIZE, INPUT_SIZE])
    image = preprocess_input(image)
    return image, label
    
val_ds = (
    val_data_ds
    .map(center_crop_and_resize, num_parallel_calls=AUTOTUNE)
    .map(encode_one_hot, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

In [None]:
base_model = MobileNetV2(
    alpha=1.4,
    weights='imagenet',
    include_top=False,
    pooling=None,
    input_shape=(INPUT_SIZE, INPUT_SIZE, 3),
    classes=num_classes
)

In [None]:
# for layer in base_model.layers:
#     layer.trainable = False

In [None]:
model = Sequential()
model.add(base_model)

model.add(layers.AveragePooling2D(pool_size=(7, 7)))
model.add(layers.Flatten())
model.add(layers.Dropout(0.5))

model.add(layers.Dense(num_classes, activation='softmax'))

model.compile(
    loss=CategoricalCrossentropy(label_smoothing=0.1), 
    # loss=SparseCategoricalCrossentropy(from_logits=True), 
    optimizer=Adam(lr=1e-4), 
    metrics=['accuracy']
)

In [None]:
model.summary()

In [None]:
DECAY_START = 1

def scheduler(epoch):
    if epoch < DECAY_START:
        return 5e-5
    else:
        return max(1e-4 * np.exp(0.1 * (DECAY_START - epoch)), 1e-6)

lr_callback = LearningRateScheduler(scheduler, verbose=1)

In [None]:
checkpoint_callback = ModelCheckpoint(
    filepath=CHECK_POINT_DIR, 
    save_best_only=True, 
    monitor='val_loss', 
    mode='min', 
    verbose=1
)

In [None]:
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=20,
    callbacks=[lr_callback, checkpoint_callback],
    verbose=1
)

In [None]:
tf.saved_model.save(model, './saved_model')

In [None]:
import matplotlib.pyplot as plt

def evaluate(val_ds):
    num_total = 0
    errors = []
    for batch in val_ds:
        images, labels = batch
        labels = np.argmax(labels, axis=1)
        num_total += len(images)
        predicts = model.predict(images)
        predicts = np.argmax(predicts, axis=1)
        for image, label, predict in zip(images, labels, predicts):
            if label != predict:
                errors.append({
                    'image': image,
                    'label': label,
                    'predict': predict
                })
    acc = 1 - len(errors) / num_total
    return acc, errors

get_label_name = metadata.features['label'].int2str

def plot_errors(errors, num_col=4):
    num_row = np.ceil(len(errors) / num_col)
    plt.figure(figsize=(4 * num_col, 4 * num_row))
    for idx, error in enumerate(errors):
        plt.subplot(num_row, num_col, idx + 1)
        image = tf.cast(((error['image'] + 1) * 127.5), tf.int32)
        plt.imshow(image)
        plt.title(f"label: {get_label_name(error['label'])}, predict: {get_label_name(error['predict'])}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
acc, errors = evaluate(val_ds)
print("acc = {:.4f}".format(acc))

In [None]:
plot_errors(errors)