Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conv3D performance degradation after ONNX conversion #2303

Open
jm2201 opened this issue Feb 13, 2024 · 2 comments
Open

Conv3D performance degradation after ONNX conversion #2303

jm2201 opened this issue Feb 13, 2024 · 2 comments
Labels
bug An unexpected problem or unintended behavior pending on user response Waiting for more information or validation from user

Comments

@jm2201
Copy link

jm2201 commented Feb 13, 2024

Describe the bug
A simple tensorflow model with Conv3D and pooling is 3.6x slower on CPU after converting to ONNX.
The same model with Conv3D replaced by Conv2D is 10x faster on CPU after converting to ONNX.

Urgency
If not resolved in the next 4-6 months, this bug will block the planned release of a TF to ONNX-converted model.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 18.04*): Windows 10 Pro 22H2
  • TensorFlow Version: 2.10.0
  • Python version: 3.9.16
  • ONNX version (if applicable, e.g. 1.11*): 1.15.0
  • ONNXRuntime version (if applicable, e.g. 1.11*): 1.16.3

To Reproduce

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import models, layers, optimizers, initializers
import onnxruntime as rt

def get_model(_3d=True):
    tf.keras.backend.set_image_data_format('channels_first')
    if _3d:
        input_shape = (1, 80, 128, 128)  # channels first
        conv_layer = layers.Conv3D
        pool_layer = layers.MaxPooling3D        
    else:
        input_shape = (1, 128, 128)
        conv_layer = layers.Conv2D
        pool_layer = layers.MaxPooling2D
        
    input = layers.Input(input_shape, name='input_0')
    x = conv_layer(32, 3, activation='relu', padding='same')(input)
    x = conv_layer(32, 3, activation='relu', padding='same')(x)
    x = pool_layer(2)(x)
    
    x = conv_layer(64, 3, activation='relu', padding='same')(x)
    x = conv_layer(64, 3, activation='relu', padding='same')(x)
    x = pool_layer(2)(x)
    
    x = conv_layer(128, 3, activation='relu', padding='same')(x)
    x = conv_layer(128, 3, activation='relu', padding='same')(x)
    x = pool_layer(2)(x)
    
    out = conv_layer(1, 1, activation='sigmoid')(x)
    
    model = models.Model(inputs=input, outputs=out)
    model.compile(optimizer=optimizers.Adam(lr=1e-3, decay=1e-5), loss='binary_crossentropy')
    return model

Test the 3D tensorflow model:

MODEL_EX = 'model_ex/saved_model'
_3D = True
tf_model = get_model(_3D)
input_ = np.random.random((1, 1, 80, 128, 128)).astype(np.float32)
%timeit out = tf_model.predict(input_)
1/1 [==============================] - 0s 89ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 21ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 21ms/step
1/1 [==============================] - 0s 23ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 21ms/step
195 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Convert the 3D model to ONNX:

tf_model.save(MODEL_EX)
!python -m tf2onnx.convert --opset 18 --saved-model model_ex/saved_model --output model_ex/tmp.onnx

Test the 3D ONNX model:

TMP_MODEL = os.path.join('model_ex', 'tmp.onnx')
sess = rt.InferenceSession(TMP_MODEL)
%timeit result = sess.run(None, {'input_0': input_})
737 ms ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Test the 2D tensorflow model:

MODEL_EX = 'model_ex/saved_model'
_3D = False
tf_model = get_model(_3D)
input_ = np.random.random((1, 1, 128, 128)).astype(np.float32)
%timeit out = tf_model.predict(input_)
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 20ms/step
...
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 19ms/step
1/1 [==============================] - 0s 18ms/step
59.7 ms ± 1.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Convert the 2D model to ONNX:

tf_model.save(MODEL_EX)
!python -m tf2onnx.convert --opset 18 --saved-model model_ex/saved_model --output model_ex/tmp.onnx

Test the 2D ONNX model:

TMP_MODEL = os.path.join('model_ex', 'tmp.onnx')
sess = rt.InferenceSession(TMP_MODEL)
%timeit result = sess.run(None, {'input_0': input_})
4.55 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
@jm2201 jm2201 added the bug An unexpected problem or unintended behavior label Feb 13, 2024
@fatcat-z
Copy link
Collaborator

These 2 ONNX model files have the same ONNX ops, so I'm wondering the performance difference was caused by the inputs. When those inputs are different, even for the same ONNX op, ONNXRUNTIME might show a performance difference, so I suggest opening an issue in onnxruntime repo for a further investigation.

@fatcat-z fatcat-z added the pending on user response Waiting for more information or validation from user label Mar 11, 2024
@jm2201
Copy link
Author

jm2201 commented Mar 14, 2024

Could you test again? The reason the two ONNX models were the same was a missing line in the repro code above (the 2D tf model didn't get saved!) It's fixed now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug An unexpected problem or unintended behavior pending on user response Waiting for more information or validation from user
Projects
None yet
Development

No branches or pull requests

2 participants