In [46]:
import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

cnn_model = tf.keras.Sequential([
  tf.keras.layers.Conv2D(12, (3, 3), activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D((2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)])

cnn_model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
cnn_model.summary()
cnn_model.fit(x_train, y_train, epochs=1, validation_split=0.1)
cnn_model.evaluate(x_test,  y_test)


Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_10 (Conv2D)           (None, 26, 26, 12)        120       
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 13, 13, 12)        0         
_________________________________________________________________
flatten_7 (Flatten)          (None, 2028)              0         
_________________________________________________________________
dense_8 (Dense)              (None, 10)                20290     
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
_________________________________________________________________


[0.1224130243062973, 0.9653000235557556]

In [33]:
converter = tf.lite.TFLiteConverter.from_keras_model(cnn_model)
cnn_tflite_model = converter.convert()

with open("./cnn_model.tflite", 'wb') as f:
    f.write(cnn_tflite_model)

INFO:tensorflow:Assets written to: C:\Users\lswsj\AppData\Local\Temp\tmpvohgwcl2\assets
INFO:tensorflow:Assets written to: C:\Users\lswsj\AppData\Local\Temp\tmpvohgwcl2\assets


83476

In [62]:
import numpy as np

def run_tflite_model(path, x_test, y_test):
    interpreter = tf.lite.Interpreter(str(path))
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]

    y_pred = []
    for i, test_image in enumerate(x_test):
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)

        interpreter.set_tensor(input_details['index'], test_image)
        interpreter.invoke()
        output = interpreter.get_tensor(output_details['index'])

        y_pred.append(output.argmax())

    y_pred = np.array(y_pred)
    accuracy = (y_pred == y_test).mean()
    return accuracy


In [63]:
run_tflite_model("./cnn_model.tflite", x_test, y_test)

0.9638

In [58]:
import tensorflow_model_optimization as tfmot

quantized_cnn_model = tfmot.quantization.keras.quantize_model(cnn_model)

quantized_cnn_model.compile(optimizer='adam',
                      loss="sparse_categorical_crossentropy",
                      metrics=['accuracy'])
quantized_cnn_model.summary()

train_image_subset = x_train[0:1000]
train_labels_subset = y_train[0:1000]

quantized_cnn_model.fit(train_image_subset, train_labels_subset, 
                        batch_size=500, epochs=1, validation_split=0.1)
quantized_cnn_model.evaluate(x_test, y_test)

Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quantize_layer_7 (QuantizeLa (None, 28, 28, 1)         3         
_________________________________________________________________
quant_conv2d_10 (QuantizeWra (None, 26, 26, 12)        147       
_________________________________________________________________
quant_max_pooling2d_8 (Quant (None, 13, 13, 12)        1         
_________________________________________________________________
quant_flatten_7 (QuantizeWra (None, 2028)              1         
_________________________________________________________________
quant_dense_8 (QuantizeWrapp (None, 10)                20295     
Total params: 20,447
Trainable params: 20,410
Non-trainable params: 37
_________________________________________________________________


[0.3501076400279999, 0.9667999744415283]

In [57]:
converter = tf.lite.TFLiteConverter.from_keras_model(quantized_cnn_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_cnn_tflite_model = converter.convert()

with open("./quantized_cnn_model.tflite", 'wb') as f:
    f.write(quantized_cnn_tflite_model)

INFO:tensorflow:Assets written to: C:\Users\lswsj\AppData\Local\Temp\tmpaft0o2qk\assets
INFO:tensorflow:Assets written to: C:\Users\lswsj\AppData\Local\Temp\tmpaft0o2qk\assets


23648

In [64]:
run_tflite_model("./quantized_cnn_model.tflite", x_test, y_test)


0.966