Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Runtime much slower than PyTorch (2-3x slower) #12880

Open
thomas-beznik opened this issue Sep 7, 2022 · 21 comments
Open

ONNX Runtime much slower than PyTorch (2-3x slower) #12880

thomas-beznik opened this issue Sep 7, 2022 · 21 comments
Labels
ep:CUDA issues related to the CUDA execution provider

Comments

@thomas-beznik
Copy link

Describe the bug
We built a UNet3D image segmentation model in PyTorch (based on this repo) and want to start distributing it. ONNX seemed like a good option as it allows us to compress our models and the dependencies needed to run them. As our models are large & slow, we need to run them on GPU

We were able to convert these models to ONNX, but noticed a significant slow-down of the inference (2-3x). The issue is that the timing is quite critical, and that our models are already relatively slow, so we can't afford more slow-downs

I'm running my comparison tests following what was done in this issue

I could use your help to better understand where the issue is coming from and if it is resolvable at all. What tests, settings, etc. can I try to see where the issue might be ?

Urgency
This is quite an urgent issue, we need to deliver our models to our clients in the coming month and will need to resolve to other solutions if we can't fix ONNX soon

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04
  • ONNX Runtime installed from (source or binary): source
  • ONNX Runtime version: 1.12
  • Python version: 3.8.13
  • Visual Studio version (if applicable):
  • CUDA/cuDNN version: 11.1 / 8.0.5
  • GPU model and memory: NVIDIA GeForce RTX 3060, 12GB

To Reproduce
The code for the model will be quite hard to extract, so I'll first try to describe the issue and what I've tested. I'm currently generating my model using:

with torch.no_grad():
        torch.onnx.export(
            torchModel,
            dummyInput,
            outPath,
            export_params=True,
            opset_version=14,
            do_constant_folding=True,
            input_names=["input"],
            output_names=["output"],
            dynamic_axes={
                "input": [0],
                "output": [0]
            },
            verbose=True,
        )

The model that we are using uses the following operations:

  • 3D conv
  • Group normalisation
  • Max pooling
  • Nearest neighbor interpolation & concatenation

When converting to ONNX, I could see some weird things in the graph (see the first screenshot):

  1. For some operations the device is cpu instead of cuda:0 like all other operations; what does this mean? Will ONNX runtime run these operations on the CPU?. See below for the partial output of the conversion with verbose = True
  %154 : Half(*, 512, *, *, *, strides=[884736, 1728, 144, 12, 1], requires_grad=0, device=cuda:0) = onnx::Relu(%153) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1297:0
  %155 : Long(3, strides=[1], device=cpu) = onnx::Constant[value= 0  8 -1 [ CPULongType{3} ]]()
  %156 : Half(0, 8, *, device=cpu) = onnx::Reshape[allowzero=0](%154, %155)
  %157 : Half(8, strides=[1], device=cpu) = onnx::Constant[value= 1  1  1  1  1  1  1  1 [ CPUHalfType{8} ]]()
  %158 : Half(8, strides=[1], device=cpu) = onnx::Constant[value= 0  0  0  0  0  0  0  0 [ CPUHalfType{8} ]]()
  %159 : Half(0, 8, *, device=cpu) = onnx::InstanceNormalization[epsilon=1.0000000000000001e-05](%156, %157, %158)
  %160 : Long(5, strides=[1], device=cpu) = onnx::Shape(%154)
  %161 : Half(*, *, *, *, *, device=cpu) = onnx::Reshape[allowzero=0](%159, %160)
  %164 : Half(*, *, *, *, *, device=cpu) = onnx::Mul(%161, %309)
  %167 : Half(*, *, *, *, *, strides=[884736, 1728, 144, 12, 1], requires_grad=0, device=cuda:0) = onnx::Add(%164, %310) # /usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:2360:0
  %169 : Long(5, strides=[1], device=cpu) = onnx::Shape(%167)
  %170 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()
  %171 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()
  %172 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}]()
  %173 : Long(2, strides=[1], device=cpu) = onnx::Slice(%169, %171, %172, %170)
  %175 : Long(5, strides=[1], device=cpu) = onnx::Concat[axis=0](%173, %311)
  %176 : Tensor? = prim::Constant()
  %177 : Tensor? = prim::Constant()
  1. The graph is rather "ugly" when compared to the one generated for ResNet (regarding all the Mul, Add, Reshape, etc. operations). Could this be the reason for the slowdown?

I saw that Group normalisation wasn't directly supported by ONNX and thus thought that this might be the cause for the slow-down, I thus tried with an alternative model where I remove the group norm, which led to a nicer graph (see 2nd screenshot) and to less slow-down (from 3x slower to 2x slower). The slow-down is still significant though, and the Slice, Concat, etc. operations still say that they occur on the cpu; are these then the issue?

