Install the required packages

In [1]:
!pip install tensorflow
!pip install tensorflow_datasets
!pip install numpy
!pip install tensorflow-model-optimization
!pip install tf_keras


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip ins

Import the required libraries

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import tensorflow_model_optimization as tfmot

Load the EMNIST 'byclass' dataset

In [3]:
(train_ds, test_ds), info = tfds.load(
    'emnist/byclass',
    split=['train', 'test'],
    as_supervised=True,
    with_info=True
)

original_class_names = info.features['label'].names

dIn EMNIST byclass:<br>
0-9 are digits<br>
10-35 are uppercase letters A-Z<br>
36-61 are lowercase letters a-z<br>

Define the range of labels for uppercase letters

In [4]:
LABELS_TO_KEEP_START_INDEX = 10 # Corresponds to 'A'
LABELS_TO_KEEP_END_INDEX = 35  # Corresponds to 'Z'
NUM_CLASSES_TO_TRAIN = (LABELS_TO_KEEP_END_INDEX - LABELS_TO_KEEP_START_INDEX) + 1 # Should be 26

def filter_uppercase(_, label):
    is_uppercase = tf.logical_and(
        tf.greater_equal(label, LABELS_TO_KEEP_START_INDEX),
        tf.less_equal(label, LABELS_TO_KEEP_END_INDEX)
    )
    return is_uppercase

def remap_label_to_zero_indexed(image, label):
    new_label = label - LABELS_TO_KEEP_START_INDEX
    return image, new_label

def preprocess_image(image, label):
    image = tf.image.rot90(image, k=-1)
    image = tf.image.flip_left_right(image)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

IMG_HEIGHT = 28
IMG_WIDTH = 28
BATCH_SIZE = 32

train_batches = train_ds.filter(filter_uppercase).map(remap_label_to_zero_indexed).map(preprocess_image).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_batches = test_ds.filter(filter_uppercase).map(remap_label_to_zero_indexed).map(preprocess_image).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

Define the Base Keras Model (Non-Quantized)<br>
Perhaps add BatchNormalization for better quantization stability?<br>

In [5]:
base_model = tf.keras.Sequential([
    tf.keras.Input(shape=(IMG_HEIGHT, IMG_WIDTH, 1)),
    tf.keras.layers.Conv2D(16, (3, 3), activation='relu'), # Reduced filters
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu'), # Reduced filters
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'), # Reduced dense layer size
    tf.keras.layers.Dropout(0.3), # Keep dropout for regularization
    tf.keras.layers.Dense(NUM_CLASSES_TO_TRAIN, activation='softmax')
])

print("\n--- Base Keras Model Summary (Before QAT) ---")
base_model.summary()


--- Base Keras Model Summary (Before QAT) ---
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 26, 26, 16)        160       
                                                                 
 max_pooling2d (MaxPooling2  (None, 13, 13, 16)        0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 11, 11, 32)        4640      
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 5, 5, 32)          0         
 g2D)                                                            
                                                                 
 flatten (Flatten)           (None, 800)               0         
                                                                 
 dense (D

Apply Quantization-Aware Training Wrappers<br>
This step wraps the layers of the base_model with quantization operations.<br>

In [6]:
quantize_model = tfmot.quantization.keras.quantize_model
q_aware_model = quantize_model(base_model)

Compile the Quantization-Aware Model

In [7]:
q_aware_model.compile(optimizer='adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                      metrics=['accuracy'])

print("\n--- Quantization-Aware Model Summary (During Training) ---")
q_aware_model.summary()


--- Quantization-Aware Model Summary (During Training) ---
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer (QuantizeLa  (None, 28, 28, 1)         3         
 yer)                                                            
                                                                 
 quant_conv2d (QuantizeWrap  (None, 26, 26, 16)        195       
 perV2)                                                          
                                                                 
 quant_max_pooling2d (Quant  (None, 13, 13, 16)        1         
 izeWrapperV2)                                                   
                                                                 
 quant_conv2d_1 (QuantizeWr  (None, 11, 11, 32)        4707      
 apperV2)                                                        
                                                              

Define an EarlyStopping Callback<br>
Stop training when validation accuracy reaches 97.5% (0.975) or more. Or when there's been no improvement for at least 20 epochs.<br>
Restore model weights from the best epoch.

In [8]:
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_accuracy',
    min_delta=0.001,
    patience=20,
    verbose=1,
    mode='max',
    baseline=0.965,
    restore_best_weights=True
)

Define higher weights for problematic classes (e.g., "O"=14, "Q"=26)<br>
The different weights are estimations based on manual testing on the actual device - an ESP32 with a Cirque trackpad.<br>

In [9]:
class_weight = {i: 1.0 for i in range(NUM_CLASSES_TO_TRAIN)}
class_weight[1] = 0.8  # "B"
class_weight[2] = 3.0  # "C"
class_weight[4] = 2.0  # "E"
class_weight[6] = 1.5  # "G"
class_weight[8] = 3.0  # "I"
class_weight[9] = 2.8  # "J"
class_weight[14] = 2.0  # "O"
class_weight[15] = 1.7  # "P"
class_weight[16] = 0.05  # "Q"
class_weight[17] = 1.5  # "R"
class_weight[18] = 2.0  # "S"
class_weight[24] = 2.0 # "Y"
class_weight[25] = 2.0 # "Z"

Train the Quantization-Aware Model with the Callback<br>
Training will run for 50 epochs or until the early stopping condition is met.<br>

In [10]:
print(f"\n--- Training QAT Model (Stopping at val_accuracy >= {early_stopping_callback.baseline}) ---")
history = q_aware_model.fit(
    train_batches,
    epochs=50,
    validation_data=test_batches,
    callbacks=[early_stopping_callback],
    class_weight=class_weight
)


--- Training QAT Model (Stopping at val_accuracy >= 0.965) ---
Epoch 1/50
     41/Unknown - 0s 4ms/step - loss: 4.5025 - accuracy: 0.1799

2025-05-23 16:21:17.633490: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:387] The default buffer size is 262144, which is overridden by the user specified `buffer_size` of 8388608


   5898/Unknown - 24s 4ms/step - loss: 0.5580 - accuracy: 0.8713

