# 1. Import Libraries

resnet50 이 아닌 다른 모델을 사용할 경우 `resnet50` 이 아닌 다른 모델을 import 한다

In [1]:
import os
import glob
import time

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions
import tflite_runtime.interpreter as tflite

# 2. Download Data

In [2]:
!mkdir data

In [3]:
!wget  -O data/img0.JPG "https://d17fnq9dkz9hgj.cloudfront.net/breed-uploads/2018/08/siberian-husky-detail.jpg?bust=1535566590&width=630"
!wget  -O data/img1.JPG "https://www.hakaimagazine.com/wp-content/uploads/header-gulf-birds.jpg"
!wget  -O data/img2.JPG "https://www.artis.nl/media/filer_public_thumbnails/filer_public/00/f1/00f1b6db-fbed-4fef-9ab0-84e944ff11f8/chimpansee_amber_r_1920x1080.jpg__1920x1080_q85_subject_location-923%2C365_subsampling-2.jpg"
!wget  -O data/img3.JPG "https://www.familyhandyman.com/wp-content/uploads/2018/09/How-to-Avoid-Snakes-Slithering-Up-Your-Toilet-shutterstock_780480850.jpg"

--2022-01-27 10:52:49--  https://d17fnq9dkz9hgj.cloudfront.net/breed-uploads/2018/08/siberian-husky-detail.jpg?bust=1535566590&width=630
Resolving d17fnq9dkz9hgj.cloudfront.net (d17fnq9dkz9hgj.cloudfront.net)... 54.192.175.42, 54.192.175.163, 54.192.175.115, ...
Connecting to d17fnq9dkz9hgj.cloudfront.net (d17fnq9dkz9hgj.cloudfront.net)|54.192.175.42|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 24112 (24K) [image/jpeg]
Saving to: ‘data/img0.JPG’


2022-01-27 10:52:49 (16.1 MB/s) - ‘data/img0.JPG’ saved [24112/24112]

--2022-01-27 10:52:49--  https://www.hakaimagazine.com/wp-content/uploads/header-gulf-birds.jpg
Resolving www.hakaimagazine.com (www.hakaimagazine.com)... 164.92.73.117
Connecting to www.hakaimagazine.com (www.hakaimagazine.com)|164.92.73.117|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 452718 (442K) [image/jpeg]
Saving to: ‘data/img1.JPG’


2022-01-27 10:52:52 (857 KB/s) - ‘data/img1.JPG’ saved [452718/452718]

--2

# 3. Define Variables

- 모델을 이미 다운 또는 convert 했다는 가정하에 코드 작성
- `tftft_model_path` 에는 convert 한 frozen graph 경로 값 할당 (디렉토리 경로)
  - 디렉토리 내에 `frozen_graph.pb` 파일이 있어야 함
- `tflite_model_path` 에는 컴파일 된 tflite_model 경로를 할당 (파일 경로) 
  - 컴파일 된 모델 파일의 경로 작성

In [4]:
cur_path = os.getcwd()

tftrt_model_path = os.path.join(cur_path, 'tftrt_model', 'frozen_graph.pb')
tflite_model_path = os.path.join(cur_path, 'tflite_model', 'resnet50_edgetpu.tflite')

In [5]:
print(tftrt_model_path)

/home/keti/tf_2.5.0/src/4.1. run tf-trt,lite/tftrt_model/frozen_graph.pb


# 4. Define Functions

In [6]:
def load_frozen_graph(input_path):
    
    def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
        def _imports_graph_def():
            tf.compat.v1.import_graph_def(graph_def, name="")

        wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
        import_graph = wrapped_import.graph

        if print_graph == True:
            print("-" * 50)
            print("Frozen model layers: ")
            layers = [op.name for op in import_graph.get_operations()]
            for layer in layers:
                print(layer)
            print("-" * 50)

        return wrapped_import.prune(
            tf.nest.map_structure(import_graph.as_graph_element, inputs),
            tf.nest.map_structure(import_graph.as_graph_element, outputs))

    # Load frozen graph using TensorFlow 1.x functions
    with tf.io.gfile.GFile(input_path, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        loaded = graph_def.ParseFromString(f.read())

    # Wrap frozen graph to ConcreteFunctions
    frozen_func = wrap_frozen_graph(graph_def=graph_def,
                                    inputs=["x:0"],
                                    outputs=["Identity:0"],
                                    print_graph=False)
    
    return frozen_func

In [7]:
"""
    tf-trt 모델을 load 할 것이라면 type=tf_trt
    tflite 모델을 load 할 것이라면 type=tf_lite 로 함수 호출
"""

def load_model(type, input_path):
    if(type == 'tf_trt'):
        print('start load tf-trt model...')
        loaded_model = load_frozen_graph(input_path)
        
    elif(type == 'tf_lite'):
        print('start load tflite model...')
        loaded_model = tflite.Interpreter(input_path, 
                          experimental_delegates=[tflite.load_delegate('libedgetpu.so.1')])
        loaded_model.allocate_tensors()
    return loaded_model

In [8]:
"""
reference: https://www.tensorflow.org/lite/guide/inference?hl=ko#load_and_run_a_model_in_python
"""

def predict(input_model, input_data):
    if(str(type(input_model)) == "<class 'tensorflow.python.eager.wrap_function.WrappedFunction'>"):
        # model type: trt frozen graph
        prediction = input_model(input_data)
        prediction = prediction[0].numpy()
        # return decode_predictions(preds, top=3)[0][0][1]
        
    elif(str(type(input_model)) == "<class 'tflite_runtime.interpreter.Interpreter'>"):
        # model type: tflite model
        input_details = input_model.get_input_details()
        output_details = input_model.get_output_details()

        input_data = np.array(input_data, dtype=np.uint8)
        input_model.set_tensor(input_details[0]['index'], input_data) # set input data to interpreter
        input_model.invoke()    # 추론
        prediction = input_model.get_tensor(output_details[0]['index']) # get output

    else:
        print('모델 타입이 일치하지 않음')
        return -1;
    
    return prediction

# Test

In [9]:
# Test the model on random input data.
img_path = 'data/img0.JPG'  # Siberian_husky
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
x = tf.constant(x)

### inference test: tf-trt model

In [10]:
tftrt_loaded_model = load_model(type='tf_trt', input_path=tftrt_model_path)

start load tf-trt model...


In [11]:
preds = predict(input_model=tftrt_loaded_model, input_data=x)

In [12]:
# 상위 5까지 출력, 모델에 따라 decode_predict import 변경 필요
decode_predictions(preds, top=5)[0][0][1]   # 최상위 클래스 출력

'Siberian_husky'

### inference test: tflite model

In [15]:
tflite_loaded_model = load_model(type='tf_lite', input_path=tflite_model_path)

start load tflite model...


In [16]:
preds = predict(input_model=tflite_loaded_model, input_data=x)

In [17]:
# 상위 5까지 출력, 모델에 따라 decode_predict import 변경 필요
decode_predictions(preds, top=5)[0][0][1]   # 최상위 클래스 출력

'nematode'