Overall it would be great to get some guidance on where the problem could be located: should we adapt our model architecture, the way of exporting to ONNX, etc. ? Is it even possible at all with a model like UNet3D ?

Thanks for the help !

Screenshots
model onnx (2)
model2 onnx

@skottmckay
Copy link
Contributor

I don't think the device used by the pytorch export has any relevant to how ORT runs the model so I wouldn't worry about that. If you set the log severity to VERBOSE the node assignments will be printed out. Look for 'VerifyEachNodeIsAssignedToAnEp' in the output.

so = onnxruntime.SessionOptions()
so.log_severity_level = 0
ort_session = onnxruntime.InferenceSession('model.onnx', so)

Should produce output like this:

2022-09-08 09:54:11.5652487 [V:onnxruntime:, session_state.cc:1186 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Node placements
2022-09-08 09:54:11.5722903 [V:onnxruntime:, session_state.cc:1188 onnxruntime::VerifyEachNodeIsAssignedToAnEp] All nodes have been placed on [CPUExecutionProvider].

ORT will perform various optimizations when loading the model. If you want to view how that modifies the model you can use this script to save the model with the optimizations applied.
python -m onnxruntime.tools.optimize_onnx_model --opt_level extended <input model> <output model>

Optimization levels are described here: https://onnxruntime.ai/docs/performance/graph-optimizations.html

Are you binding the model input and outputs so they stay on GPU? Otherwise there's device copy between CPU and GPU that can significantly affect performance.

@thomas-beznik
Copy link
Author

Thank you for your answer

I don't think the device used by the pytorch export has any relevant to how ORT runs the model so I wouldn't worry about that. If you set the log severity to VERBOSE the node assignments will be printed out. Look for 'VerifyEachNodeIsAssignedToAnEp' in the output.

so = onnxruntime.SessionOptions()
so.log_severity_level = 0
ort_session = onnxruntime.InferenceSession('model.onnx', so)

I looked into the node assignments and got this output:

2022-09-08 08:45:30.708811165 [V:onnxruntime:, session_state.cc:1193 VerifyEachNodeIsAssignedToAnEp]  Provider: [CPUExecutionProvider]: [Slice (Slice_95), Concat (Concat_96), Slice (Slice_125), Concat (Concat_126), Slice (Slice_155), Concat (Concat_156), ]
2022-09-08 08:45:30.708831105 [V:onnxruntime:, session_state.cc:1193 VerifyEachNodeIsAssignedToAnEp]  Provider: [CUDAExecutionProvider]: [Conv (Conv_0), Relu (Relu_1), Reshape (Reshape_3), InstanceNormalization (InstanceNormalization_6), ...

So it does seem like some operations run on the CPU, though idk if this is the cause for the slow-down ?

ORT will perform various optimizations when loading the model. If you want to view how that modifies the model you can use this script to save the model with the optimizations applied. python -m onnxruntime.tools.optimize_onnx_model --opt_level extended <input model> <output model>

Optimization levels are described here: https://onnxruntime.ai/docs/performance/graph-optimizations.html

I get the attached graph as output when running the optimisations. The weird thing is that the optimised model is even slower: I go from 350ms to 690ms per inference

Are you binding the model input and outputs so they stay on GPU? Otherwise there's device copy between CPU and GPU that can significantly affect performance.

I think that I'm doing it correctly, see code below:

deviceType = 'cuda'
deviceId = '0'
device = torch.device(deviceType + deviceId)
input_names = ort_session.get_inputs()[0].name
output_names = ort_session.get_outputs()[0].name
io_binding = ort_session.io_binding()

data = onnxruntime.OrtValue.ortvalue_from_numpy(x, device.type, 0)
io_binding.bind_input(input_names, deviceType, int(deviceId), np.float16, [batch_size, 1, 96, 96, 96], data.data_ptr())
io_binding.bind_output(output_names, device_type=deviceType,
      device_id=int(deviceId))

#warm up run
ort_session.run_with_iobinding(io_binding)

for i in range(total_samples):
  t0 = time.time()
  ort_session.run_with_iobinding(io_binding)
  latency.append(time.time() - t0)

Overall, I'm a bit stuck to find where the slow-down comes from.. Without resolving this we'll have to go with another solution than ONNX

model_optimised onnx

@hariharans29
Copy link
Member

hariharans29 commented Sep 8, 2022

The Slices and Concats that are being forced down to CPU are part of shape subgraphs - if you look at what they are doing, they slice out one int and concatenate 2 ints and so on. There is no need for these ops to be hardware accelerated (in fact they are detrimental). So ORT has logic to force these to CPU to save device bandwidth for ops that actually require hardware acceleration. So, I don't believe this is the cause for the poor perf.

Have you tried using nvprof and checking which kernel takes up the most time ? That is the best way to move forward with this.

@hariharans29 hariharans29 added the ep:CUDA issues related to the CUDA execution provider label Sep 8, 2022
@hariharans29
Copy link
Member

hariharans29 commented Sep 8, 2022

Since you have a Conv-heavy fp16 model and a card that supports tensor core operations, can you try this simple one-line update to your script -

https://onnxruntime.ai/docs/performance/tune-performance.html#convolution-heavy-models-and-the-cuda-ep.

This is why I expect this to help your use-case:

image

@skottmckay
Copy link
Contributor

I get the attached graph as output when running the optimisations. The weird thing is that the optimised model is even slower: I go from 350ms to 690ms per inference

Sorry - overlooked that that script doesn't have a way to enable the CUDA EP when running. I'm guessing that results in CPU EP specific custom ops being inserted which leads to more of the model running on the CPU EP.

The script is very simple though and you can do the equivalent (set optimization level and output filename in session options) manually if you want to see the optimized model for a session with the CUDA EP enabled.

so = ort.SessionOptions()
so.optimized_model_filepath = str(output_path.resolve())
so.graph_optimization_level = level

@thomas-beznik
Copy link
Author

Since you have a Conv-heavy fp16 model and a card that supports tensor core operations, can you try this simple one-line update to your script -

https://onnxruntime.ai/docs/performance/tune-performance.html#convolution-heavy-models-and-the-cuda-ep.

This is why I expect this to help your use-case:

image

I was already using this actually.. (I used the same setup as here)

@thomas-beznik
Copy link
Author

I get the attached graph as output when running the optimisations. The weird thing is that the optimised model is even slower: I go from 350ms to 690ms per inference

Sorry - overlooked that that script doesn't have a way to enable the CUDA EP when running. I'm guessing that results in CPU EP specific custom ops being inserted which leads to more of the model running on the CPU EP.

The script is very simple though and you can do the equivalent (set optimization level and output filename in session options) manually if you want to see the optimized model for a session with the CUDA EP enabled.

so = ort.SessionOptions()
so.optimized_model_filepath = str(output_path.resolve())
so.graph_optimization_level = level

I was already using ORT_ENABLE_ALL, and actually don't really see a difference in performance when using ORT_ENABLE_BASIC

@skottmckay
Copy link
Contributor

Any difference is dependent on the model and the EPs that are enabled. If there are no internal ORT operators with CUDA implementations that apply to nodes the CUDA EP is taking there won't be a difference between 'basic' and 'extended'/'all'.

@sloth2012
Copy link

Meet similar issue. My onnxruntime model is very close to pytorch model (less than pytorch model)。

@hariharans29
Copy link
Member

Since you have a Conv-heavy fp16 model and a card that supports tensor core operations, can you try this simple one-line update to your script -
https://onnxruntime.ai/docs/performance/tune-performance.html#convolution-heavy-models-and-the-cuda-ep.
This is why I expect this to help your use-case:
image

I was already using this actually.. (I used the same setup as here)

My bad. I didn't notice that.

Could you run nvprof against your script and just give us the results of that ?

Also (if shareable), can you please give us the model ?

@wschin
Copy link
Contributor

wschin commented Sep 14, 2022

Can you try running with ORTModule? You can just wrap your nn.Module model via

from onnxruntime.training.ortmodule import ORTModule
new_model = ORTModule(model)
# ORTModule is also nn.Module so just use it with the original inputs
output = new_model(inputs)

ORTModule has some optimization to reduce overhead in the use of IOBinding. If you observe ORTModule brining some speedup, then it's really IOBinding's problem.

@wschin
Copy link
Contributor

wschin commented Sep 14, 2022

Could you run nsys profile with your model w/wo onnxruntime? It was easy to me to identify which part is the performance bottleneck when I have profiling result. For example, in the following figure, it's clear IOBinding causes performance issue.
image.
Another example below shows unnecessary cudaMemcpy's happening in Reshape and slows down the model.
image

@wschin
Copy link
Contributor

wschin commented Sep 14, 2022

Why do you have so many Cast's in this reply's figure? ORT recently adds support for "strided" tensors, so I expect those Cast's are no-op's. If I am wrong and ORT does real computation on those Cast's, it could slow down significantly. To verify if Cast is a problem, do you mind run your model under float32 and compare the time w/wo ORT? Thanks a lot!

@thomas-beznik
Copy link
Author

Hello @wschin ! Thanks for all your suggestions, I'll try them as soon as I can, as I've been getting some errors due to the installation of ORTModule (similar to this) and don't have time to fix this right now..

@wschin
Copy link
Contributor

wschin commented Sep 15, 2022

@thomas-beznik, sure thing. If #9754 is the blocker, you probably need to build ORT from source. Note that you need a clean machine to avoid dependency interference for a clean build. Btw, you don't need ORTModule to do the float32 comparison. If you have time, please do float32 comparison first. Thank you.

@thomas-beznik
Copy link
Author

@thomas-beznik, sure thing. If #9754 is the blocker, you probably need to build ORT from source. Note that you need a clean machine to avoid dependency interference for a clean build. Btw, you don't need ORTModule to do the float32 comparison. If you have time, please do float32 comparison first. Thank you.

Got it, is there an easy way to undo the installation of ORTModule ? As now I cannot run normal ONNX inference either :/

@wschin
Copy link
Contributor

wschin commented Sep 15, 2022

First, pip uninstall onnxruntime-training? Then, pip install onnxruntime?

How many inputs/outputs/operators do you have? If the number of inputs/outputs is at the same scale as the number of operators, IOBinding is super slow.

@thomas-beznik
Copy link
Author

thomas-beznik commented Sep 19, 2022

Could you run nsys profile with your model w/wo onnxruntime? It was easy to me to identify which part is the performance bottleneck when I have profiling result. For example, in the following figure, it's clear IOBinding causes performance issue. image. Another example below shows unnecessary cudaMemcpy's happening in Reshape and slows down the model. image

I've attached the result from the profiling. I have a hard time understanding it unfortunately.. (this is without the ORTModule btw)
nsys-profiling.zip

@davidmezzetti
Copy link

davidmezzetti commented Nov 29, 2022

I ran into a similar issue where an ONNX model was much slower than it's PyTorch counterpart on the GPU. I tried all the suggestions here including io_binding but nothing worked.

To solve the issue, profiling was enabled via the following code:

import onnxruntime as ort

opts = ort.SessionOptions()
opts.enable_profiling = True

session = ort.InferenceSession(path, opts,
  ["CUDAExecutionProvider", "CPUExecutionProvider"])
session.run(None, inputs)

Once the program exited, a profiling JSON file was generated. I took a look at that to find the longest running nodes.

jq . onnxruntime_profile.json | grep dur | cut -d ":" -f2 | sort --numeric

Skipping nodes for the full model run and session initialization, I was seeing nodes like this: feed_forward/w_1/Conv_kernel_time. Reading the documentation, the following setting stood out, cudnn_conv_algo_search.

The program was re-run with that setting changed (to either HEURISTIC or DEFAULT).

import onnxruntime as ort

opts = ort.SessionOptions()
opts.enable_profiling = True

session = ort.InferenceSession(path, opts, 
  [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}), "CPUExecutionProvider"])
session.run(None, inputs)

This time the performance was equal to or even slightly better than the PyTorch model on the GPU.

I'm not sure why ONNX defaults to an EXHAUSTIVE search. In reading similar code in PyTorch, it doesn't appear PyTorch does (looks like it defaults to what ONNX calls HEURISTIC) and that is the performance difference in my case.

Hope this helps anyone running into performance issues one way or another. Looking at the original post, there were a lot of Conv operations, so it's worth a try.

@tianleiwu
Copy link
Contributor

For UNet (diffusion) model, try the following setting for best performance:

providers=[ ( "CUDAExecutionProvider", {"cudnn_conv_use_max_workspace": "1", "cudnn_conv_algo_search": "EXHAUSTIVE"}),        "CPUExecutionProvider"]

@davidmezzetti, cuDNN convolution algo search only happens once. Even though it is slow in first inference run, the following run might be faster. You can use some warm up queries before serving user queries.

@davidmezzetti
Copy link

davidmezzetti commented Dec 3, 2022

My understanding is that cuDNN only caches the results when the input shape is static. I was able to confirm this same behavior with a Torch model having dynamic input shapes exported to ONNX.

Benchmark mode in PyTorch is what ONNX calls EXHAUSTIVE and EXHAUSTIVE is the default ONNX setting per the documentation. PyTorch defaults to using cudnnGetConvolutionForwardAlgorithm_v7 which is much faster. So in this case with dynamic inputs, it leads to the Torch model appearing to run faster.

I wrote an article with detailed steps on this comparison.
https://medium.com/neuml/debug-onnx-gpu-performance-c9290fe07459

This link also has a related discussion.
https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

7 participants