Install the required packages

In [None]:
!pip install tensorflow
!pip install tensorflow_datasets
!pip install numpy
!pip install tensorflow-model-optimization
!pip install tf_keras

Import the required libraries

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import tensorflow_model_optimization as tfmot

Load the EMNIST 'byclass' dataset

In [None]:
(train_ds, test_ds), info = tfds.load(
    'emnist/byclass',
    split=['train', 'test'],
    as_supervised=True,
    with_info=True
)

original_class_names = info.features['label'].names

dIn EMNIST byclass:<br>
0-9 are digits<br>
10-35 are uppercase letters A-Z<br>
36-61 are lowercase letters a-z<br>

Define the range of labels for uppercase letters

In [None]:
LABELS_TO_KEEP_START_INDEX = 10 # Corresponds to 'A'
LABELS_TO_KEEP_END_INDEX = 35  # Corresponds to 'Z'
NUM_CLASSES_TO_TRAIN = (LABELS_TO_KEEP_END_INDEX - LABELS_TO_KEEP_START_INDEX) + 1 # Should be 26

def filter_uppercase(_, label):
    is_uppercase = tf.logical_and(
        tf.greater_equal(label, LABELS_TO_KEEP_START_INDEX),
        tf.less_equal(label, LABELS_TO_KEEP_END_INDEX)
    )
    return is_uppercase

def remap_label_to_zero_indexed(image, label):
    new_label = label - LABELS_TO_KEEP_START_INDEX
    return image, new_label

def preprocess_image(image, label):
    image = tf.image.rot90(image, k=-1)
    image = tf.image.flip_left_right(image)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

IMG_HEIGHT = 28
IMG_WIDTH = 28
BATCH_SIZE = 32

train_batches = train_ds.filter(filter_uppercase).map(remap_label_to_zero_indexed).map(preprocess_image).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_batches = test_ds.filter(filter_uppercase).map(remap_label_to_zero_indexed).map(preprocess_image).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

Define the Base Keras Model (Non-Quantized)<br>
Perhaps add BatchNormalization for better quantization stability?<br>

In [None]:
base_model = tf.keras.Sequential([
    tf.keras.Input(shape=(IMG_HEIGHT, IMG_WIDTH, 1)),
    tf.keras.layers.Conv2D(16, (3, 3), activation='relu'), # Reduced filters
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu'), # Reduced filters
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'), # Reduced dense layer size
    tf.keras.layers.Dropout(0.3), # Keep dropout for regularization
    tf.keras.layers.Dense(NUM_CLASSES_TO_TRAIN, activation='softmax')
])

print("\n--- Base Keras Model Summary (Before QAT) ---")
base_model.summary()

Apply Quantization-Aware Training Wrappers<br>
This step wraps the layers of the base_model with quantization operations.<br>

In [None]:
quantize_model = tfmot.quantization.keras.quantize_model
q_aware_model = quantize_model(base_model)

Compile the Quantization-Aware Model

In [None]:
q_aware_model.compile(optimizer='adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                      metrics=['accuracy'])

print("\n--- Quantization-Aware Model Summary (During Training) ---")
q_aware_model.summary()

Define an EarlyStopping Callback<br>
Stop training when validation accuracy reaches 97.5% (0.975) or more. Or when there's been no improvement for at least 20 epochs.<br>
Restore model weights from the best epoch.

In [None]:
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_accuracy',
    min_delta=0.001,
    patience=20,
    verbose=1,
    mode='max',
    baseline=0.976,
    restore_best_weights=True
)

Define different weights for problematic classes. If the model tends to favour certain characters, over visually similar ones, this can be remediated by lowering its weight below 1.0.<br>
The opposite is also true, if the model tends to ignore certain characters, their weight can be increased above 1.0. Up to 3.0<br>
The different weights are estimations based on manual testing on the actual device - an ESP32 with a Cirque trackpad.<br>

