In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Flatten, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy

# Define the model
model = Sequential([
    Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    Flatten(),
    Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer=Adam(),
              loss=SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

In [9]:
# Load the MNIST dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize the data
x_train = x_train / 255.0
x_test = x_test / 255.0

# Reshape the data to add the channel dimension
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

In [3]:
# Train the model
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_test, y_test))

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x2bd578cba90>

Convert and Quantize the Model to TFLite

The tf.lite.TFLiteConverter with tf.lite.Optimize.DEFAULT optimizes the model, including quantization to int8. This is done during the conversion to TFLite format.

In [10]:
# Ensure the data is in FLOAT32 format
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32')
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32')

# Provide a representative dataset for better quantization
def representative_data_gen():
    for input_value in tf.data.Dataset.from_tensor_slices(x_train).batch(1).take(100):
        yield [input_value]

# Convert the model to TensorFlow Lite format with quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.representative_dataset = representative_data_gen

tflite_model = converter.convert()

# Save the quantized model
with open('model_quantized.tflite', 'wb') as f:
    f.write(tflite_model)



INFO:tensorflow:Assets written to: C:\Users\srico\AppData\Local\Temp\tmp0ifj8r7v\assets


INFO:tensorflow:Assets written to: C:\Users\srico\AppData\Local\Temp\tmp0ifj8r7v\assets


Verify Quantization

Use the TFLite interpreter to check if the model is quantized correctly to int8.

In [12]:
import numpy as np

# Load the quantized TFLite model
interpreter = tf.lite.Interpreter(model_path='model_quantized.tflite')
interpreter.allocate_tensors()

# Get input and output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Check if the input and output tensors are int8
print(f"Input type: {input_details[0]['dtype']}")
print(f"Output type: {output_details[0]['dtype']}")

# Check if intermediate tensors are int8
for detail in interpreter.get_tensor_details():
    tensor = interpreter.tensor(detail['index'])()
    print(f"Tensor {detail['name']} type: {tensor.dtype}")

Input type: <class 'numpy.float32'>
Output type: <class 'numpy.float32'>
Tensor serving_default_conv2d_input:0 type: float32
Tensor sequential/flatten/Const type: int32
Tensor sequential/dense/BiasAdd/ReadVariableOp type: int32
Tensor sequential/dense/MatMul type: int8
Tensor sequential/conv2d/BiasAdd/ReadVariableOp type: int32
Tensor sequential/conv2d/Conv2D type: int8
Tensor tfl.quantize type: int8
Tensor sequential/conv2d/Relu;sequential/conv2d/BiasAdd;sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp type: int8
Tensor sequential/flatten/Reshape type: int8
Tensor sequential/dense/MatMul;sequential/dense/BiasAdd type: int8
Tensor StatefulPartitionedCall:01 type: int8
Tensor StatefulPartitionedCall:0 type: float32
Tensor  type: int8


In [13]:
# Example: Get quantization parameters for the input layer
input_scale, input_zero_point = input_details[0]['quantization']

# Example: Get quantization parameters for the output layer
output_scale, output_zero_point = output_details[0]['quantization']

print(f"Input scale: {input_scale}, zero point: {input_zero_point}")
print(f"Output scale: {output_scale}, zero point: {output_zero_point}")

# Extract filter and bias quantization parameters
for i, detail in enumerate(interpreter.get_tensor_details()):
    if 'Conv2D' in detail['name']:
        filter_quant_params = detail['quantization_parameters']
        filter_scale = filter_quant_params['scales']
        filter_zero_point = filter_quant_params['zero_points']
        break

print(f"Filter scale: {filter_scale}, zero point: {filter_zero_point}")

Input scale: 0.0, zero point: 0
Output scale: 0.0, zero point: 0
Filter scale: [0.00458585 0.00083913 0.00345411 0.0079803  0.0006902  0.00072748
 0.00352251 0.00235633 0.0036677  0.00163086 0.0060626  0.00498375
 0.00565246 0.00654385 0.00211097 0.00171206 0.00551453 0.00431368
 0.00302073 0.00298816 0.00207585 0.00535683 0.00388202 0.00146785
 0.00416391 0.00835785 0.00572529 0.00205417 0.00353958 0.00174995
 0.00310315 0.00222287], zero point: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
