# Trainer

Using the DEiT small model here (DEiT Small Distilled Patch 16, Image size 244 x 244) in the interest of time and space for deployment

In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub

from tensorflow import keras
from keras.applications import imagenet_utils

from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [2]:
DATASET_SIZE = 9367
IMAGE_SIZE = 224
BATCH_SIZE = 8
WORKERS = 4
EPOCHS = 10

BASE_PATH='../data'

classes = [
    'cup', 
    'fork', 
    'glass', 
    'knife', 
    'plate', 
    'spoon'
]

First, we will load the training dataframe and split it into train and validation

In [3]:
df_train_full = pd.read_csv('data/train.csv', dtype={'Id': str})
df_train_full['filename'] = 'data/images/' + df_train_full['Id'] + '.jpg'
df_train_full.head()

Unnamed: 0,Id,label,filename
0,560,glass,data/images/0560.jpg
1,4675,cup,data/images/4675.jpg
2,875,glass,data/images/0875.jpg
3,4436,spoon,data/images/4436.jpg
4,8265,plate,data/images/8265.jpg


In [4]:
val_cutoff = int(len(df_train_full) * 0.8)
df_train = df_train_full[:val_cutoff]
df_val = df_train_full[val_cutoff:]

## Training

Now let's create image generators

In [5]:
# These models don't have the imagenet preprocessing built in so I have to apply this
def preprocess_input(x, data_format=None):
    return imagenet_utils.preprocess_input(
        x, data_format=data_format, mode="tf"
    )

In [6]:
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    dtype="float16"
)

train_generator = train_datagen.flow_from_dataframe(
    df_train,
    x_col='filename',
    y_col='label',
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
)

val_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    dtype="float16"
)

val_generator = val_datagen.flow_from_dataframe(
    df_val,
    x_col='filename',
    y_col='label',
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
)

Found 4447 validated image filenames belonging to 6 classes.
Found 1112 validated image filenames belonging to 6 classes.


In [7]:
classes = np.array(list(train_generator.class_indices.keys()))
classes

array(['cup', 'fork', 'glass', 'knife', 'plate', 'spoon'], dtype='<U5')

In [8]:
earlystopping = tf.keras.callbacks.EarlyStopping(
    monitor = 'val_accuracy',
    min_delta = 1e-4,
    patience = 3,
    mode = 'max',
    restore_best_weights = True,
    verbose = 1
)

checkpoint = keras.callbacks.ModelCheckpoint(
    'deit_s_d_p16_224_{epoch:02d}_{val_accuracy:.3f}.h5',
    save_best_only=True,
    save_weights_only=False,
    monitor='val_accuracy',
    mode='max'
)

callbacks = [earlystopping, checkpoint]

In [9]:
def get_model_deit(model_url, res=IMAGE_SIZE, num_classes=len(classes)) -> tf.keras.Model:
    inputs = tf.keras.Input((res, res, 3))
    hub_module = hub.KerasLayer(model_url, trainable=False)

    base_model_layers, _ = hub_module(inputs)   # Second output in the tuple is a dictionary containing attention scores.
    outputs = keras.layers.Dense(num_classes, activation="softmax")(base_model_layers)
    
    return tf.keras.Model(inputs, outputs) 

Warnings are normal; the pre-trained weights for the original classifications heads are not being skipped.

In [10]:
def build_model():
    model_gcs_path = "http://tfhub.dev/sayakpaul/deit_base_distilled_patch16_224_fe/1"
    model = get_model_deit(model_gcs_path)

    # Define the optimizer learning rate as a hyperparameter.
    learning_rate = 1e-2
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model

In [11]:
model = build_model()



In [12]:
history = model.fit(
    x = train_generator,
    validation_data=val_generator,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    workers=WORKERS,
    callbacks=callbacks
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 5: early stopping


In [13]:
import bentoml
bentoml.tensorflow.save_model("kitchenware-classification", model)

ModuleNotFoundError: No module named 'bentoml'