Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [BUG] - <ONNXModelCPU_tvm_0.9.0_cpu.yaml file was empty, can't get opset properly properly> #120

Closed
sdhuie opened this issue Aug 22, 2023 · 11 comments
Labels
bug Something isn't working

Comments

@sdhuie
Copy link

sdhuie commented Aug 22, 2023

Description

ONNXModelCPU_tvm_0.9.0_cpu.yaml file was not generated properly, resulting in the following error message
opset: {}

�[31mERROR  �[0m �[35mfuzz  �[0m - Traceback (most recent call last):
  File "/home/vincy/nnsmith/nnsmith/cli/fuzz.py", line 224, in run
    testcase = self.make_testcase(seed)
  File "/home/vincy/nnsmith/nnsmith/cli/fuzz.py", line 179, in make_testcase
    gen = model_gen(
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/graph_gen.py", line 796, in model_gen
    gen = SymbolicGen(opset, seed, symbolic_init=symbolic_init, **kwargs)
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/graph_gen.py", line 467, in __init__
    super().__init__(opset, seed, **kwargs)
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/graph_gen.py", line 42, in __init__
    assert len(opset) > 0, "opset must not be empty"

tvm version: 0.9.0
nnsmith: 0.2.0.dev16

Installation

pip install "git+https://github.com/ise-uiuc/nnsmith@main#egg=nnsmith[torch,onnx]" --upgrade

Reproduction

# Paste the commands or python script for reproducing the issue.
command: `python nnsmith/cli/fuzz.py  fuzz.time=1s fuzz.root=${PATH_TO_REPORT}    model.type=onnx backend.type=tvm      filter.type="[nan,inf,dup]"     fuzz.save_test=${PATH_TO_SAVE_TESTS} `

or 

command: `pytest -s tests/tvm/ `

Screenshots

![DESCRIPTION](LINK.png)

Logs

�[33mWARNING�[0m �[35mdtest �[0m - =====> [Failure] at torch.PTMatMul((float16, float16)) => (float16,)�[0m
============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

�[33mWARNING�[0m �[35mdtest �[0m - =====> [Failure] at torch.PTMatMul((float32, float32)) => (float32,)�[0m
============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

�[33mWARNING�[0m �[35mdtest �[0m - =====> [Failure] at torch.PTMatMul((float64, float64)) => (float64,)�[0m
============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

�[33mWARNING�[0m �[35mdtest �[0m - =====> [Failure] at torch.PTMatMul((int8, int8)) => (int8,)�[0m
============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

�[33mWARNING�[0m �[35mdtest �[0m - =====> [Failure] at torch.PTMatMul((int16, int16)) => (int16,)�[0m
============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

�[33mWARNING�[0m �[35mdtest �[0m - =====> [Failure] at torch.PTMatMul((int32, int32)) => (int32,)�[0m
============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

�[33mWARNING�[0m �[35mdtest �[0m - =====> [Failure] at torch.PTMatMul((int64, int64)) => (int64,)�[0m
============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

�[33mWARNING�[0m �[35mdtest �[0m - =====> [Failure] at torch.PTMatMul((uint8, uint8)) => (uint8,)�[0m
�[32mINFO   �[0m �[35mfuzz  �[0m - Test success info supressed -- only showing logs for failed tests�[0m
�[32mINFO   �[0m �[35mfuzz  �[0m - Saving all intermediate testcases to PATH_TO_SAVE_TESTS�[0m
�[33mWARNING�[0m �[35mcore  �[0m - Report folder already exists. Press [Y/N] to continue or exit...�[0m
�[31mERROR  �[0m �[35mfuzz  �[0m - `make_testcase` failed with seed 2091741040. It can be NNSmith or Generator (onnx) bug.�[0m
�[31mERROR  �[0m �[35mfuzz  �[0m - Traceback (most recent call last):
  File "/home/vincy/nnsmith/nnsmith/cli/fuzz.py", line 224, in run
    testcase = self.make_testcase(seed)
  File "/home/vincy/nnsmith/nnsmith/cli/fuzz.py", line 179, in make_testcase
    gen = model_gen(
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/graph_gen.py", line 796, in model_gen
    gen = SymbolicGen(opset, seed, symbolic_init=symbolic_init, **kwargs)
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/graph_gen.py", line 467, in __init__
    super().__init__(opset, seed, **kwargs)
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/graph_gen.py", line 42, in __init__
    assert len(opset) > 0, "opset must not be empty"
AssertionError: opset must not be empty

Front-end framework

onnx

Version of the front-end framework

onnx-1.14.0 onnxruntime-1.15.1

Back-end engine

tvm

Version of the front-end engine

tvm-0.9.0

Other context

The onnxruntime back-end engine is working fine

@sdhuie sdhuie added the bug Something isn't working label Aug 22, 2023
@sdhuie sdhuie changed the title 🐛 [BUG] - <ONNXModelCPU_tvm_0.9.0_cpu.yaml file was not generated properly> 🐛 [BUG] - <ONNXModelCPU_tvm_0.9.0_cpu.yaml file was empty, can't get opset properly properly> Aug 22, 2023
@sdhuie
Copy link
Author

sdhuie commented Aug 24, 2023

@ganler

@ganler
Copy link
Member

ganler commented Aug 24, 2023

Hi, can you remove the ONNXModelCPU_tvm_0.9.0_cpu.yaml file and retry your command by appending hydra.verbose=dtest which allows us to have more complete logs? Oftentimes it is very simple configuration errors (e.g., TVM is not properly installed).

@sdhuie
Copy link
Author

sdhuie commented Aug 25, 2023

Thanks for the answer, @ganler
(1)deleting yaml and regenerating has been tried,generated ONNXModelCPU_onnxruntime_1.15.1_cpu.yaml content is still empty;
(2)tvm-0.9.0 already build and install.
Moreover, I even tried renaming the ONNXModelCPU_onnxruntime_1.15.1_cpu.yaml file toONNXModelCPU_tvm_0.9.0_cpu.yamlto test it, but the tvm backend recognizes it with an error, or is it possible to provide a copy of the standard ONNXModelCPU_tvm_0.9.0_cpu.yaml file to allow me to proceed with it.

@ganler
Copy link
Member

ganler commented Aug 25, 2023

Hi, I mean I am interested in the log. Deleting and regenerating yaml with hydra.verbose=dtest will show detailed failure reasons that can help me debug. Thanks.

@sdhuie
Copy link
Author

sdhuie commented Aug 25, 2023

Okay, thanks. Got it. Debug log

WARNING dtest  - =====> [Failure] at Cast {'to'= int32}((int16,)) => (int32,)
DEBUG   dtest  - Traceback (most recent call last):
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/backends/factory.py", line 103, in checked_compile
    return self.checked_make_backend(testcase.model)
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/backends/factory.py", line 99, in checked_make_backend
    return self.make_backend(model)
  File "/home/vincy/.local/lib/python3.10/site-packages/multipledispatch/dispatcher.py", line 439, in __call__
    return func(self.obj, *args, **kwargs)
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/backends/tvm.py", line 86, in make_backend
    ).evaluate()
  File "/home/vincy/apache-tvm-src-v0.9.0/python/tvm/relay/backend/interpreter.py", line 171, in evaluate
    return self._make_executor()
  File "/home/vincy/apache-tvm-src-v0.9.0/python/tvm/relay/build_module.py", line 592, in _make_executor
    mod = build(self.mod, target=self.target)
  File "/home/vincy/apache-tvm-src-v0.9.0/python/tvm/relay/build_module.py", line 438, in build
    graph_json, runtime_mod, params = bld_mod.build(
  File "/home/vincy/apache-tvm-src-v0.9.0/python/tvm/relay/build_module.py", line 161, in build
    self._build(
  File "/home/vincy/apache-tvm-src-v0.9.0/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  4: TVMFuncCall
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
  1: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  0: tvm::codegen::Build(tvm::IRModule, tvm::Target)
  File "/home/vincy/apache-tvm-src-v0.9.0/src/target/codegen.cc", line 58
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (bf != nullptr) is false: target.build.llvm is not enabled

============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============

@ganler
Copy link
Member

ganler commented Aug 25, 2023

So the problem is that you are not using a TVM with LLVM backend. Because LLVM backend is used but not found in TVM, consequently all operators being "smoke-tested" failed with "target.build.llvm is not enabled".

To fix this you can (i) install official TVM wheels from https://pypi.org/project/apache-tvm/; or (ii) build TVM with LLVM support (add USE_LLVM in the cmake flags).

@ganler
Copy link
Member

ganler commented Aug 25, 2023

You may also use other CPU targets here (https://github.com/ise-uiuc/nnsmith/blob/main/nnsmith/backends/tvm.py#L38) if you don't mind hard-coding the target that you flavor. But I don't think I will change the LLVM backend option to others in the NNSmith given that LLVM is the most common CPU target.

@sdhuie
Copy link
Author

sdhuie commented Aug 25, 2023

Thanks! Configure llvm paths to be valid, set(USE_LLVM $CLANG/bin/llvm-config)

@sdhuie sdhuie closed this as completed Aug 25, 2023
@sdhuie
Copy link
Author

sdhuie commented Aug 25, 2023

ONNXModelCPU_tvm_0.9.0_cpu.yaml is generated with other types of errors, please confirm.
1

  11: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  10: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  9: tvm::transform::Pass::operator()(tvm::IRModule) const
  8: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  6: _ZN3tvm7runtime13PackedFuncObj
  5: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  4: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  3: tvm::relay::TypeSolver::Solve()
  2: _ZN3tvm7runtime13PackedFuncObj
  1: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  0: bool tvm::relay::MatmulRel<tvm::relay::MatmulAttrs>(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  File "/home/vincy/apache-tvm-src-v0.9.0/src/relay/analysis/type_solver.cc", line 624
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: [15:31:29] /home/vincy/apache-tvm-src-v0.9.0/src/relay/op/nn/nn.h:105: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (static_cast<int>(tensor_b->shape.size()) == 2) is false: 

2

  File "/home/vincy/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/vincy/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/vincy/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/vincy/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/home/vincy/.local/lib/python3.10/site-packages/torch/jit/_trace.py", line 1268, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/vincy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/vincy/.local/lib/python3.10/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/vincy/.local/lib/python3.10/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/vincy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/vincy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/materialize/torch/symbolnet.py", line 362, in forward
    output_tensors = inst(*input_tensors)
RuntimeError: "acos_vml_cpu" not implemented for 'Half'

3

WARNING dtest  - =====> [Failure] at torch.PTMatMul((uint8, uint8)) => (uint8,)
DEBUG   dtest  - Traceback (most recent call last):
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/backends/factory.py", line 103, in checked_compile
    return self.checked_make_backend(testcase.model)
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/backends/factory.py", line 99, in checked_make_backend
    return self.make_backend(model)
  File "/home/vincy/.local/lib/python3.10/site-packages/multipledispatch/dispatcher.py", line 439, in __call__
    return func(self.obj, *args, **kwargs)
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/backends/tvm.py", line 76, in make_backend
    onnx_model = model.native_model
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/materialize/onnx/__init__.py", line 272, in native_model
    self.onnx_model = self.get_onnx_from_torch()
  File "/home/vincy/.local/lib/python3.10/site-packages/nnsmith/materialize/onnx/__init__.py", line 285, in get_onnx_from_torch
    onnx.checker.check_model(onnx_model, full_check=True)
  File "/home/vincy/.local/lib/python3.10/site-packages/onnx/checker.py", line 136, in check_model
    C.check_model(protobuf_string, full_check)
onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:MatMul, node name: /MatMul): A typestr: T, has unsupported type: tensor(uint8)

@sdhuie sdhuie reopened this Aug 25, 2023
@ganler
Copy link
Member

ganler commented Aug 25, 2023

If the yaml file is not empty you can safely ignore the errors. So the rationale here is that we will run some smoke tests for certain operators and data types to examine if it is compilable. It is very fine for a subset of them to not work. It is not fine as you previously found to have nothing work.

@ganler
Copy link
Member

ganler commented Aug 27, 2023

Closed for now. Feel free to reopen if further assistance is needed.

@ganler ganler closed this as completed Aug 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants