
# Introduction #

Welcome to the [**Petals to the Metal**](https://www.kaggle.com/c/tpu-getting-started) competition! In this competition, you’re challenged to build a machine learning model to classify 104 types of flowers based on their images.

In this tutorial notebook, you'll learn how to build an image classifier in Keras and train it on a [Tensor Processing Unit (TPU)](https://www.kaggle.com/docs/tpu). At the end, you'll have a complete project you can build off of with ideas of your own.

<blockquote style="margin-right:auto; margin-left:auto; background-color: #ebf9ff; padding: 1em; margin:24px;">
    <strong>Fork This Notebook!</strong><br>
Create your own editable copy of this notebook by clicking on the <strong>Copy and Edit</strong> button in the top right corner.
</blockquote>

# Step 1: Imports #

We begin by importing several Python packages.

In [None]:
import math, re, os
import numpy as np
import tensorflow as tf

print("Tensorflow version " + tf.__version__)

# Step 2: Distribution Strategy #

A TPU has eight different *cores* and each of these cores acts as its own accelerator. (A TPU is sort of like having eight GPUs in one machine.) We tell TensorFlow how to make use of all these cores at once through a **distribution strategy**. Run the following cell to create the distribution strategy that we'll later apply to our model.

In [None]:
# Detect TPU, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() 

print("REPLICAS: ", strategy.num_replicas_in_sync)

We'll use the distribution strategy when we create our neural network model. Then, TensorFlow will distribute the training among the eight TPU cores by creating eight different *replicas* of the model, one for each core.

# Step 3: Loading the Competition Data #

## Get GCS Path ##

When used with TPUs, datasets need to be stored in a [Google Cloud Storage bucket](https://cloud.google.com/storage/). You can use data from any public GCS bucket by giving its path just like you would data from `'/kaggle/input'`. The following will retrieve the GCS path for this competition's dataset.

In [None]:
from kaggle_datasets import KaggleDatasets

GCS_DS_PATH = KaggleDatasets().get_gcs_path('tpu-getting-started')
print(GCS_DS_PATH) # what do gcs paths look like?

You can use data from any public dataset here on Kaggle in just the same way. If you'd like to use data from one of your private datasets, see [here](https://www.kaggle.com/docs/tpu#tpu3pt5).

## Load Data ##

When used with TPUs, datasets are often serialized into [TFRecords](https://www.kaggle.com/ryanholbrook/tfrecords-basics). This is a format convenient for distributing data to each of the TPUs cores. We've hidden the cell that reads the TFRecords for our dataset since the process is a bit long. You could come back to it later for some guidance on using your own datasets with TPUs.

In [None]:
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose'] 
AUTO = tf.data.experimental.AUTOTUNE
# Define the batch size. This will be 16 with TPU off and 128 (=16*8) with TPU on
BATCH_SIZE = 16 * strategy.num_replicas_in_sync


# Explore Data #

Let's take a moment to look at some of the images in the dataset.

In [None]:
from matplotlib import pyplot as plt

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case,
                                     # these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is
    # the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square
    # or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else CLASSES[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()


def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

# Transform the Data

In [None]:
def decode_image(image_data,image_size):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*image_size, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example, image_size):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'], image_size)
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example, image_size):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'], image_size)
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def load_dataset(filenames, image_size, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(lambda x: read_labeled_tfrecord(x, image_size) if labeled else read_unlabeled_tfrecord(x, image_size), num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

# Augmentation of Data

In [None]:
def data_augment(image, label):
    # Thanks to the dataset.prefetch(AUTO)
    # statement in the next function (below), this happens essentially
    # for free on TPU. Data pipeline code is executed on the "CPU"
    # part of the TPU while the TPU itself is computing gradients.
    image = tf.image.random_flip_left_right(image)
    #image = tf.image.random_saturation(image, 0, 2)
    return image, label   

def get_training_dataset(train_file_name, image_size):
    dataset = load_dataset(train_file_name, image_size, labeled=True)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(validation_file_name, image_size,ordered=False):
    dataset = load_dataset(validation_file_name, image_size, labeled=True, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_test_dataset(test_file_name,image_size,ordered=False):
    dataset = load_dataset(test_file_name,image_size, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec
    # files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

# Learning Rate

In [None]:
EPOCHS = 12
# Learning Rate Schedule for Fine Tuning #
def exponential_lr(epoch,
                   start_lr = 0.00001, min_lr = 0.00001, max_lr = 0.00005,
                   rampup_epochs = 5, sustain_epochs = 0,
                   exp_decay = 0.8):

    def lr(epoch, start_lr, min_lr, max_lr, rampup_epochs, sustain_epochs, exp_decay):
        # linear increase from start to rampup_epochs
        if epoch < rampup_epochs:
            lr = ((max_lr - start_lr) /
                  rampup_epochs * epoch + start_lr)
        # constant max_lr during sustain_epochs
        elif epoch < rampup_epochs + sustain_epochs:
            lr = max_lr
        # exponential decay towards min_lr
        else:
            lr = ((max_lr - min_lr) *
                  exp_decay**(epoch - rampup_epochs - sustain_epochs) +
                  min_lr)
        return lr
    return lr(epoch,
              start_lr,
              min_lr,
              max_lr,
              rampup_epochs,
              sustain_epochs,
              exp_decay)

lr_callback = tf.keras.callbacks.LearningRateScheduler(exponential_lr, verbose=True)

rng = [i for i in range(EPOCHS)]
y = [exponential_lr(x) for x in rng]
plt.plot(rng, y)
print("Learning rate schedule: {:.3g} to {:.3g} to {:.3g}".format(y[0], max(y), y[-1]))

# Model - VGG 16 PRETRAINED MODEL

In [None]:
with strategy.scope():
    def create_vgg16_model(input_shape, N_CLASSES):
        pretrained_model = tf.keras.applications.VGG16(
            weights='imagenet',
            include_top=False,
            input_shape=input_shape
        )
        pretrained_model.trainable = True

        model = tf.keras.Sequential([
            pretrained_model,
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(N_CLASSES, activation='softmax')
        ])

        return model
def compile_model(model, optimizer):
    model.compile(
        optimizer=optimizer,
        loss='sparse_categorical_crossentropy',
        metrics=['sparse_categorical_accuracy']
    )

def fine_tuning_lr(epoch, base_lr=0.00001, multiplier=0.2):
    return base_lr * (multiplier ** epoch)

# Example usage:
initial_fine_tuning_lr = fine_tuning_lr(0)

In [None]:
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

# Confusion Matrix

In [None]:

import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()

In [None]:
def generate_confusion_matrix(model, validation_dataset, number_validation_images):
    images_ds = validation_dataset.map(lambda image, label: image)
    labels_ds = validation_dataset.map(lambda image, label: label).unbatch()

    correct_labels = next(iter(labels_ds.batch(number_validation_images))).numpy()
    probabilities = model.predict(images_ds)
    predictions = np.argmax(probabilities, axis=-1)

    labels = range(len(CLASSES))
    cmat = confusion_matrix(correct_labels, predictions, labels=labels)
    cmat = (cmat.T / cmat.sum(axis=1)).T  # normalize
    
    f1 = f1_score(correct_labels, predictions, labels=labels, average='macro')
    precision = precision_score(correct_labels, predictions, labels=labels, average='macro')
    recall = recall_score(correct_labels, predictions, labels=labels, average='macro')

    display_confusion_matrix(cmat, f1, precision, recall)

In [None]:
def compute_predictions(model, dataset):
    images_ds = dataset.map(lambda image, idnum: image)
    probabilities = model.predict(images_ds)
    predictions = np.argmax(probabilities, axis=-1)
    return predictions

def generate_submission(model, test_dataset, output_filename, number_test_images):
    predictions = compute_predictions(model, test_dataset)
    test_ids_ds = test_dataset.map(lambda image, idnum: idnum).unbatch()
    test_ids = next(iter(test_ids_ds.batch(number_test_images))).numpy().astype('U')
    np.savetxt(
        output_filename,
        np.rec.fromarrays([test_ids, predictions]),
        fmt=['%s', '%d'],
        delimiter=',',
        header='id,label',
        comments='',
    )

    print(f'Generated {output_filename} file successfully!')

# **RESOLUTION = 331 x 331**

In [None]:
IMAGE_SIZE_331 = [331, 331]
GCS_PATH_331 = GCS_DS_PATH + '/tfrecords-jpeg-331x331'

TRAINING_FILENAMES_331 = tf.io.gfile.glob(GCS_PATH_331 + '/train/*.tfrec')
VALIDATION_FILENAMES_331 = tf.io.gfile.glob(GCS_PATH_331 + '/val/*.tfrec')
TEST_FILENAMES_331 = tf.io.gfile.glob(GCS_PATH_331 + '/test/*.tfrec')

In [None]:
NUM_TRAINING_IMAGES_331 = count_data_items(TRAINING_FILENAMES_331)
NUM_VALIDATION_IMAGES_331 = count_data_items(VALIDATION_FILENAMES_331)
NUM_TEST_IMAGES_331 = count_data_items(TEST_FILENAMES_331)

print('Dataset - 331 x 331: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES_331, NUM_VALIDATION_IMAGES_331, NUM_TEST_IMAGES_331))

In [None]:
ds_train_331 = get_training_dataset(TRAINING_FILENAMES_331, IMAGE_SIZE_331)
ds_valid_331 = get_validation_dataset(VALIDATION_FILENAMES_331, IMAGE_SIZE_331)
ds_test_331 = get_test_dataset(TEST_FILENAMES_331, IMAGE_SIZE_331)

print("Training - 331:", ds_train_331)
print ("Validation - 331:", ds_valid_331)
print("Test - 331:", ds_test_331)

In [None]:
np.set_printoptions(threshold=15, linewidth=80)

print("Training - 331 data shapes:")
for image, label in ds_train_331.take(3):
    print(image.numpy().shape, label.numpy().shape)
print("Training - 331 data label examples:", label.numpy())

In [None]:
print("Test - 331 data shapes:")
for image, idnum in ds_test_331.take(3):
    print(image.numpy().shape, idnum.numpy().shape)
print("Test - 331 data IDs:", idnum.numpy().astype('U')) # U=unicode string

In [None]:
ds_iter_331 = iter(ds_train_331.unbatch().batch(20))

In [None]:
print("PRINTING ONE BATCH OF 331 x 331 SIZE IMAGES")
one_batch_331 = next(ds_iter_331)
display_batch_of_images(one_batch_331)

In [None]:
lr_callback_331_vgg16 = tf.keras.callbacks.LearningRateScheduler(exponential_lr, verbose=True)
lr_callback_331_vgg19 = tf.keras.callbacks.LearningRateScheduler(exponential_lr, verbose=True)

In [None]:
with strategy.scope():
    model_331_vgg16 = create_vgg16_model(IMAGE_SIZE_331 + [3], len(CLASSES))

In [None]:
optimizer_331_vgg16 = tf.keras.optimizers.Adam(learning_rate=initial_fine_tuning_lr)
optimizer_331_vgg19 = tf.keras.optimizers.Adam(learning_rate=initial_fine_tuning_lr)

In [None]:
with strategy.scope():
    compile_model(model_331_vgg16,optimizer_331_vgg16)
    model_331_vgg16.summary()

In [None]:
STEPS_PER_EPOCH_331 = NUM_TRAINING_IMAGES_331 // BATCH_SIZE

In [None]:
history_331_vgg16 = model_331_vgg16.fit(
    ds_train_331,
    validation_data=ds_valid_331,
    epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH_331,
    callbacks=[lr_callback_331_vgg16],
)

In [None]:
display_training_curves(
    history_331_vgg16.history['loss'],
    history_331_vgg16.history['val_loss'],
    'loss',
    211,
)
display_training_curves(
    history_331_vgg16.history['sparse_categorical_accuracy'],
    history_331_vgg16.history['val_sparse_categorical_accuracy'],
    'accuracy',
    212,
)

In [None]:
confusion_matrix_331_vgg16 = generate_confusion_matrix(model_331_vgg16, get_validation_dataset(VALIDATION_FILENAMES_331, IMAGE_SIZE_331, ordered = True), NUM_VALIDATION_IMAGES_331)

In [None]:
dataset_331 = get_validation_dataset(VALIDATION_FILENAMES_331, IMAGE_SIZE_331)
dataset_331 = dataset_331.unbatch().batch(20)
batch_331 = iter(dataset_331)

In [None]:
images, labels = next(batch_331)
probabilities_vgg16 = model_331_vgg16.predict(images)
predictions = np.argmax(probabilities_vgg16, axis=-1)
display_batch_of_images((images, labels), predictions)

In [None]:
dataset_331 = get_validation_dataset(VALIDATION_FILENAMES_331, IMAGE_SIZE_331)
loss, accuracy = model_331_vgg16.evaluate(dataset_331)
print(f'Validation Loss: {loss:.4f}, Validation Accuracy: {accuracy:.4f}')