# Augmentation testing

The best model was BeitLargePatch16 with validation accuracy 0.9856114983558655.
- We test with the DEiT base model here (small and fast) in the interest of time

Here we will experiment with image augmentation to see if this can be improved. If not, the extra complexity is not worth it.

In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

import tensorflow_hub as hub
import tensorflow_addons as tfa

from tensorflow import keras
from keras.applications import imagenet_utils

from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [2]:
DATASET_SIZE = 9367
IMAGE_SIZE = 224
BATCH_SIZE = 8
WORKERS = 4
EPOCHS = 10

BASE_PATH='../data'

classes = [
    'cup', 
    'fork', 
    'glass', 
    'knife', 
    'plate', 
    'spoon'
]

First, we will load the training dataframe and split it into train and validation

In [3]:
df_train_full = pd.read_csv('data/train.csv', dtype={'Id': str})
df_train_full['filename'] = 'data/images/' + df_train_full['Id'] + '.jpg'
df_train_full.head()

Unnamed: 0,Id,label,filename
0,560,glass,data/images/0560.jpg
1,4675,cup,data/images/4675.jpg
2,875,glass,data/images/0875.jpg
3,4436,spoon,data/images/4436.jpg
4,8265,plate,data/images/8265.jpg


In [4]:
val_cutoff = int(len(df_train_full) * 0.8)
df_train = df_train_full[:val_cutoff]
df_val = df_train_full[val_cutoff:]

## Baseline model

Now let's create image generators

In [5]:
# These models don't have the imagenet preprocessing built in so I have to apply this
def preprocess_input(x, data_format=None):
    return imagenet_utils.preprocess_input(
        x, data_format=data_format, mode="tf"
    )

In [6]:
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    dtype="float16"
)

train_generator = train_datagen.flow_from_dataframe(
    df_train,
    x_col='filename',
    y_col='label',
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
)

val_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    dtype="float16"
)

val_generator = val_datagen.flow_from_dataframe(
    df_val,
    x_col='filename',
    y_col='label',
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
)

Found 4447 validated image filenames belonging to 6 classes.
Found 1112 validated image filenames belonging to 6 classes.


In [7]:
classes = np.array(list(train_generator.class_indices.keys()))
classes

array(['cup', 'fork', 'glass', 'knife', 'plate', 'spoon'], dtype='<U5')

In [8]:
earlystopping = tf.keras.callbacks.EarlyStopping(
    monitor = 'val_accuracy',
    min_delta = 1e-4,
    patience = 3,
    mode = 'max',
    restore_best_weights = True,
    verbose = 1
)

callbacks = [earlystopping]

In [9]:
def get_model_deit(model_url, res=IMAGE_SIZE, num_classes=len(classes)) -> tf.keras.Model:
    inputs = tf.keras.Input((res, res, 3))
    hub_module = hub.KerasLayer(model_url, trainable=False)

    base_model_layers, _ = hub_module(inputs)   # Second output in the tuple is a dictionary containing attention scores.
    outputs = keras.layers.Dense(num_classes, activation="softmax")(base_model_layers)
    
    return tf.keras.Model(inputs, outputs) 

Warnings are normal; the pre-trained weights for the original classifications heads are not being skipped.

In [10]:
model_gcs_path = "http://tfhub.dev/sayakpaul/deit_base_distilled_patch16_224_fe/1"
model = get_model_deit(model_gcs_path)



In [11]:
learning_rate = 0.01
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
loss = keras.losses.CategoricalCrossentropy()

model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

In [12]:
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=EPOCHS,
    workers=WORKERS,
    callbacks=callbacks
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 6: early stopping


## Augmented Model

In [13]:
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=90,
    # width_shift_range=0.1,
    # height_shift_range=0.1,
    # shear_range=0.1,
    # zoom_range=0.1,
    vertical_flip=True,
    horizontal_flip=True,
    dtype="float16",
)

train_generator = train_datagen.flow_from_dataframe(
    df_train,
    x_col='filename',
    y_col='label',
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    seed=1
)

Found 4447 validated image filenames belonging to 6 classes.


In [14]:
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

In [15]:
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=EPOCHS,
    workers=WORKERS,
    callbacks=callbacks
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 5: early stopping


In [None]:
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    # rotation_range=10,
    width_shift_range=0.25,
    height_shift_range=0.25,
    # shear_range=0.25,
    # zoom_range=0.1,
    horizontal_flip=True,
    dtype="float16"
)

train_generator = train_datagen.flow_from_dataframe(
    df_train,
    x_col='filename',
    y_col='label',
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
)

Found 4447 validated image filenames belonging to 6 classes.


In [None]:
model = get_model_deit(model_gcs_path)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

In [None]:
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=EPOCHS,
    workers=WORKERS,
    callbacks=callbacks
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 5: early stopping


In [15]:
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=45,
    # width_shift_range=0.1,
    # height_shift_range=0.1,
    # shear_range=0.25,
    # zoom_range=0.2,
    horizontal_flip=True,
)

train_generator = train_datagen.flow_from_dataframe(
    df_train,
    x_col='filename',
    y_col='label',
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
)

Found 4447 validated image filenames belonging to 6 classes.


In [16]:
model = get_model_deit(model_gcs_path)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

In [17]:
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=EPOCHS,
    workers=WORKERS,
    callbacks=callbacks
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 6: early stopping


Very little to be gained from image augmentation for this dataset it seems. We will forgo augmentation as the additional complexity does not come with significant benefit 