# Pytorch -> ONNX -> TFLite

pytorch에서 TFLite까지 변환하는 것을 목표로 하는 Notebook입니다.

In [None]:
import torch
import torchvision
import os
from torchsummary import summary
import numpy as np

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Model load

In [None]:
model_dir = '/opt/ml/final_project/kyh/model/'

# model1. deeplabv3(backbone:mobileNetV3-large) 모델에 용범님이 올려주신 deeplabV3 pth 파일 weight 씌움
model_path = os.path.join(model_dir, 'deeplabv3_mobilenet_v3_large_sample.pth')
model_pt = torch.load(model_path)
model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_mobilenet_v3_large', pretrained=True).to(device)
model.load_state_dict(model_pt)

# model2. ONNX tutorial 예시 모델인 torchvision Alexnet
model2 = torchvision.models.alexnet(pretrained=True).to(device)

model.eval()
model2.eval()

#summary(model, input_size=(3,960,720))
#print(model)
#print(type(model2))

## Model -> ONNX

In [None]:
batch_size = 8
dummy_input = torch.zeros(batch_size, 3, 960, 720).to(device)

input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

# 'f' 파라미터로 파일 이름 맞추는 거 잊지말기
torch.onnx.export(model, dummy_input, f='sample_onnx_model.onnx', verbose=False, input_names=input_names, output_names=output_names,
                 operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)

## ONNX 빌드 잘 되었는지 체크

In [None]:
import onnx

# Load the ONNX model
model = onnx.load("/opt/ml/final_project/kyh/sample_onnx_model.onnx")

# Check that the IR is well formed
onnx.checker.check_model(model)

# Print a Human readable representation of the graph
onnx.helper.printable_graph(model.graph)

## ONNX 파일 Inference

In [None]:
import onnxruntime as ort

ort_session = ort.InferenceSession("/opt/ml/final_project/kyh/sample_onnx_model.onnx")

outputs = ort_session.run(
    None,
    {'actual_input_1': np.random.randn(batch_size, 3, 960, 720).astype(np.float32)}
)

In [None]:
print(outputs[0])

## ONNX -> Tensorflow 빌드

In [None]:
import sys
sys.path.append("./onnx-tensorflow")

In [None]:
import onnx
import tensorflow as tf
from onnx_tf.backend import prepare

onnx_model_path = "/opt/ml/final_project/kyh/sample_onnx_model.onnx"
tf_model_path = "/opt/ml/final_project/kyh/tf_model"
onnx_model = onnx.load(onnx_model_path)

tf_rep = prepare(onnx_model)

tf_rep.export_graph(tf_model_path)

## Tensorflow Inference

In [None]:
model = tf.saved_model.load(tf_model_path)
model.trainable = False

input_tensor = tf.random.uniform([8, 3, 960, 720])
out = model(**{'actual_input_1': input_tensor})

In [None]:
print(out)

## Tensorflow -> TFLite 빌드

In [None]:
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
tflite_model = converter.convert()
tflite_model_path = "/opt/ml/final_project/kyh/tflite_model/tflite_model.tflite"

# Save the model
with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)


## TFLite Inference

In [None]:
import numpy as np
import tensorflow as tf

# Load the TFLite model and allocate tensors
interpreter = tf.lite.Interpreter(model_path="/opt/ml/final_project/kyh/tflite_model/tflite_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# get_tensor() returns a copy of the tensor data
# use tensor() in order to get a pointer to the tensor
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)