# TF2 로 양자화 인식 훈련 with MNIST

* https://www.tensorflow.org/model_optimization/guide/quantization/training_example

In [17]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    # 텐서플로가 첫 번째 GPU만 사용하도록 제한
    # 프로그램 시작시에 메모리 증가가 설정되어야만 합니다
    try:
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
        tf.config.experimental.set_memory_growth(gpus[0], True)
        print('GPU[0] is ready')
    except RuntimeError as e:
        # 프로그램 시작시에 접근 가능한 장치가 설정되어야만 합니다
        print(e)
else:
    print('Please check GPU available')
    
import os
import sys
import tensorflow as tf
from tensorflow.keras import datasets, layers, models, optimizers
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output
from six.moves import urllib
from tensorflow import feature_column as fc
import tensorflow_datasets as tfds
plt.rcParams["font.family"] = 'NanumBarunGothic'
TENSORBOARD_BINARY = '/home/hoondori/anaconda3/envs/ai/bin/tensorboard'
os.environ['TENSORBOARD_BINARY'] =  TENSORBOARD_BINARY
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) # suppress warning

GPU[0] is ready


In [3]:
import tempfile
import os
import tensorflow as tf
from tensorflow import keras

# QAT 없이 MNIST 모델 훈련

In [8]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_2 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 26, 26, 12)        120       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 13, 13, 12)        0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 2028)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                20290     
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
_________________________________________________________________


In [9]:
model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)



<tensorflow.python.keras.callbacks.History at 0x7eff053822b0>

# QAT로 pre-trained 모델 복제 및 fine-tuning

In [10]:
import tensorflow_model_optimization as tfmot
quantize_model = tfmot.quantization.keras.quantize_model

q_aware_model = quantize_model(model)

# quantize model requires a recompile
q_aware_model.compile(optimizer='adam', 
                      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])

q_aware_model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quantize_layer_2 (QuantizeLa (None, 28, 28)            3         
_________________________________________________________________
quant_reshape_2 (QuantizeWra (None, 28, 28, 1)         1         
_________________________________________________________________
quant_conv2d_2 (QuantizeWrap (None, 26, 26, 12)        147       
_________________________________________________________________
quant_max_pooling2d_2 (Quant (None, 13, 13, 12)        1         
_________________________________________________________________
quant_flatten_2 (QuantizeWra (None, 2028)              1         
_________________________________________________________________
quant_dense_2 (QuantizeWrapp (None, 10)                20295     
Total params: 20,448
Trainable params: 20,410
Non-trainable params: 38
_________________________________________________

In [12]:
# 미세 조정을 위해 subset of input data 선정 후 fitting
train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)



<tensorflow.python.keras.callbacks.History at 0x7eff05077780>

## baseline 과 q-aware finetuned 모델과의 성능 비교

In [13]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)

Baseline test accuracy: 0.9495000243186951
Quant test accuracy: 0.9535999894142151


# INT8 가중치 및  UINT8 활성화를 통한 양자화된 모델 얻기

In [18]:
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()  # byte array

INFO:tensorflow:Assets written to: /tmp/tmpb1cvs8df/assets


INFO:tensorflow:Assets written to: /tmp/tmpb1cvs8df/assets


## TF Lite 모델 성능 확인하기

In [20]:
def evaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on every image in the "test" dataset.
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print('Evaluated on {n} results so far.'.format(n=i))
        # Pre-processing: add batch dimension and convert to float32 to match with
        # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    print('\n')
        
    # Compare prediction results with ground truth labels to calculate accuracy.
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Quant TFLite test_accuracy: 0.9536
Quant TF test accuracy: 0.9535999894142151


## Quantized model(정수) 는 Float(32) TF Lite 보다 4배 작아진 것을 확인

In [21]:
# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
    f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
    f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))

INFO:tensorflow:Assets written to: /tmp/tmpow03y2a1/assets


INFO:tensorflow:Assets written to: /tmp/tmpow03y2a1/assets


Float model in Mb: 0.0806121826171875
Quantized model in Mb: 0.02344512939453125
