# Flower Classification on TPU
## Table of Contents
- Import Packages
- Distribution Strategy
- Common Parameters
- Common Functions
- Import datasets
- Understand the data
- Model Development
- Submission

Code edited from https://www.kaggle.com/lonnieqin/flower-classification-on-tpu

## Import Packages

In [1]:
import tensorflow as tf
import pandas as pd
from kaggle_datasets import KaggleDatasets
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import applications

## Distribution Strategy

In [2]:
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

## Common Parameters

In [3]:
IMAGE_SIZE = [192, 192] # at this size, a GPU will run out of memory. Use the TPU
EPOCHS = 100
BATCH_SIZE = 16 * strategy.num_replicas_in_sync

NUM_TRAINING_IMAGES = 12753
NUM_TEST_IMAGES = 7382
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
GCS_DS_PATH = KaggleDatasets().get_gcs_path() # you can list the bucket with "!gsutil ls $GCS_DS_PATH"
print(GCS_DS_PATH)

## Common Functions

 **Load datasets**

In [4]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def get_training_dataset():
    dataset = load_dataset(tf.io.gfile.glob(GCS_DS_PATH + '/tfrecords-jpeg-192x192/train/*.tfrec'), labeled=True)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

def get_validation_dataset():
    dataset = load_dataset(tf.io.gfile.glob(GCS_DS_PATH + '/tfrecords-jpeg-192x192/val/*.tfrec'), labeled=True, ordered=False)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(tf.io.gfile.glob(GCS_DS_PATH + '/tfrecords-jpeg-192x192/test/*.tfrec'), labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

**Sample Images**

In [5]:
def sample_images(images, row_count, column_count):
    fig, axs = plt.subplots(row_count, column_count, figsize=(10,10))
    for i in range(row_count):
        for j in range(column_count):
            axs[i,j].imshow(images[i * column_count + j])
            axs[i,j].axis('off')
    plt.show()

In [6]:
def find_mean_img(full_mat, title, size = (192, 192, 3)):
    # calculate the average
    mean_img = np.mean(full_mat, axis = 0)
    # reshape it back to a matrix
    mean_img = mean_img.reshape(size)
    plt.imshow(mean_img, vmin=0, vmax=255)
    plt.title(f'Average {title}')
    plt.axis('off')
    plt.show()
    return mean_img

## Import datasets

This data is loaded from Kaggle and automatically sharded to maximize parallelization.

In [7]:
training_dataset = get_training_dataset()
validation_dataset = get_validation_dataset()

## Understand the data

Let's see what the dataset looks like.

In [8]:
for item in training_dataset:
    images = item[0].numpy()
    labels = item[1].numpy()
    break
images.shape, labels.shape

In [9]:
np.unique(labels, return_counts = True)

In [10]:
sample_images(images, 4, 4)

In [11]:
print(labels) 

#moyenne pour deux labels différents genre
#recup images [label= tant]
full_mat_F1 = images [labels == 4, :,:,:]
full_mat_F2 = images [labels == 67,:,:,:]
full_mat_F3 = images [labels == 73,:,:,:]

In [12]:
F1_mean = find_mean_img(full_mat_F1, 'F1')
F2_mean = find_mean_img(full_mat_F2, 'F2')
F3_mean = find_mean_img(full_mat_F3, 'F3')

# Model Development

### Model Checkpoint

In [13]:
checkpoint_path = "model.h5"
checkpoint = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_best_only=True)

### Learning Rate Scheduler

In [14]:
LR_START = 0.00005
LR_MAX =   0.00005 * strategy.num_replicas_in_sync
LR_MIN =   0.0000025
LR_RAMPUP_EPOCHS = 3
LR_SUSTAIN_EPOCHS = 6
LR_EXP_DECAY = .8
def scheduler_callback(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr =  np.random.random_sample() * LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr
scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler_callback, verbose=True)

## Early Stopping

In [15]:
early_stop = tf.keras.callbacks.EarlyStopping(patience=10)

In [16]:
callbacks = [early_stop, checkpoint, scheduler]

## Optimizer

In [17]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, 
                                                 beta_1=0.9, 
                                                 beta_2=0.999, 
                                                 epsilon=1e-07, 
                                                 amsgrad=False)

### Get Pretrain Model
I will build the Model based on kears pretrained Models. There are many pretrained Models such as InceptionV3, EfficientNet 0 - 7 and so on. You can choose the pretrained Model you like to train the Model.

In [18]:
model_types = [
    "dense_net", 
    "xception", 
    "inception", 
    "inceptionResNet",
]

model_type = model_types[2]

def get_pretraind_model(model_type, input_shape):
    if model_type == "dense_net":
        return applications.densenet.DenseNet121(
                include_top=False,
                input_shape=input_shape               
            )
    if model_type == "xception":
        return applications.Xception(
            include_top=False,
            input_shape=input_shape                        
        )
    if model_type == "inception":
        return applications.InceptionV3(
            include_top=False,
            input_shape=input_shape                          
        )
    if model_type == "inceptionResNet":
        return applications.InceptionResNetV2(
            include_top=False,
                input_shape=input_shape                         
        )
   

### Train the Model

In [19]:
def train(
    model_type, epochs, optimizer, callbacks, 
    strategy, layers):
    tf.keras.backend.clear_session()
    with strategy.scope():  
        input_shape = [*IMAGE_SIZE, 3]   
        pretrained_model = get_pretraind_model(model_type, input_shape)
        print(pretrained_model.summary())
        pretrained_model.trainable = True 
        all_layers = [pretrained_model] + layers + [tf.keras.layers.Dense(104, activation='softmax')]
        model = tf.keras.Sequential(all_layers)
        model.compile(
            optimizer=optimizer,
            loss = 'sparse_categorical_crossentropy',
            metrics=['sparse_categorical_accuracy']
        )
        history = model.fit(training_dataset, 
                            steps_per_epoch=STEPS_PER_EPOCH, 
                            epochs=epochs, 
                            validation_data=validation_dataset, 
                            callbacks=callbacks
                           )
        pd.DataFrame(history.history).plot()
        plt.show()
        return model

In [20]:
model = train(
    model_type, EPOCHS, optimizer, callbacks, strategy, 
    layers=[
        tf.keras.layers.Dropout(0.5), 
        tf.keras.layers.GlobalAveragePooling2D(), 
        tf.keras.layers.Dropout(0.5)
    ]
)

## Submission

In [21]:
test_ds = get_test_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and ids, order matters.
model.load_weights(checkpoint_path)
print('Computing predictions...')
test_images_ds = test_ds.map(lambda image, idnum: image)
probabilities = model.predict(test_images_ds)
predictions = np.argmax(probabilities, axis=-1)
print(predictions)

print('Generating submission.csv file...')
test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')