In [None]:
class_weight = {i: 1.0 for i in range(NUM_CLASSES_TO_TRAIN)}
class_weight[1] = 0.8  # "B"
class_weight[2] = 3.0  # "C"
class_weight[4] = 2.6  # "E"
class_weight[6] = 1.5  # "G"
class_weight[8] = 3.0  # "I"
class_weight[9] = 2.8  # "J"
class_weight[14] = 2.0  # "O"
class_weight[15] = 1.7  # "P"
class_weight[16] = 0.08  # "Q"
class_weight[17] = 1.5  # "R"
class_weight[18] = 2.0  # "S"
class_weight[24] = 2.0 # "Y"
class_weight[25] = 2.2 # "Z"

Train the Quantization-Aware Model with the Callback<br>
Training will run for 50 epochs or until the early stopping condition is met.<br>
<br>
_Note that warnings such as: "Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence" are to be expected with certain Tensorflow versions. They are harmless._

In [None]:
print(f"\n--- Training QAT Model (Stopping at val_accuracy >= {early_stopping_callback.baseline}) ---")
history = q_aware_model.fit(
    train_batches,
    epochs=50,
    validation_data=test_batches,
    callbacks=[early_stopping_callback],
    class_weight=class_weight
)

Evaluate the Quantization-Aware Model

In [None]:
loss, accuracy = q_aware_model.evaluate(test_batches)
print(f"\n--- Quantization-Aware Model Evaluation (Keras, Float Sim) ---")
print(f"Test Loss: {loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")

Convert the Quantization-Aware Model to LiteRT (.tflite) with FLOAT16 weights<br>
Input and output are still float32 (0.0-1.0)<br>

In [None]:
converter_qat = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter_qat.optimizations = [tf.lite.Optimize.DEFAULT]

converter_qat.target_spec.supported_types = [tf.float16]

converter_qat.inference_input_type = tf.float32
converter_qat.inference_output_type = tf.float32

tflite_model_qat_float16 = converter_qat.convert()

Save the Quantization-Aware TFLite model (Float16 weights)<br>

In [None]:
qat_float16_model_path = '.workdir/emnist_uppercase_qat_float16_model.tflite'
with open(qat_float16_model_path, 'wb') as f:
    f.write(tflite_model_qat_float16)

print(f"\nQuantization-Aware TFLite model (Float16 weights) saved to: {qat_float16_model_path}")

Evaluate the Float16 Quantized LiteRT Model

In [None]:
interpreter = tf.lite.Interpreter(model_content=tflite_model_qat_float16)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

quant_test_accuracy = 0
num_test_samples = 0

for images, labels in test_batches:
    for i in range(images.shape[0]):
        input_data = images[i:i+1]
        true_label = labels[i]

        interpreter.set_tensor(input_details[0]['index'], input_data)
        interpreter.invoke()
        output = interpreter.get_tensor(output_details[0]['index'])

        predicted_label = np.argmax(output[0])

        if predicted_label == true_label:
            quant_test_accuracy += 1
        num_test_samples += 1

final_quant_accuracy = quant_test_accuracy / num_test_samples
print(f"\n--- QAT TFLite Model Evaluation (Python Interpreter, Float16 Weights) ---")
print(f"Test Accuracy: {final_quant_accuracy:.4f}")

Save the Quantization-Aware TFLite model as a C array<br>

In [None]:
with open(qat_float16_model_path, 'rb') as f:
    tflite_model = f.read()

c_array_name = 'emnist_uppercase_model_qat_float16'
c_array_code = f"const unsigned char {c_array_name}[] = {{\n"
c_array_code += ', '.join([f'0x{byte:02x}' for byte in tflite_model])
c_array_code += '\n};'

output_file = '../include/emnist_uppercase_model_qat_float16.h'
with open(output_file, 'w') as f:
    f.write(c_array_code)

print(f"QAT Float16 model converted to C array and saved as {output_file}")