Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GlobalAveragePool with ORT_ENABLE_ALL genenerates the incorrect output. #20540

Closed
SuhwanSong opened this issue May 2, 2024 · 0 comments
Closed

Comments

@SuhwanSong
Copy link

Describe the issue

Onnxruntime generates the incorrect output of the GlobalAveragePool layer when utilizing the "ORT_ENABLE_ALL" . The inference results diverge from the expected outcome achieved with "ORT_DISABLE_ALL".

image

To reproduce

poc.zip

  1. Download and use the provided ONNX model (poc.onnx) and the following Python script.
  2. Execute the Python script with the ONNX model using different optimization levels (ORT_DISABLE_ALL and ORT_ENABLE_ALL).
  3. Observe the outputs and compare results.
import onnxruntime
import numpy as np

if __name__ == "__main__" :

    onnx_model_path = 'poc.onnx'

    rt_option_list = [onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL,
                      onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
                      ]

    rt_sess_option_list = []

    for level in rt_option_list:
        sess_options = onnxruntime.SessionOptions()
        sess_options.graph_optimization_level = level
        rt_sess_option_list.append(sess_options)


    ort_sessions = []
    for sess_options in rt_sess_option_list:
        try:
            ort_session = onnxruntime.InferenceSession(onnx_model_path, sess_options)
            ort_sessions.append(ort_session)
        except Exception as e:
            print (e)


    num_of_error = 0
    for _ in range(100):
        input_np = np.random.randn(1, 64, 13, 13).astype('f')
        input_ = {'fire9/squeeze1x1_1': input_np}

        outputs = []
        for ort_session in ort_sessions:
            try:
                ort_output  = ort_session.run(None, input_)
                outputs.append(np.array(ort_output[0]))
            except Exception as e:
                print (e)

        if outputs[0].argmax() != outputs[1].argmax():
            print (f'output {outputs[0].argmax()} != {outputs[1].argmax()}')
            num_of_error += 1

    print (f'result: {num_of_error} / 100')

image

Urgency

No response

Platform

Linux

OS Version

6.2.0-35-generic

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.17.3

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant