# Train the model used to recognise images

In [3]:
import numpy as np
import importlib

## Process data

- Raw data (not included in this repository) is sourced from the Google Quick Draw project: https://github.com/googlecreativelab/quickdraw-dataset
- In order to successfully recognise partial images, the model is trained on each image at each stage of its drawing. For example, if a drawing has three strokes, it will appear in the dataset in three versions.
- Each drawing (at each stage) is trimmed to remove whitespace, scaled to fill the canvas, and then scaled to a 0-1 range.

In [4]:
import process
importlib.reload(process)

data_folder = 'data/'

CATEGORIES = process.get_categories_from_data(data_folder)
NUM_CATEGORIES = len(CATEGORIES)
category_ids = dict(zip(CATEGORIES, range(NUM_CATEGORIES)))

MAX_SEQ_LEN = 200

TRAIN_SAMPLE_SIZE = 4096     # number per category
TEST_SAMPLE_SIZE = 1024

X_train, Y_train, X_test, Y_test = process.get_train_test_data(data_folder, CATEGORIES, TRAIN_SAMPLE_SIZE, TEST_SAMPLE_SIZE)

In [5]:
print(X_train.shape)
print(Y_train.shape)
print(X_test.shape)
print(Y_test.shape)

(1413120, 200, 3)
(1413120,)
(1024, 200, 3)
(1024,)


## Training

In [6]:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint

import classifier
importlib.reload(classifier)

input_shape = (MAX_SEQ_LEN, 3)
rnn = classifier.RNN(input_shape, NUM_CATEGORIES)

In [7]:
# Training parameters
epochs = 10
batch_size = 64

# Save the model weights with the best validation accuracy
checkpoint_filepath = 'best_model_weights.h5'
model_checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    # monitor='val_accuracy',
    # mode='max',
    # save_best_only=True
)

# Training the model
history = rnn.model.fit(
    X_train,             # Training data
    Y_train,             # Training labels
    epochs=epochs, 
    batch_size=batch_size,
    validation_data=(X_test, Y_test), # Validation data and labels
    shuffle=True,
    callbacks=[model_checkpoint_callback]
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