2025-05-23 16:21:41.665685: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 2/50
  38/5905 [..............................] - ETA: 24s - loss: 0.3702 - accuracy: 0.9054

2025-05-23 16:21:44.989527: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 42: early stopping
Restoring model weights from the end of the best epoch: 22.


Evaluate the Quantization-Aware Model

In [11]:
loss, accuracy = q_aware_model.evaluate(test_batches)
print(f"\n--- Quantization-Aware Model Evaluation (Keras, Float Sim) ---")
print(f"Test Loss: {loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")


--- Quantization-Aware Model Evaluation (Keras, Float Sim) ---
Test Loss: 0.1020
Test Accuracy: 0.9768


Convert the Quantization-Aware Model to LiteRT (.tflite) with FLOAT16 weights<br>
Input and output are still float32 (0.0-1.0)<br>

In [12]:
converter_qat = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter_qat.optimizations = [tf.lite.Optimize.DEFAULT]

converter_qat.target_spec.supported_types = [tf.float16]

converter_qat.inference_input_type = tf.float32
converter_qat.inference_output_type = tf.float32

tflite_model_qat_float16 = converter_qat.convert()

INFO:tensorflow:Assets written to: /var/folders/c2/tzr4yl1d2_d2zcw6f3x5h_480000gn/T/tmpdt2qkg8d/assets


INFO:tensorflow:Assets written to: /var/folders/c2/tzr4yl1d2_d2zcw6f3x5h_480000gn/T/tmpdt2qkg8d/assets
W0000 00:00:1748011242.522097 3155413 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1748011242.522105 3155413 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2025-05-23 16:40:42.522245: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/c2/tzr4yl1d2_d2zcw6f3x5h_480000gn/T/tmpdt2qkg8d
2025-05-23 16:40:42.523435: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-05-23 16:40:42.523439: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /var/folders/c2/tzr4yl1d2_d2zcw6f3x5h_480000gn/T/tmpdt2qkg8d
I0000 00:00:1748011242.529023 3155413 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
2025-05-23 16:40:42.529753: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-05-23 16:40:42.557799: I tenso

Save the Quantization-Aware TFLite model (Float16 weights)<br>

In [13]:
qat_float16_model_path = '.workdir/emnist_uppercase_qat_float16_model.tflite'
with open(qat_float16_model_path, 'wb') as f:
    f.write(tflite_model_qat_float16)

print(f"\nQuantization-Aware TFLite model (Float16 weights) saved to: {qat_float16_model_path}")


Quantization-Aware TFLite model (Float16 weights) saved to: .workdir/emnist_uppercase_qat_float16_model.tflite


Evaluate the Float16 Quantized LiteRT Model

In [14]:
interpreter = tf.lite.Interpreter(model_content=tflite_model_qat_float16)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

quant_test_accuracy = 0
num_test_samples = 0

for images, labels in test_batches:
    for i in range(images.shape[0]):
        input_data = images[i:i+1]
        true_label = labels[i]

        interpreter.set_tensor(input_details[0]['index'], input_data)
        interpreter.invoke()
        output = interpreter.get_tensor(output_details[0]['index'])

        predicted_label = np.argmax(output[0])

        if predicted_label == true_label:
            quant_test_accuracy += 1
        num_test_samples += 1

final_quant_accuracy = quant_test_accuracy / num_test_samples
print(f"\n--- QAT TFLite Model Evaluation (Python Interpreter, Float16 Weights) ---")
print(f"Test Accuracy: {final_quant_accuracy:.4f}")

    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.



--- QAT TFLite Model Evaluation (Python Interpreter, Float16 Weights) ---
Test Accuracy: 0.9767


2025-05-23 16:40:47.423475: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Save the Quantization-Aware TFLite model as a C array<br>

In [15]:
with open(qat_float16_model_path, 'rb') as f:
    tflite_model = f.read()

c_array_name = 'emnist_uppercase_model_qat_float16'
c_array_code = f"const unsigned char {c_array_name}[] = {{\n"
c_array_code += ', '.join([f'0x{byte:02x}' for byte in tflite_model])
c_array_code += '\n};'

output_file = '../include/emnist_uppercase_model_qat_float16.h'
with open(output_file, 'w') as f:
    f.write(c_array_code)

print(f"QAT Float16 model converted to C array and saved as {output_file}")

QAT Float16 model converted to C array and saved as ../include/emnist_uppercase_model_qat_float16.h
