In [70]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import datasets, layers, models, losses

In [71]:
# Load dataset
(train_images, train_labels), (test_images, test_labels)= datasets.mnist.load_data()
train_images.shape

(60000, 28, 28)

In [72]:
# pad input 28x28 images with zeros to 32x32 images and scaled 8-bit pixel values to values between 0-1
train_images = tf.pad(train_images, [[0, 0], [2,2], [2,2]])/255
test_images = tf.pad(test_images, [[0, 0], [2,2], [2,2]])/255
train_images.shape

TensorShape([60000, 32, 32])

In [73]:
train_images = tf.expand_dims(train_images, axis=3, name=None)
test_images = tf.expand_dims(test_images, axis=3, name=None)
train_images.shape

TensorShape([60000, 32, 32, 1])

In [74]:
val_images = train_images[-2000:,:,:,:] 
val_labels = train_labels[-2000:] 
train_images = train_images[:-2000,:,:,:] 
train_labels = train_labels[:-2000]

In [75]:
model = models.Sequential()
model.add(layers.Conv2D(6, 5, activation='relu6', input_shape=train_images.shape[1:]))
model.add(layers.MaxPooling2D(2))
# model.add(layers.Activation('sigmoid'))
model.add(layers.Conv2D(16, 5, activation='relu6'))
model.add(layers.MaxPooling2D(2))
# model.add(layers.Activation('sigmoid'))
model.add(layers.Conv2D(120, 5, activation='relu6'))
model.add(layers.Flatten())
model.add(layers.Dense(84, activation='relu6'))
model.add(layers.Dense(10, activation='softmax'))
model.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_6 (Conv2D)           (None, 28, 28, 6)         156       
                                                                 
 max_pooling2d_4 (MaxPoolin  (None, 14, 14, 6)         0         
 g2D)                                                            
                                                                 
 conv2d_7 (Conv2D)           (None, 10, 10, 16)        2416      
                                                                 
 max_pooling2d_5 (MaxPoolin  (None, 5, 5, 16)          0         
 g2D)                                                            
                                                                 
 conv2d_8 (Conv2D)           (None, 1, 1, 120)         48120     
                                                                 
 flatten_2 (Flatten)         (None, 120)              

In [11]:
model.compile(optimizer='adam', loss=losses.sparse_categorical_crossentropy, metrics=['accuracy'])
history = model.fit(train_images, train_labels, batch_size=64, epochs=10, validation_data=(val_images, val_labels))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [13]:
def representative_data_gen():
    for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
        # Model has only one input so each data point has one element.
        yield [input_value]

In [14]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8  # or tf.uint8
converter.inference_output_type = tf.int8  # or tf.uint8
tflite_quant_model = converter.convert()

INFO:tensorflow:Assets written to: /var/folders/xs/nn2f1m4d4vg3mp72k2gv6c8h0000gn/T/tmpm1p672ss/assets


INFO:tensorflow:Assets written to: /var/folders/xs/nn2f1m4d4vg3mp72k2gv6c8h0000gn/T/tmpm1p672ss/assets
2023-11-09 23:11:43.171724: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-09 23:11:43.171744: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2023-11-09 23:11:43.172160: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/xs/nn2f1m4d4vg3mp72k2gv6c8h0000gn/T/tmpm1p672ss
2023-11-09 23:11:43.174788: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2023-11-09 23:11:43.174807: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/xs/nn2f1m4d4vg3mp72k2gv6c8h0000gn/T/tmpm1p672ss
2023-11-09 23:11:43.178866: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:382] MLIR V1 optimization pass is not enabled
2023-11-09 23:11:43.180905: I tensorflow/cc/saved_model/load

In [15]:
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

input:  <class 'numpy.int8'>
output:  <class 'numpy.int8'>


In [16]:
import pathlib

tflite_models_dir = pathlib.Path("../saved_models")
tflite_models_dir.mkdir(exist_ok=True, parents=True)

# Save the unquantized model:
tf_model_file = tflite_models_dir/"lenet5.keras"
model.save(tf_model_file)

# Save the quantized model:
tflite_model_quant_file = tflite_models_dir/"lenet5_int8.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)

70328

In [59]:
# Helper function to run inference on a TFLite model
def run_tflite_model(tflite_file, test_image_indices):
  global test_images

  # Initialize the interpreter
  interpreter = tf.lite.Interpreter(model_path=str(tflite_file))
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  output_details = interpreter.get_output_details()[0]

  predictions = np.zeros((len(test_image_indices),), dtype=int)
  for i, test_image_index in enumerate(test_image_indices):
    test_image = test_images[test_image_index]

    # Check if the input type is quantized, then rescale input data to int8
    if input_details['dtype'] == np.int8:
      input_scale, input_zero_point = input_details["quantization"]
      test_image = test_image / input_scale + input_zero_point

    test_image = np.expand_dims(test_image, axis=0).astype(input_details["dtype"])
    interpreter.set_tensor(input_details["index"], test_image)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details["index"])[0]

    predictions[i] = output.argmax()

  return predictions

In [60]:
# Helper function to evaluate a TFLite model on all images
def evaluate_model(tflite_file, model_type):
  global test_images
  global test_labels

  test_image_indices = range(test_images.shape[0])
  predictions = run_tflite_model(tflite_file, test_image_indices)

  accuracy = (np.sum(test_labels== predictions) * 100) / len(test_images)

  print('%s model accuracy is %.4f%% (Number of test samples=%d)' % (
      model_type, accuracy, len(test_images)))

In [61]:
evaluate_model(tflite_model_quant_file, model_type="Int8")

Int8 model accuracy is 98.9000% (Number of test samples=10000)


In [62]:
import pickle

tflite_interpreter = tf.lite.Interpreter(model_path='../saved_models/lenet5_int8.tflite')
tflite_interpreter.allocate_tensors()

tensor_details = tflite_interpreter.get_tensor_details()
num_fc_layers = 2
num_conv2d_layers = 3

obj = []
cache = []

for dict in tensor_details:
    # print(dict)
    i = dict['index']
    name = dict['name']
    shape = dict['shape']
    if ';' not in name:
        if 'BiasAdd' in name:
            bias = tflite_interpreter.tensor(i)()
            # print(i, name, shape)
            # print(bias)
        if 'MatMul' in name:
            weights = tflite_interpreter.tensor(i)()
            reshaped_weights = np.transpose(weights)
            # print(i, name, reshaped_weights.shape)
            # print(reshaped_weights)
            cache = {'fc' + str(num_fc_layers) + '.weights': reshaped_weights, 'fc' + str(num_fc_layers) + '.bias': bias}
            obj.append(cache)
            num_fc_layers -= 1
        if name.split('/')[-1] == 'Conv2D':
            weights = tflite_interpreter.tensor(i)()
            # print(i, name, shape)
            reshaped_weights = np.zeros(dtype=np.int8, shape=(weights.shape[0], weights.shape[3], weights.shape[1], weights.shape[2]))
            for l in range(weights.shape[0]):
                for k in range(weights.shape[1]):
                    for j in range(weights.shape[2]):
                        for i in range(weights.shape[3]):
                            reshaped_weights[l][i][k][j] = weights[l][k][j][i]
            # print(i, name, reshaped_weights.shape)
            # print(reshaped_weights)
            cache = {'conv' + str(num_conv2d_layers) + '.weights': reshaped_weights, 'conv' + str(num_conv2d_layers) + '.bias': bias}
            obj.append(cache)
            num_conv2d_layers -= 1

with open('./params.pkl', 'wb') as handle:
    pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [63]:
with open('./params.pkl', 'rb') as handle:
    b = pickle.load(handle)
    print(b[4]['conv1.weights'].shape, b[4]['conv1.bias'].shape)
    print(b[3]['conv2.weights'].shape, b[3]['conv2.bias'].shape)
    print(b[2]['conv3.weights'].shape, b[2]['conv3.bias'].shape)
    print(b[1]['fc1.weights'].shape, b[1]['fc1.bias'].shape)
    print(b[0]['fc2.weights'].shape, b[0]['fc2.bias'].shape)

(6, 1, 5, 5) (6,)
(16, 6, 5, 5) (16,)
(120, 16, 5, 5) (120,)
(120, 84) (84,)
(84, 10) (10,)
