In [1]:
import tvm.testing
from tvm.contrib.download import download_testdata
import onnxruntime.providers.stvm   # nessesary to register tvm_onnx_import_and_compile and others

In [2]:
import onnx
import numpy as np
from typing import List, AnyStr
from onnx import ModelProto, helper, checker, mapping

## Helper functions

In [3]:
def get_onnx_input_names(model: ModelProto) -> List[AnyStr]:
    inputs = [node.name for node in model.graph.input]
    initializer = [node.name for node in model.graph.initializer]
    inputs = list(set(inputs) - set(initializer))
    return sorted(inputs)


def get_onnx_output_names(model: ModelProto) -> List[AnyStr]:
    return [node.name for node in model.graph.output]


def get_onnx_input_types(model: ModelProto) -> List[np.dtype]:
    input_names = get_onnx_input_names(model)
    return [
        mapping.TENSOR_TYPE_TO_NP_TYPE[node.type.tensor_type.elem_type]
        for node in sorted(model.graph.input, key=lambda node: node.name) if node.name in input_names
    ]


def get_onnx_input_shapes(model: ModelProto) -> List[List[int]]:
    input_names = get_onnx_input_names(model)
    return [
        [dv.dim_value for dv in node.type.tensor_type.shape.dim]
        for node in sorted(model.graph.input, key=lambda node: node.name) if node.name in input_names
    ]


def get_random_model_inputs(model: ModelProto) -> List[np.ndarray]:
    input_shapes = get_onnx_input_shapes(model)
    input_types = get_onnx_input_types(model)
    assert len(input_types) == len(input_shapes)
    inputs = [np.random.uniform(size=shape).astype(dtype) for shape, dtype in zip(input_shapes, input_types)]
    return inputs

In [4]:
def get_onnxruntime_output(model: ModelProto, inputs: List, provider_name: AnyStr) -> np.ndarray:
    output_names = get_onnx_output_names(model)
    input_names = get_onnx_input_names(model)
    assert len(input_names) == len(inputs)
    input_dict = {input_name: input_value for input_name, input_value in zip(input_names, inputs)}

    inference_session = onnxruntime.InferenceSession(model.SerializeToString(), providers=[provider_name])
    output = inference_session.run(output_names, input_dict)

    # Unpack output if there's only a single value.
    if len(output) == 1:
        output = output[0]
    return output


def get_pure_onnxruntime_output(model: ModelProto, inputs: List) -> np.ndarray:
    return get_onnxruntime_output(model, inputs, "CPUExecutionProvider")


def get_stvm_onnxruntime_output(model: ModelProto, inputs: List) -> np.ndarray:
    return get_onnxruntime_output(model, inputs, "StvmExecutionProvider")

In [5]:
def verify_with_ort_with_inputs(
    model,
    inputs,
    out_shape=None,
    opset=None,
    freeze_params=False,
    dtype="float32",
    rtol=1e-5,
    atol=1e-5,
    opt_level=1,
):
    if opset is not None:
        model.opset_import[0].version = opset

    ort_out = get_pure_onnxruntime_output(model, inputs)
    stvm_out = get_stvm_onnxruntime_output(model, inputs)
    for stvm_val, ort_val in zip(stvm_out, ort_out):
        tvm.testing.assert_allclose(ort_val, stvm_val, rtol=rtol, atol=atol)
        assert ort_val.dtype == stvm_val.dtype

## Check accuracy of STVM for simple model

In [6]:
def get_two_input_model(op_name: AnyStr) -> ModelProto:
    dtype = "float32"
    in_shape = [1, 2, 3, 3]
    in_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
    out_shape = in_shape
    out_type = in_type

    layer = helper.make_node(op_name, ["in1", "in2"], ["out"])
    graph = helper.make_graph(
        [layer],
        "two_input_test",
        inputs=[
            helper.make_tensor_value_info("in1", in_type, in_shape),
            helper.make_tensor_value_info("in2", in_type, in_shape),
        ],
        outputs=[
            helper.make_tensor_value_info(
                "out", out_type, out_shape
            )
        ],
    )
    model = helper.make_model(graph, producer_name="two_input_test")
    checker.check_model(model, full_check=True)
    return model

In [7]:
onnx_model = get_two_input_model("Add")
inputs = get_random_model_inputs(onnx_model)
verify_with_ort_with_inputs(onnx_model, inputs)
print("****************** Success! ******************")

STVM ep options:
target: llvm -mcpu=skylake-avx512
target_host: llvm -mcpu=skylake-avx512
opt level: 3
freeze weights: 1
tuning file path: 
tuning type: Ansor
convert layout to NHWC: 0
input tensor names: 
input tensor shapes: 
Build TVM graph executor
****************** Success! ******************




## Check for ResNet50

In [8]:
BASE_MODEL_URL = "https://github.com/onnx/models/raw/master/"
MODEL_URL_COLLECTION = {
    "ResNet50-v1": "vision/classification/resnet/model/resnet50-v1-7.onnx",
    "ResNet50-v2": "vision/classification/resnet/model/resnet50-v2-7.onnx",
    "SqueezeNet-v1.1": "vision/classification/squeezenet/model/squeezenet1.1-7.onnx",
    "SqueezeNet-v1.0": "vision/classification/squeezenet/model/squeezenet1.0-7.onnx",
    "Inception-v1": "vision/classification/inception_and_googlenet/inception_v1/model/inception-v1-7.onnx",
    "Inception-v2": "vision/classification/inception_and_googlenet/inception_v2/model/inception-v2-7.onnx",
}


def get_model_url(model_name):
    return BASE_MODEL_URL + MODEL_URL_COLLECTION[model_name]


def get_name_from_url(url):
    return url[url.rfind("/") + 1 :].strip()


def find_of_download(model_name):
    model_url = get_model_url(model_name)
    model_file_name = get_name_from_url(model_url)
    return download_testdata(model_url, model_file_name, module="models")


def get_onnx_model(model_name):
    model_path = find_of_download(model_name)
    onnx_model = onnx.load(model_path)
    return onnx_model

In [9]:
model_name = "ResNet50-v1"

onnx_model = get_onnx_model(model_name)
inputs = get_random_model_inputs(onnx_model)
verify_with_ort_with_inputs(onnx_model, inputs)
print("****************** Success! ******************")

STVM ep options:
target: llvm -mcpu=skylake-avx512
target_host: llvm -mcpu=skylake-avx512
opt level: 3
freeze weights: 1
tuning file path: 
tuning type: Ansor
convert layout to NHWC: 0
input tensor names: 
input tensor shapes: 




Build TVM graph executor


One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.


****************** Success! ******************


## Configuration options

In [10]:
model_name = "ResNet50-v1"
onnx_model = get_onnx_model(model_name)
# onnx_model = get_two_input_model("Add")
inputs = get_random_model_inputs(onnx_model)
input_names = get_onnx_input_names(onnx_model)
output_names = get_onnx_output_names(onnx_model)
input_dict = {input_name: input_value for input_name, input_value in zip(input_names, inputs)}

In [11]:
client_target = "llvm -mtriple=x86_64-linux-gnu"
client_target_host = client_target
client_opt_level = 3
freeze = True
client_tuning_logfile = ""

po = [dict(target=client_target,
           target_host=client_target_host,
           opt_level=client_opt_level,
           freeze_weights=freeze,
           tuning_file_path=client_tuning_logfile,
           tuning_type="Ansor")]
stvm_session = onnxruntime.InferenceSession(onnx_model.SerializeToString(), providers=["StvmExecutionProvider"], provider_options=po)

output = stvm_session.run(output_names, input_dict)
print(output)

STVM ep options:
target: llvm -mtriple=x86_64-linux-gnu
target_host: llvm -mtriple=x86_64-linux-gnu
opt level: 3
freeze weights: 1
tuning file path: 
tuning type: Ansor
convert layout to NHWC: 0
input tensor names: 
input tensor shapes: 
Build TVM graph executor
[array([[-1.68543422e+00,  7.99613670e-02,  8.89679372e-01,
         7.08967865e-01,  1.10995877e+00,  6.46288872e-01,
         5.37635624e-01, -5.03077269e-01, -1.56129766e+00,
        -2.59686410e-01,  2.23753119e+00,  2.82836032e+00,
         1.95826304e+00,  2.73472595e+00,  1.93408489e+00,
         1.59744227e+00,  1.25942326e+00,  9.91862416e-01,
         2.48975468e+00,  1.87842631e+00, -1.28254578e-01,
         2.12853694e+00,  1.81127214e+00,  1.75756657e+00,
         3.29883844e-01, -2.94335544e-01,  9.62437019e-02,
         4.83264215e-02,  1.99886739e-01, -9.31235850e-01,
        -1.57790756e+00,  1.34628928e+00, -1.01166713e+00,
         9.82081711e-01,  2.14204979e+00, -1.71066701e+00,
        -2.54290909e-01, -1.