-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Significant difference in the performance of pytorch and exported onnx models #7212
Comments
We'll investigate and get back. cc @RandyShuai |
I think this is what is happening. In the pytorch case, you're calling tensor.to('cuda'....). This will incur the cost of copying the tensor from pinned memory to cuda only once; the subsequent iterations will be no-op. In the case of ORT, the copy is done every single time when you call run(). A better comparison is to ensure your input is copied to the GPU only once in both cases before run() is called. You can achieve this using iobinding. |
Thanks for the reply! |
Allclose assert starts to fail after I increased a value in the channel dimension, should I report a bug? |
Now ORT setup uses convDesc.mathType.CUDNN_DEFAULT_MATH in cudnnConvolutionForward which is different from pytorch CUDNN_TENSOR_OP_MATH maybe you know why? |
Do you see any difference in performance after using iobinding? |
Yes, I still see difference. Current measurements after using iobinding: |
May also sound interesting. PytorchI! CuDNN (v8005) function cudnnFindConvolutionForwardAlgorithmEx() called: i! handle: type=cudnnHandle_t; streamId=(nil) (defaultStream); i! srcDesc: type=cudnnTensorDescriptor_t: i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2); i! nbDims: type=int; val=4; i! dimA: type=int; val=[32,896,1,832]; i! strideA: type=int; val=[745472,832,832,1]; i! srcData: location=dev; addr=0x7f1f46000000; i! filterDesc: type=cudnnFilterDescriptor_t: i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2); i! vect: type=int; val=0; i! nbDims: type=int; val=4; i! dimA: type=int; val=[1792,896,1,19]; i! format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0); i! filterData: location=dev; addr=0x7f1fc4000000; i! convDesc: type=cudnnConvolutionDescriptor_t: i! mode: type=cudnnConvolutionMode_t; val=CUDNN_CROSS_CORRELATION (1); i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0); i! mathType: type=cudnnMathType_t; val=CUDNN_TENSOR_OP_MATH (1); i! reorderType: type=int; val=0; i! arrayLength: type=int; val=2; i! padA: type=int; val=[0,9]; i! strideA: type=int; val=[1,2]; i! dilationA: type=int; val=[1,1]; i! groupCount: type=int; val=1; i! destDesc: type=cudnnTensorDescriptor_t: i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2); i! nbDims: type=int; val=4; i! dimA: type=int; val=[32,1792,1,416]; i! strideA: type=int; val=[745472,416,416,1]; i! destData: location=dev; addr=0x7f1f48d80000; i! requestedAlgoCount: type=int; val=8; i! returnedAlgoCount: location=host; addr=0x7ffd1db1e260; i! perfResults: location=host; addr=0x562d4de29310; i! workSpace: location=dev; addr=0x7f1f22000000; i! workSpaceSizeInBytes: type=size_t; val=453246976; i! Time: 2021-04-02T20:20:58.376987 (0d+0h+0m+6s since start) i! Process=20457; Thread=20457; GPU=0; Handle=0x562d0f5507f0; StreamId=(nil) (defaultStream). OnnxruntimeI! CuDNN (v8005) function cudnnFindConvolutionForwardAlgorithmEx() called: i! handle: type=cudnnHandle_t; streamId=0x556d192637c0; i! srcDesc: type=cudnnTensorDescriptor_t: i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2); i! nbDims: type=int; val=4; i! dimA: type=int; val=[4,896,128,1]; i! strideA: type=int; val=[114688,128,1,1]; i! srcData: location=dev; addr=0x7fbac2c7f200; i! filterDesc: type=cudnnFilterDescriptor_t: i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2); i! vect: type=int; val=0; i! nbDims: type=int; val=4; i! dimA: type=int; val=[1792,896,19,1]; i! format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0); i! filterData: location=dev; addr=0x7fb97c000000; i! convDesc: type=cudnnConvolutionDescriptor_t: i! mode: type=cudnnConvolutionMode_t; val=CUDNN_CROSS_CORRELATION (1); i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2); i! mathType: type=cudnnMathType_t; val=CUDNN_TENSOR_OP_MATH (1); i! reorderType: type=int; val=0; i! arrayLength: type=int; val=2; i! padA: type=int; val=[9,0]; i! strideA: type=int; val=[2,1]; i! dilationA: type=int; val=[1,1]; i! groupCount: type=int; val=1; i! destDesc: type=cudnnTensorDescriptor_t: i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2); i! nbDims: type=int; val=4; i! dimA: type=int; val=[4,1792,64,1]; i! strideA: type=int; val=[114688,64,1,1]; i! destData: location=dev; addr=0x7fb97fc00000; i! requestedAlgoCount: type=int; val=1; i! returnedAlgoCount: location=host; addr=0x7ffd76b1e7f0; i! perfResults: location=host; addr=0x7ffd76b1ea70; i! workSpace: location=dev; addr=0x7fb9a2000000; i! workSpaceSizeInBytes: type=size_t; val=33554432; i! Time: 2021-04-02T20:20:39.710280 (0d+0h+0m+1s since start) i! Process=20361; Thread=20361; GPU=0; Handle=0x556d19454340; StreamId=0x556d192637c0. |
And could you please answer what is the reason for the differences (datatypes algotypes strides)? |
We found interesting thing. If I preconfigure ORT with provider like: providers = [
('CUDAExecutionProvider', {
'cudnn_conv_algo_search': 'DEFAULT',
}),
'CPUExecutionProvider',
]
onnxruntime_session = onnxruntime.InferenceSession(args.onnx, providers=providers) metrics becomes like in pytorch version: |
We believe that the point is on how ORT works with find function. https://github.com/microsoft/onnxruntime/blob/d01006fc222799f879c8ce70edd58e97f53b7767/onnxruntime/core/providers/cuda/nn/conv.cc |
This PR #7227 might have address the issue. I will implement the algorithm search for ConvGrad in the next PR. Will look at this problem all together. Please let me know if this fix has address the issue. |
But isn't the math type getting set properly before we do the algo search even without your addition ? I am referring to these lines:
|
Hi! @SherlockNoMad I took measurements on ORT nightly. It gets worse. Was ~35 msec (ORT 1.7), now ~61 msec (1.7.0.dev202104071). |
@SherlockNoMad We along with @II245 will upload the CUDNN logs / nvprof traces, but it seems that the fix did not help |
Thansk @vadimkantorov, I will revert the change for now to unblock you. Will investigate in the meantime. |
Thanks! I think it would be great to have a Conv perf test in ORT that compares performance against vanilla PyTorch ones (actually I wonder if PyTorch has similar test in their benchmarks?) For speech recognition, Conv1d performance is critical, so such bugs should be discoverable easily :) 3x-4x perf degradation undermines greatly suitability of ORT as inference backend... |
Definitely... |
For speech recognition models for functional perf testing, maybe you can use the Nvidia Nemo architectures or the ones from hugging face (jaspernet, quartznet etc) |
Hi @vadimkantorov, hope this PR address your problem. |
Hi @SherlockNoMad ! Thank you very much for solving the represented case. After an examination of a whole production model, I found that in some cases, which previously worked almost well with CUDA Provider fix (mention above). Now (with nightly build) model shows some performance degradation. This is about cases when GPU memory almost full. Let me show you the table:
where B is batch size; With B=48 I expect 1400 msec of an average run. Could you please investigate this case too? Model is here: Repro is here: CUDA_VISIBLE_DEVICES=0 python3 perf_ort.py --onnx fp16_model.onnx --iterations 3 --iterations-warmup 3 --run-with-io-binding -B 32 -T 120 import argparse
import time
import torch
import torch.cuda.profiler
import onnxruntime
import numpy as np
def infer_ort(onnxruntime_session, io_binding):
onnxruntime_session.run_with_iobinding(io_binding)
parser = argparse.ArgumentParser()
parser.add_argument('--iterations', type=int, default=16)
parser.add_argument('--iterations-warmup', type=int, default=16)
parser.add_argument('--channels', type=int, default=64)
parser.add_argument('--sample-rate', type=int, default=8000)
parser.add_argument('--onnx')
parser.add_argument('-B', type=int)
parser.add_argument('-T', type=int)
parser.add_argument('--profile-cuda', action='store_true')
parser.add_argument('--run-with-io-binding', action='store_true')
args = parser.parse_args()
print(args)
dtype = torch.float16
use_cuda = True
if args.onnx:
providers = [
('CUDAExecutionProvider', {
'device_id': 0,
'gpu_mem_limit': 6 * 1024 * 1024 * 1024,
}),
'CPUExecutionProvider',
]
onnxruntime_session = onnxruntime.InferenceSession(args.onnx, providers = providers)
model = lambda x: onnxruntime_session.run(None, dict(x=x))
if args.run_with_io_binding:
model = lambda io_binding: onnxruntime_session.run_with_iobinding(io_binding)
pass
tictoc = lambda: (use_cuda and torch.cuda.synchronize()) or time.time()
batch_shape = [args.B, args.T * args.sample_rate]
batch = torch.rand(*batch_shape, dtype=torch.float32)
if args.onnx:
io_binding = onnxruntime_session.io_binding()
X_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(batch.numpy(), 'cuda', 0)
io_binding.bind_input(name='x', device_type=X_ortvalue.device_name(), device_id=0, element_type=np.float32,
shape=X_ortvalue.shape(), buffer_ptr=X_ortvalue.data_ptr())
io_binding.bind_output('logits', 'cuda')
batch = io_binding
model = onnxruntime_session
infer = lambda model, batch: infer_ort(model, batch)
print('Warming up for', args.iterations_warmup, 'iterations')
tic_wall = tictoc()
torch.backends.cudnn.benchmark = True
for i in range(args.iterations_warmup):
y = infer(model, batch)
print('Warmup done in {:.02f} wall clock seconds'.format(tictoc() - tic_wall))
print()
if args.profile_cuda:
torch.cuda.profiler.start()
print('Starting benchmark for', args.iterations, 'iterations:', 'fwd')
times_fwd = torch.zeros(args.iterations)
batch = torch.rand(*batch_shape, dtype=torch.float32)
if args.onnx:
io_binding = onnxruntime_session.io_binding()
X_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(batch.numpy(), 'cuda', 0)
io_binding.bind_input(name='x', device_type=X_ortvalue.device_name(), device_id=0, element_type=np.float32,
shape=X_ortvalue.shape(), buffer_ptr=X_ortvalue.data_ptr())
io_binding.bind_output('logits', 'cuda')
batch = io_binding
model = onnxruntime_session
infer = lambda model, batch: infer_ort(model, batch)
for i in range(args.iterations):
tic = tictoc()
y = infer(model, batch)
times_fwd[i] = tictoc() - tic
y = None
print('Batch shape', batch_shape)
print('average load+fwd {:.02f} msec'.format(float(times_fwd.mean()) * 1e3)) |
Performance issue
Noticed a significant difference in the performance of pytorch and exported onnx models with a simple conv layer.
The difference is more than 5 times after warming up.
Urgency
none
System information
OS Platform and Distribution: Linux Ubuntu 18.04: Linux x86_64
ONNX Runtime installed from (source or binary): binary
ONNX Runtime version: onnxruntime-1.7.0
Python version: Python 3.8.5
Pytorch version: 1.8.1
CUDA/cuDNN version: CUDA Version: 11.1
GPU model and memory: Tesla V100 (32G)
To Reproduce
Run the script below with parameters
reproduce onnx performance. The script creates onnx model on the fly.
output
reproduce pytorch
output
Full cudnn logs here
updated_cudnn_logs.zip
To get logs,traces pelase run:
Expected behavior
Onnx performance is the same as pytorch.
Additional context
Found that some arguments of cudnnConvolutionForward differ (cudnn log).
pytorch
onnx
List of noted differences:
Also
In tracing used different matrix kernels
pytorch: volta_fp16_s884cudnn_fp16_128x128_ldg8_relu_f2f_exp_large_nhwc2nchw_tn_v1
onnx: implicit_convolve_hhgemm
Could you please answer what is the cause of the differences?
Or what do I wrong in benchmarking or onnx setup?
UPDATE:
UPDATE 2:
now pytorch ~3.5 times faster.
The text was updated successfully, but these errors were encountered: