Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

working model with Resize node becomes invalid after using convert_float_to_float16 #14827

Open
ssube opened this issue Feb 25, 2023 · 4 comments
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.

Comments

@ssube
Copy link

ssube commented Feb 25, 2023

Describe the issue

After converting a model to FP16 internally using onnxruntime.transformers.float16.convert_float_to_float16, the model can be loaded with onnx.load_model but fails to load with onnxruntime.InferenceSession. The model was valid before conversion and both ONNX and ORT could load it:

> python3 fp16-repro.py 
ONNX load succeeded for original model
ORT load succeeded for original model
saved optimized model
ONNX load succeeded for optimized model
ORT load failed for optimized model
Traceback (most recent call last):
  File "/home/ssube/onnx-repro/fp16-repro.py", line 42, in <module>
    sess = InferenceSession(opt_file, providers=["CPUExecutionProvider"])
  File "/tmp/ort_env/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 360, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/tmp/ort_env/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 397, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from ./optimized.onnx failed:This is an invalid model. Type Error: Type 'tensor(float16)' of input parameter (onnx::Resize_9300) of operator (Resize) in node (Resize_4928) is invalid.

This appears to be related to #8327 and #2848, both of which were closed as errors in third-party models, but this model was valid before calling convert_float_to_float16.

The error seems to come from the FP16 type incorrectly being applied to Resize nodes: This is an invalid model. Type Error: Type 'tensor(float16)' of input parameter (onnx::Resize_9300) of operator (Resize) in node (Resize_4928) is invalid.

I've tested with the CPUExecutionProvider, which should be the most compatible?, as well as the CUDAExecutionProvider and they both report the same error.

onnx/onnxmltools#361 (comment) suggests there might be an overly-aggressive conversion happening and suggests raising an issue here:

It assumes all operators in graph have a runtime which supports float16. That's not the case for onnxruntime. The model may be valid but cannot be run unless onnxruntime supports float16 for all operators in your graph. You need to use a different runtime or raise an issue on onnxruntime repository with the list of operators you need to support float16.

I don't see any way to exclude the Resize nodes during conversion.

To reproduce

Download https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/onnx/vae_decoder/model.onnx ahead-of-time or within the notebook/script:

!wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/onnx/vae_decoder/model.onnx

from onnx import load_model, save_model
from onnxruntime import InferenceSession
from onnxruntime.transformers.float16 import convert_float_to_float16
from traceback import print_exception

init_file = "./model.onnx"
opt_file = "./optimized.onnx"

# verify original model
try:
    model = load_model(init_file)
    print("ONNX load succeeded for original model")
except Exception as err:
    print("ONNX load failed for original model")
    print_exception(type(err), err, err.__traceback__)

# reload with ORT
try:
    sess = InferenceSession(init_file, providers=["CPUExecutionProvider"])
    print("ORT load succeeded for original model")
except Exception as err:
    print("ORT load failed for original model")
    print_exception(type(err), err, err.__traceback__)

# load and convert
model = load_model(init_file)
optimized = convert_float_to_float16(model, keep_io_types=True, force_fp16_initializers=False, disable_shape_infer=True) # TODO: test with initializers=false
save_model(optimized, opt_file, save_as_external_data=True, all_tensors_to_one_file=True)
print("saved optimized model")

# reload with ONNX
try:
    model = load_model(opt_file)
    print("ONNX load succeeded for optimized model")
except Exception as err:
    print("ONNX load failed for optimized model")
    print_exception(type(err), err, err.__traceback__)

# reload with ORT
try:
    sess = InferenceSession(opt_file, providers=["CPUExecutionProvider"])
    print("ORT load succeeded for optimized model")
except Exception as err:
    print("ORT load failed for optimized model")
    print_exception(type(err), err, err.__traceback__)

Using what appear to be the latest packages (fresh venv created for this repro):

pip3 list --local | grep -e onnx -e torch
onnx                     1.13.1
onnxruntime-gpu          1.14.0
torch                    1.13.1
torchaudio               0.13.1
torchvision              0.14.1

Forcing FP16 initializers, force_fp16_initializers=True, does not seem to make a difference.

Converting a model that does not use the Resize node, like the corresponding encoder https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/onnx/vae_encoder/model.onnx, does work correctly and can be loaded by both ONNX and ORT.

Urgency

Not urgent, but limiting usage on cards with low VRAM (ssube/onnx-web#121 (comment)).

I'm happy to start a PR if this is the correct place and not user error.

Platform

Linux

OS Version

Linux compute-infer-1 5.15.0-60-generic #66-Ubuntu SMP Fri Jan 20 14:29:49 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux
Ubuntu 22.04.1 LTS

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.14.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

ORT 1.14.0, PyTorch 1.13.1

@github-actions github-actions bot added the model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. label Feb 25, 2023
@ssube ssube changed the title onnx and onnxruntime disagree on model validity after using convert_float_to_float16 working model with Resize node becomes invalid after using convert_float_to_float16 Feb 25, 2023
@tianleiwu
Copy link
Contributor

@ssube
Copy link
Author

ssube commented Feb 25, 2023

I'm hoping to eventually use all of the optimizations that ORT is providing, they look pretty helpful.

For now, I'm still get the same error with the nightly package and op_block_list in my repro script:

ONNX load succeeded for optimized model
ORT load failed for optimized model
Traceback (most recent call last):
  File "/home/ssube/onnx-repro/fp16-repro.py", line 41, in <module>
    sess = InferenceSession(opt_file, providers=["CPUExecutionProvider"])
  File "/home/ssube/onnx-repro/ort_env/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 366, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/ssube/onnx-repro/ort_env/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 403, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from ./optimized.onnx failed:This is an invalid model. Type Error: Type 'tensor(float16)' of input parameter (onnx::Resize_967) of operator (Resize) in node (Resize_340) is invalid.

(ort_env) ssube@compute-infer-1:~/onnx-repro$ pip3 list --local | grep -e onnx -e ort -e torch
onnx                     1.13.1
ort-nightly-gpu          1.15.0.dev20230222003
torch                    1.13.1
torchaudio               0.13.1
torchvision              0.14.1

with

optimized = convert_float_to_float16(
    model, 
    keep_io_types=False, 
    force_fp16_initializers=False, 
    disable_shape_infer=True, 
    op_block_list=[
        "RandomNormalLike",
        "Resize",
    ]
)

based on https://github.com/microsoft/onnxruntime/blob/780054900994b704dc37c9f1a4f7874538bfc11e/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md#optimize-onnx-pipeline and

The optimize command shown in the docs runs to completion, and loads on the CUDA provider but not the CPU provider:

> python3 -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ~/onnx-web/models/stable-diffusion-onnx-v1-5 -o ./sd-v1-5-fp16 --float16

...
optimize_sd_pipeline: Convert unet to float16 ...                                                                                                                                                
get_operator_statistics: Operators:{'Constant': 192, 'Transpose': 294, 'MatMul': 112, 'Shape': 66, 'Reshape': 64, 'Gather': 65, 'NhwcConv': 98, 'Unsqueeze': 158, 'GroupNorm': 61, 'Concat': 47, 
'ConstantOfShape': 1, 'Mul': 20, 'Equal': 1, 'Where': 1, 'Expand': 1, 'Sin': 1, 'Cos': 1, 'Slice': 2, 'Gemm': 24, 'Sigmoid': 2, 'Add': 60, 'LayerNormalization': 16, 'MultiHeadAttention': 32, 'S
kipLayerNormalization': 32, 'Resize': 3, 'BiasSplitGelu': 16, 'BiasAdd': 16, 'Cast': 3}
get_fused_operator_statistics: Optimized operators:{'Attention': 0, 'MultiHeadAttention': 32, 'LayerNormalization': 16, 'SkipLayerNormalization': 32, 'BiasSplitGelu': 16, 'GroupNorm': 61, 'Nhwc
Conv': 98}
  save_model_to_file: Sort graphs in topological order
  save_model_to_file: Model saved to sd-v1-5-fp16/unet/model.onnx
optimize_sd_pipeline: unet is optimized
...

> python3
>>> import onnxruntime
>>> sess = onnxruntime.InferenceSession("./sd-v1-5-fp16/vae_decoder/model.onnx", providers=['CPUExecutionProvider'])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ssube/onnx-repro/ort_env/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 366, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/ssube/onnx-repro/ort_env/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 414, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Failed to find kernel for NhwcConv(1) (node NhwcConv_0-/post_quant_conv/Conv). Kernel not found

>>> sess = onnxruntime.InferenceSession("./sd-v1-5-fp16/vae_decoder/model.onnx", providers=['CUDAExecutionProvider'])
2023-02-25 23:11:07.349772352 [W:onnxruntime:, session_state.cc:1136 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-02-25 23:11:07.349789765 [W:onnxruntime:, session_state.cc:1138 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
>>> 

The fp16 optimization is not really meant for CPU, so that's probably ok, but using convert_float_to_float16 on its own is still causing problems.

@tianleiwu
Copy link
Contributor

tianleiwu commented Feb 26, 2023

@ssube, I can reproduce the issue. Walkaround is either change disable_shape_infer=False or op_block_list=['Identity'] during calling convert_float_to_float16 for this model. It seems that after disabling shape inference, data type information is not completed for the conversion. We will improve type inference later.

You are right. The optimization for stable diffusion works for CUDA only.

CPUExecutionProvider does not support float16 operator. If an operator has float32 implementation in CPUExecutionProvider, ORT actually will add Cast to convert tensors back to float32 so that it can run the operator in float32. That might cause float16 model to be slower than float32 in CPUExecutionProvider.

@ssube
Copy link
Author

ssube commented Feb 27, 2023

Thanks for testing that. I had tried setting disable_shape_infer=False, but that fails on very large models (> 2GB), so I set it back to True at some point. I can reproduce your findings as well, that seems like a viable workaround if I call onnx.shape_inference.infer_shapes_path first with the model path rather than the loaded proto.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.
Projects
None yet
Development

No branches or pull requests

2 participants