Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX optimzer crashes for initializers not listed as graph-input #2198

Closed
thilow opened this issue Jul 25, 2019 · 2 comments
Closed

ONNX optimzer crashes for initializers not listed as graph-input #2198

thilow opened this issue Jul 25, 2019 · 2 comments
Assignees
Labels
bug optimizer Issues related to ONNX optimizers

Comments

@thilow
Copy link

thilow commented Jul 25, 2019

Initializers are not required being listed in graph inputs (To my understanding this requirement was lifted in context of issue #1449).

Problem: The ONNX optimizer fails for models that have initializers which are not listed in graph inputs.

Below there is a sample script demonstrating the issue. For the first version of the model (model_1) the optimizer passes while for the second it fails.
See further below for script output

#%%
import onnx
from onnx import helper, optimizer
from onnx import AttributeProto, TensorProto, GraphProto

# Creating a simple model that adds 42 to the input-value.
# The initializer is not added to he graph-inputs
def create_simpel_model():
    # Create one input (ValueInfoProto)
    a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [1])

    # Create an initialzer with value 42
    b = helper.make_tensor("b",TensorProto.FLOAT,[1], [42])

    # Create one output (ValueInfoProto)
    Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1])

    # Create a node (NodeProto)
    node_def = helper.make_node(
        'Add', # node name
        ['a', 'b'], # inputs
        ['Y'], # outputs
    )

    # Create the graph (GraphProto)
    graph_def = helper.make_graph(
        [node_def],
        'test-model',
        [a],
        [Y],
        [b]
    )

    # Create the model (ModelProto)
    return helper.make_model(graph_def, producer_name='onnx-example')


print("ONNX version: ", onnx.__version__)
#%%
model_1 = create_simpel_model()

# Postprocessing the model by adding the initializer 'b' to the input.
# This prevents the optimizer from crashing.
new_input = helper.make_tensor_value_info('b', TensorProto.FLOAT,[1])
model_1.graph.input.append(new_input)

#%% check model_1
onnx.checker.check_model(model_1)
print('model_1 is checked and ok!')

#%%
# Optimize the model
passes = [
    "eliminate_deadend",
    "eliminate_identity",
    "eliminate_nop_dropout",
    "eliminate_nop_monotone_argmax",
    "eliminate_nop_pad",
    "eliminate_nop_transpose",
    "extract_constant_to_initializer",
    "fuse_consecutive_concats",
    "fuse_consecutive_log_softmax",
    "fuse_consecutive_reduce_unsqueeze",
    "fuse_consecutive_squeezes",
    "fuse_consecutive_transposes"]

model_1 = optimizer.optimize(model_1, passes)
print("model_1 is optimized")

#%%
# Now we don't add the initialzer to the inputs
model_2 = create_simpel_model()


# check model_2
onnx.checker.check_model(model_2)
print('model_2 is checked and ok too')

print("trying to optimize model_2")
# Next line fails!
model_1 = optimizer.optimize(model_2, passes)
print("model_2 is optimized")

The script prduces output below:

ONNX version:  1.5.0
model_1 is checked and ok!
model_1 is optimized
model_2 is checked and ok too
trying to optimize model_2
Traceback (most recent call last):
  File "c:\Users\thilow\.vscode\extensions\ms-python.python-2019.6.24221\pythonFiles\ptvsd_launcher.py", line 43, in <module>
    main(ptvsdArgs)
  File "c:\Users\thilow\.vscode\extensions\ms-python.python-2019.6.24221\pythonFiles\lib\python\ptvsd\__main__.py", line 434, in main
    run()
  File "c:\Users\thilow\.vscode\extensions\ms-python.python-2019.6.24221\pythonFiles\lib\python\ptvsd\__main__.py", line 312, in run_file
    runpy.run_path(target, run_name='__main__')
  File "e:\src\BrainWaveMunich\smartReply\OnnxPythonTools\optimizer_bug_repro.py", line 85, in <module>
    model_1 = optimizer.optimize(model_2, passes)
  File "E:\anaconda3\lib\site-packages\onnx\optimizer.py", line 55, in optimize
    optimized_model_str = C.optimize(model_str, passes)
IndexError: invalid unordered_map<K, T> key
@houseroad
Copy link
Member

Yes, we should fix, since the spec support such cases.

cc: @spandantiwari

@jcwchen
Copy link
Member

jcwchen commented Apr 15, 2021

Please note that ONNX optimizer has been moved to another repo https://github.com/onnx/optimizer since ONNX 1.9. If you still have questions related to the optimizer, please raise an issue there. Thank you!

@jcwchen jcwchen closed this as completed Apr 15, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug optimizer Issues related to ONNX optimizers
Projects
None yet
Development

No branches or pull requests

5 participants