In [1]:
import os
import tempfile

import tensorflow as tf
import tensorflow_model_optimization as tfmot

from tensorflow import keras

In [2]:
physical_devices = tf.config.list_physical_devices('GPU')
try:
  tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  # Invalid device or cannot modify virtual devices once initialized.
  pass

In [3]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape = (28, 28)),
  keras.layers.Reshape(target_shape = (28, 28, 1)),
  keras.layers.Conv2D(filters = 12, kernel_size = (3, 3), activation = 'relu'),
  keras.layers.MaxPooling2D(pool_size = (2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer = 'adam',
              loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics = ['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs = 1,
  validation_split = 0.1,
)



<tensorflow.python.keras.callbacks.History at 0x7f6ae0208410>

### Clone and fine-tune pre-trained model with quantization aware training

The resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8).

In [4]:
quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer = 'adam',
              loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
              metrics = ['accuracy'])

q_aware_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quant_reshape (QuantizeWrapp (None, 28, 28, 1)         1         
_________________________________________________________________
quant_conv2d (QuantizeWrappe (None, 26, 26, 12)        147       
_________________________________________________________________
quant_max_pooling2d (Quantiz (None, 13, 13, 12)        1         
_________________________________________________________________
quant_flatten (QuantizeWrapp (None, 2028)              1         
_________________________________________________________________
quant_dense (QuantizeWrapper (None, 10)                20295     
Total params: 20,445
Trainable params: 20,410
Non-trainable params: 35
_________________________________________________________________


In [5]:
### Train and evaluate the model against baseline

train_images_subset = train_images[0:1000] 
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size = 500, epochs = 1, validation_split = 0.1)



<tensorflow.python.keras.callbacks.History at 0x7f6acc771110>