## 1. 模块导入

In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import sklearn

from tensorflow import keras
import tensorflow as tf
import sys
import os
import time
import datetime

for module in [np, pd, mpl, sklearn, keras, tf]:
    print(module.__name__, module.__version__)
    
gpus = tf.config.experimental.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(gpus[0], True)

numpy 1.18.1
pandas 0.25.3
matplotlib 3.1.2
sklearn 0.22.1
tensorflow_core.python.keras.api._v2.keras 2.2.4-tf
tensorflow 2.1.0


## 2. keras模型到tflite

In [2]:
keras_model = keras.models.load_model("./keras_hdf5_model/save/fashion_mnist.h5")
keras_model(np.ones((1, 28, 28, 1)))

<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[0.07633319, 0.10343505, 0.09567624, 0.02766261, 0.05517917,
        0.01143559, 0.10081402, 0.01948361, 0.4656996 , 0.0442809 ]],
      dtype=float32)>

In [3]:
keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
keras_to_tflite = keras_to_tflite_converter.convert()

In [4]:
model_dir = "tflite_models"
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
keras_tflite_file = os.path.join(model_dir, "keras_to_tflite")

with open(keras_tflite_file, "wb") as f:
    f.write(keras_to_tflite)

## 3. 具体函数到tflite

In [5]:
model_func = tf.function(lambda x: keras_model(x))
keras_concrete_func = model_func.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))

keras_concrete_func(tf.constant(np.ones((1, 28, 28, 1)), dtype=tf.float32))

<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[0.07633319, 0.10343505, 0.09567624, 0.02766261, 0.05517917,
        0.01143559, 0.10081402, 0.01948361, 0.4656996 , 0.0442809 ]],
      dtype=float32)>

In [6]:
concrete_to_tflite_converter = tf.lite.TFLiteConverter.from_concrete_functions([keras_concrete_func])
concrete_to_tflite = concrete_to_tflite_converter.convert()

concrete_tflite_file = os.path.join(model_dir, "concrete_to_tflite")

with open(concrete_tflite_file, "wb") as f:
    f.write(concrete_to_tflite)

## 4. `SavedModel`到tflite

In [7]:
savedModel_to_tflite_converter = tf.lite.TFLiteConverter.from_saved_model("./keras_saved_model/")
savedModel_to_tflite = savedModel_to_tflite_converter.convert()

savedModel_tflite_file = os.path.join(model_dir, "savedModel_to_tflite")

with open(savedModel_tflite_file, "wb") as f:
    f.write(savedModel_to_tflite)

## 5. tflite推理(interpreter)

In [8]:
def tflite_inference(tflite_file):
    
    with open(tflite_file, "rb") as f:
        tflite_content = f.read()
    
    interpreter = tf.lite.Interpreter(model_content=tflite_content)
    interpreter.allocate_tensors() 
    
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    print(input_details)
    print(output_details)
    
    input_data = tf.constant(np.ones(input_details[0]["shape"], dtype=np.float32))
    interpreter.set_tensor(input_details[0]['index'], input_data)
    
    interpreter.invoke()
    
    output_result = interpreter.get_tensor(output_details[0]["index"])
    
    return output_result

In [9]:
start_time = time.time()
print(tflite_inference(savedModel_tflite_file))
print(time.time()-start_time)

[{'name': 'conv2d_input', 'index': 28, 'shape': array([ 1, 28, 28,  1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
[{'name': 'Identity', 'index': 0, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
[[0.01264767 0.00580803 0.01378929 0.00291022 0.01054621 0.00768616
  0.03106462 0.00463048 0.87353516 0.03738211]]
0.007292747497558594


In [10]:
print(tflite_inference(keras_tflite_file))

[{'name': 'conv2d_12_input', 'index': 1, 'shape': array([ 1, 28, 28,  1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
[{'name': 'Identity', 'index': 0, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
[[0.07633319 0.10343514 0.09567622 0.0276626  0.05517913 0.01143559
  0.10081402 0.01948359 0.46569952 0.04428094]]


In [11]:
print(tflite_inference(concrete_tflite_file))

[{'name': 'x', 'index': 28, 'shape': array([ 1, 28, 28,  1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
[{'name': 'Identity', 'index': 0, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
[[0.07633319 0.10343514 0.09567622 0.0276626  0.05517913 0.01143559
  0.10081402 0.01948359 0.46569952 0.04428094]]
