Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Significant difference in the performance of pytorch and exported onnx models #7212

Open
II245 opened this issue Apr 1, 2021 · 22 comments
Open

Comments

@II245
Copy link

II245 commented Apr 1, 2021

Performance issue
Noticed a significant difference in the performance of pytorch and exported onnx models with a simple conv layer.
The difference is more than 5 times after warming up.

Urgency
none

System information
OS Platform and Distribution: Linux Ubuntu 18.04: Linux x86_64
ONNX Runtime installed from (source or binary): binary
ONNX Runtime version: onnxruntime-1.7.0
Python version: Python 3.8.5
Pytorch version: 1.8.1
CUDA/cuDNN version: CUDA Version: 11.1
GPU model and memory: Tesla V100 (32G)

To Reproduce
Run the script below with parameters

reproduce onnx performance. The script creates onnx model on the fly.

CUDA_VISIBLE_DEVICES=0 python3 benchmark_repro.py \
  --onnx conv_fp16.onnx \
  --iterations 30 \
  --channels 896 \
  --iterations-warmup 30 \
  --run-with-io-binding \
  -B 32 \
  -T 832

output

Batch shape [32, 896, 832]
average load+fwd 35.92 msec

reproduce pytorch

CUDA_VISIBLE_DEVICES=0 python3 benchmark_repro.py \
  --iterations 30 \
  --channels 896 \
  --iterations-warmup 30 \
  -B 32 \
  -T 832

output

Batch shape [32, 896, 832]
average load+fwd 10.41 msec
import argparse
import time
import torch
import torch.cuda.profiler
import onnxruntime
import numpy as np


def infer_ort(onnxruntime_session, io_binding):
	onnxruntime_session.run_with_iobinding(io_binding)


def get_model(channels):
	return torch.nn.Conv1d(in_channels=channels, out_channels=channels*2, kernel_size=19,
			stride=2, padding=(1 * 19 // 2), dilation=1, groups=1)


def export_onnx(onnx_path, channels, dtype = torch.float16):
	model = get_model(channels)
	torch.set_grad_enabled(False)
	model.eval()
	model.to('cuda', dtype=dtype)
	waveform_input = torch.rand((4, channels, 128), device='cuda', dtype=dtype)

	logits = model(waveform_input)

	torch.onnx.export(
			model, (waveform_input),
			onnx_path,
			opset_version=12,
			export_params=True,
			do_constant_folding=True,
			input_names=['input'],
			dynamic_axes=dict(input={
				0: 'B',
				2: 'T'
			})
	)

	## check export correctness
	onnxruntime_session = onnxruntime.InferenceSession(onnx_path)
	(logits_,) = onnxruntime_session.run(None, dict(input=waveform_input.cpu().to(dtype=dtype).numpy()))

	print((torch.from_numpy(logits_) - logits.cpu()).abs().to(dtype=torch.float32).max())
	assert torch.allclose(
			logits.cpu(),
			torch.from_numpy(logits_),
			**{
				'rtol': 1e-01,
				'atol': 1e-01
			}
	)

parser = argparse.ArgumentParser()
parser.add_argument('--iterations', type=int, default=16)
parser.add_argument('--iterations-warmup', type=int, default=16)
parser.add_argument('--channels', type=int, default=64)
parser.add_argument('--onnx')
parser.add_argument('-B', type=int)
parser.add_argument('-T', type=int)
parser.add_argument('--profile-cuda', action='store_true')
parser.add_argument('--run-with-io-binding', action='store_true')
args = parser.parse_args()

print(args)

dtype = torch.float16
use_cuda = True

if args.onnx:
	export_onnx(args.onnx, args.channels, dtype)
	onnxruntime_session = onnxruntime.InferenceSession(args.onnx)
	model = lambda x: onnxruntime_session.run(None, dict(input=x))
	if args.run_with_io_binding:
		model = lambda io_binding: onnxruntime_session.run_with_iobinding(io_binding)
	pass
else:
	model = get_model(args.channels)
	model.to('cuda')
	model.eval()
	model.to(dtype=dtype)

tictoc = lambda: (use_cuda and torch.cuda.synchronize()) or time.time()

batch_shape = [args.B, args.channels, args.T]

batch = torch.rand(*batch_shape, dtype=torch.float16)

if args.onnx:
	io_binding = onnxruntime_session.io_binding()
	X_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(batch.numpy(), 'cuda', 0)
	io_binding.bind_input(name='input', device_type=X_ortvalue.device_name(), device_id=0, element_type=np.float16,
			shape=X_ortvalue.shape(), buffer_ptr=X_ortvalue.data_ptr())
	io_binding.bind_output('3', 'cuda')
	batch = io_binding
	model = onnxruntime_session
	infer = lambda model, batch: infer_ort(model, batch)
else:
	batch = batch.to(device='cuda')
	infer = lambda model, batch: model(batch).cpu()

print('Warming up for', args.iterations_warmup, 'iterations')
tic_wall = tictoc()
torch.backends.cudnn.benchmark = True
for i in range(args.iterations_warmup):
	y = infer(model, batch)
print('Warmup done in {:.02f} wall clock seconds'.format(tictoc() - tic_wall))
print()

if args.profile_cuda:
	torch.cuda.profiler.start()

print('Starting benchmark for', args.iterations, 'iterations:', 'fwd')
times_fwd = torch.zeros(args.iterations)
batch = torch.rand(*batch_shape, dtype=torch.float16)

if args.onnx:
	io_binding = onnxruntime_session.io_binding()
	X_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(batch.numpy(), 'cuda', 0)
	io_binding.bind_input(name='input', device_type=X_ortvalue.device_name(), device_id=0, element_type=np.float16,
			shape=X_ortvalue.shape(), buffer_ptr=X_ortvalue.data_ptr())
	io_binding.bind_output('3', 'cuda')
	batch = io_binding
	model = onnxruntime_session
	infer = lambda model, batch: infer_ort(model, batch)
else:
	batch = batch.to(device='cuda')
	infer = lambda model, batch: model(batch)

for i in range(args.iterations):
	tic = tictoc()
	y = infer(model, batch)
	times_fwd[i] = tictoc() - tic
	y = None

print('Batch shape', batch_shape)
print('average load+fwd {:.02f} msec'.format(float(times_fwd.mean()) * 1e3))

Full cudnn logs here
updated_cudnn_logs.zip

To get logs,traces pelase run:

TRACEFILE=profile_one_conv_onnx_with_warmup.sqlite
TRACELOG=profile_one_conv_onnx_with_warmup.txt
TRACEPYPROFLOG=profile_one_conv_onnx_with_warmup.pyprof.txt
CUDNN_LOGDEST=profile_one_conv_onnx_with_warmup_cudnn_dbg.txt
CUBLAS_LOGDEST=profile_one_conv_onnx_with_warmup_cublas_dbg.txt

CUDA_VISIBLE_DEVICES=0 CUDNN_LOGINFO_DBG=1 CUDNN_LOGDEST_DBG=$CUDNN_LOGDEST CUBLAS_LOGINFO_DBG=1 CUBLAS_LOGDEST_DBG=$CUBLAS_LOGDEST nvprof -f -o $TRACEFILE -s --devices 0 --profile-from-start off -- python3 benchmark_repro.py \
  --onnx conv_fp16.onnx \
  --iterations 1 \
  --iterations-warmup 4 \
  --profile-cuda \
  -B 32 \
  -T 1664 &> $TRACELOG

python3 nvprof2json.py $TRACEFILE > $TRACEFILE.json

Expected behavior
Onnx performance is the same as pytorch.

Additional context

Found that some arguments of cudnnConvolutionForward differ (cudnn log).

pytorch

I! CuDNN (v8005) function cudnnConvolutionForward() called:
i!     handle: type=cudnnHandle_t; streamId=(nil) (defaultStream);
i!     alpha: type=CUDNN_DATA_FLOAT; val=1.000000;
i!     xDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[32,896,1,832];
i!         strideA: type=int; val=[745472,832,832,1];
i!     xData: location=dev; addr=0x7f1f48d80000;
i!     wDesc: type=cudnnFilterDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         vect: type=int; val=0;
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[1792,896,1,19];
i!         format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0);
i!     wData: location=dev; addr=0x7f1fc4000000;
i!     convDesc: type=cudnnConvolutionDescriptor_t:
i!         mode: type=cudnnConvolutionMode_t; val=CUDNN_CROSS_CORRELATION (1);
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i!         mathType: type=cudnnMathType_t; val=CUDNN_TENSOR_OP_MATH (1);
i!         reorderType: type=int; val=0;
i!         arrayLength: type=int; val=2;
i!         padA: type=int; val=[0,9];
i!         strideA: type=int; val=[1,2];
i!         dilationA: type=int; val=[1,1];
i!         groupCount: type=int; val=1;
i!     algo: type=cudnnConvolutionFwdAlgo_t; val=CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM (1);
i!     workSpace: location=dev; addr=0x7f1f24d80000;
i!     workSpaceSizeInBytes: type=size_t; val=217458192;
i!     beta: type=CUDNN_DATA_FLOAT; val=0.000000;
i!     yDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[32,1792,1,416];
i!         strideA: type=int; val=[745472,416,416,1];
i!     yData: location=dev; addr=0x7f1f22000000;
i! Time: 2021-04-02T20:20:59.364486 (0d+0h+0m+7s since start)
i! Process=20457; Thread=20457; GPU=0; Handle=0x562d0f5507f0; StreamId=(nil) (defaultStream).

onnx

I! CuDNN (v8005) function cudnnConvolutionForward() called:
i!     handle: type=cudnnHandle_t; streamId=0x556d1919df70;
i!     alpha: type=CUDNN_DATA_FLOAT; val=1.000000;
i!     xDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[32,896,832,1];
i!         strideA: type=int; val=[745472,832,1,1];
i!     xData: location=dev; addr=0x7fb90c000000;
i!     wDesc: type=cudnnFilterDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         vect: type=int; val=0;
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[1792,896,19,1];
i!         format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0);
i!     wData: location=dev; addr=0x7fb97c000000;
i!     convDesc: type=cudnnConvolutionDescriptor_t:
i!         mode: type=cudnnConvolutionMode_t; val=CUDNN_CROSS_CORRELATION (1);
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i!         reorderType: type=int; val=0;
i!         arrayLength: type=int; val=2;
i!         padA: type=int; val=[9,0];
i!         strideA: type=int; val=[2,1];
i!         dilationA: type=int; val=[1,1];
i!         groupCount: type=int; val=1;
i!     algo: type=cudnnConvolutionFwdAlgo_t; val=CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM (0);
i!     workSpace: location=dev; addr=NULL_PTR;
i!     workSpaceSizeInBytes: type=size_t; val=0;
i!     beta: type=CUDNN_DATA_FLOAT; val=0.000000;
i!     yDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[32,1792,416,1];
i!         strideA: type=int; val=[745472,416,1,1];
i!     yData: location=dev; addr=0x7fb914000000;
i! Time: 2021-04-02T20:20:42.158192 (0d+0h+0m+4s since start)
i! Process=20361; Thread=20361; GPU=0; Handle=0x556d193e9ea0; StreamId=0x556d1919df70.

List of noted differences:

  • xDesc.dimA and strideA onnx ([32,896,832,1] vs pytorch [32,896,1,832])
  • convDesc.dataType onnx val=CUDNN_DATA_HALF (2) vs pytorch val=CUDNN_DATA_FLOAT (0)
  • and time to time algo val becomes CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM instead of CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM

Also
In tracing used different matrix kernels
pytorch: volta_fp16_s884cudnn_fp16_128x128_ldg8_relu_f2f_exp_large_nhwc2nchw_tn_v1
onnx: implicit_convolve_hhgemm

Could you please answer what is the cause of the differences?
Or what do I wrong in benchmarking or onnx setup?

UPDATE:

  • Added io_binding. Copy result to cpu in both cases.
  • Changed script params and updated timings. Now performance differs two times.
  • Cudnn logs and gemm names updated.

UPDATE 2:

  • removed copy from gpy to cpu
  • update performance comparison
    now pytorch ~3.5 times faster.
@pranavsharma
Copy link
Contributor

We'll investigate and get back. cc @RandyShuai

@pranavsharma
Copy link
Contributor

pranavsharma commented Apr 2, 2021

I think this is what is happening. In the pytorch case, you're calling tensor.to('cuda'....). This will incur the cost of copying the tensor from pinned memory to cuda only once; the subsequent iterations will be no-op. In the case of ORT, the copy is done every single time when you call run(). A better comparison is to ensure your input is copied to the GPU only once in both cases before run() is called. You can achieve this using iobinding.

@II245
Copy link
Author

II245 commented Apr 2, 2021

Thanks for the reply!
Description updated.
Do you know how to synchronize cuda via ORT? It would help me to avoid gpu to cpu copy.

@II245
Copy link
Author

II245 commented Apr 2, 2021

Allclose assert starts to fail after I increased a value in the channel dimension, should I report a bug?

@II245
Copy link
Author

II245 commented Apr 2, 2021

Now ORT setup uses convDesc.mathType.CUDNN_DEFAULT_MATH in cudnnConvolutionForward which is different from pytorch CUDNN_TENSOR_OP_MATH maybe you know why?

@pranavsharma
Copy link
Contributor

Do you see any difference in performance after using iobinding?

@II245
Copy link
Author

II245 commented Apr 5, 2021

Yes, I still see difference. Current measurements after using iobinding:
pytorch:
Batch shape [32, 896, 832]
average load+fwd 10.41 msec
onnxruntime:
Batch shape [32, 896, 832]
average load+fwd 35.92 msec

@II245
Copy link
Author

II245 commented Apr 5, 2021

May also sound interesting.
There are differences in cudnnFindConvolutionForwardAlgorithmEx functions

Pytorch
I! CuDNN (v8005) function cudnnFindConvolutionForwardAlgorithmEx() called:
i!     handle: type=cudnnHandle_t; streamId=(nil) (defaultStream);
i!     srcDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[32,896,1,832];
i!         strideA: type=int; val=[745472,832,832,1];
i!     srcData: location=dev; addr=0x7f1f46000000;
i!     filterDesc: type=cudnnFilterDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         vect: type=int; val=0;
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[1792,896,1,19];
i!         format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0);
i!     filterData: location=dev; addr=0x7f1fc4000000;
i!     convDesc: type=cudnnConvolutionDescriptor_t:
i!         mode: type=cudnnConvolutionMode_t; val=CUDNN_CROSS_CORRELATION (1);
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i!         mathType: type=cudnnMathType_t; val=CUDNN_TENSOR_OP_MATH (1);
i!         reorderType: type=int; val=0;
i!         arrayLength: type=int; val=2;
i!         padA: type=int; val=[0,9];
i!         strideA: type=int; val=[1,2];
i!         dilationA: type=int; val=[1,1];
i!         groupCount: type=int; val=1;
i!     destDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[32,1792,1,416];
i!         strideA: type=int; val=[745472,416,416,1];
i!     destData: location=dev; addr=0x7f1f48d80000;
i!     requestedAlgoCount: type=int; val=8;
i!     returnedAlgoCount: location=host; addr=0x7ffd1db1e260;
i!     perfResults: location=host; addr=0x562d4de29310;
i!     workSpace: location=dev; addr=0x7f1f22000000;
i!     workSpaceSizeInBytes: type=size_t; val=453246976;
i! Time: 2021-04-02T20:20:58.376987 (0d+0h+0m+6s since start)
i! Process=20457; Thread=20457; GPU=0; Handle=0x562d0f5507f0; StreamId=(nil) (defaultStream).
Onnxruntime
I! CuDNN (v8005) function cudnnFindConvolutionForwardAlgorithmEx() called:
i!     handle: type=cudnnHandle_t; streamId=0x556d192637c0;
i!     srcDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[4,896,128,1];
i!         strideA: type=int; val=[114688,128,1,1];
i!     srcData: location=dev; addr=0x7fbac2c7f200;
i!     filterDesc: type=cudnnFilterDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         vect: type=int; val=0;
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[1792,896,19,1];
i!         format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0);
i!     filterData: location=dev; addr=0x7fb97c000000;
i!     convDesc: type=cudnnConvolutionDescriptor_t:
i!         mode: type=cudnnConvolutionMode_t; val=CUDNN_CROSS_CORRELATION (1);
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         mathType: type=cudnnMathType_t; val=CUDNN_TENSOR_OP_MATH (1);
i!         reorderType: type=int; val=0;
i!         arrayLength: type=int; val=2;
i!         padA: type=int; val=[9,0];
i!         strideA: type=int; val=[2,1];
i!         dilationA: type=int; val=[1,1];
i!         groupCount: type=int; val=1;
i!     destDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[4,1792,64,1];
i!         strideA: type=int; val=[114688,64,1,1];
i!     destData: location=dev; addr=0x7fb97fc00000;
i!     requestedAlgoCount: type=int; val=1;
i!     returnedAlgoCount: location=host; addr=0x7ffd76b1e7f0;
i!     perfResults: location=host; addr=0x7ffd76b1ea70;
i!     workSpace: location=dev; addr=0x7fb9a2000000;
i!     workSpaceSizeInBytes: type=size_t; val=33554432;
i! Time: 2021-04-02T20:20:39.710280 (0d+0h+0m+1s since start)
i! Process=20361; Thread=20361; GPU=0; Handle=0x556d19454340; StreamId=0x556d192637c0.

@II245
Copy link
Author

II245 commented Apr 5, 2021

And could you please answer what is the reason for the differences (datatypes algotypes strides)?

@II245
Copy link
Author

II245 commented Apr 5, 2021

We found interesting thing. If I preconfigure ORT with provider like:

	providers = [
		('CUDAExecutionProvider', {
			'cudnn_conv_algo_search': 'DEFAULT',
		}),
		'CPUExecutionProvider',
	]
	onnxruntime_session = onnxruntime.InferenceSession(args.onnx, providers=providers)

metrics becomes like in pytorch version:
average load+fwd 10.88 msec

@II245
Copy link
Author

II245 commented Apr 5, 2021

Now if we compare cudnn logs of 'DEFAULT' with 'EXHAUSTIVE' We will see that mathType has changed as well.
image
('DEFAULT' on the left part)

@SherlockNoMad
Copy link
Contributor

SherlockNoMad commented Apr 6, 2021

This PR #7227 might have address the issue.
We need to set the cudnnSetConvolutionMathType properly before doing the algo search.
see changes in conv.cc CudnnConvolutionDescriptor::Set()

I will implement the algorithm search for ConvGrad in the next PR. Will look at this problem all together.

Please let me know if this fix has address the issue.

@hariharans29
Copy link
Member

@SherlockNoMad,

But isn't the math type getting set properly before we do the algo search even without your addition ? I am referring to these lines:

  // set math type to tensor core before algorithm search
  if (std::is_same<T, MLFloat16>::value)
    CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); 

@II245
Copy link
Author

II245 commented Apr 8, 2021

Hi! @SherlockNoMad I took measurements on ORT nightly. It gets worse. Was ~35 msec (ORT 1.7), now ~61 msec (1.7.0.dev202104071).

@vadimkantorov
Copy link

@SherlockNoMad We along with @II245 will upload the CUDNN logs / nvprof traces, but it seems that the fix did not help

@SherlockNoMad
Copy link
Contributor

Thansk @vadimkantorov, I will revert the change for now to unblock you. Will investigate in the meantime.

@vadimkantorov
Copy link

Thanks! I think it would be great to have a Conv perf test in ORT that compares performance against vanilla PyTorch ones (actually I wonder if PyTorch has similar test in their benchmarks?)

For speech recognition, Conv1d performance is critical, so such bugs should be discoverable easily :) 3x-4x perf degradation undermines greatly suitability of ORT as inference backend...

@SherlockNoMad
Copy link
Contributor

Definitely...
I also found that we are lacking some op level micro benchmark to help us catch such regression.
I will use your inputs/kernels shape to verify this fix.

@vadimkantorov
Copy link

For speech recognition models for functional perf testing, maybe you can use the Nvidia Nemo architectures or the ones from hugging face (jaspernet, quartznet etc)

@SherlockNoMad
Copy link
Contributor

Hi @vadimkantorov, hope this PR address your problem.

@II245
Copy link
Author

II245 commented Apr 28, 2021

Hi @SherlockNoMad ! Thank you very much for solving the represented case.
I make measurements and now both PyTorch and ORT works well and equally.

After an examination of a whole production model, I found that in some cases, which previously worked almost well with CUDA Provider fix (mention above). Now (with nightly build) model shows some performance degradation.

This is about cases when GPU memory almost full.
I suppose it somehow linked with free memory for algo search.

Let me show you the table:

B T forward time, msec expected forward time, msec
32 120 914.24 914.24
48 120 4066.77 1.5 * 914.24 = ~ 1400

where B is batch size;
and T time dimension in seconds.

With B=48 I expect 1400 msec of an average run.

Could you please investigate this case too?
It looks like nonnormal behavior.

Model is here:
https://disk.yandex.com/d/nTCPSCn67eL9sQ

Repro is here:

CUDA_VISIBLE_DEVICES=0 python3 perf_ort.py  --onnx fp16_model.onnx --iterations 3 --iterations-warmup 3 --run-with-io-binding -B 32 -T 120
import argparse
import time
import torch
import torch.cuda.profiler
import onnxruntime
import numpy as np


def infer_ort(onnxruntime_session, io_binding):
	onnxruntime_session.run_with_iobinding(io_binding)


parser = argparse.ArgumentParser()
parser.add_argument('--iterations', type=int, default=16)
parser.add_argument('--iterations-warmup', type=int, default=16)
parser.add_argument('--channels', type=int, default=64)
parser.add_argument('--sample-rate', type=int, default=8000)
parser.add_argument('--onnx')
parser.add_argument('-B', type=int)
parser.add_argument('-T', type=int)
parser.add_argument('--profile-cuda', action='store_true')
parser.add_argument('--run-with-io-binding', action='store_true')
args = parser.parse_args()

print(args)

dtype = torch.float16
use_cuda = True

if args.onnx:
	providers = [
		('CUDAExecutionProvider', {
			'device_id': 0,
			'gpu_mem_limit': 6 * 1024 * 1024 * 1024,
		}),
		'CPUExecutionProvider',
	]

	onnxruntime_session = onnxruntime.InferenceSession(args.onnx, providers = providers)

	model = lambda x: onnxruntime_session.run(None, dict(x=x))
	if args.run_with_io_binding:
		model = lambda io_binding: onnxruntime_session.run_with_iobinding(io_binding)
	pass


tictoc = lambda: (use_cuda and torch.cuda.synchronize()) or time.time()

batch_shape = [args.B, args.T * args.sample_rate]

batch = torch.rand(*batch_shape, dtype=torch.float32)

if args.onnx:
	io_binding = onnxruntime_session.io_binding()
	X_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(batch.numpy(), 'cuda', 0)
	io_binding.bind_input(name='x', device_type=X_ortvalue.device_name(), device_id=0, element_type=np.float32,
			shape=X_ortvalue.shape(), buffer_ptr=X_ortvalue.data_ptr())
	io_binding.bind_output('logits', 'cuda')
	batch = io_binding
	model = onnxruntime_session
	infer = lambda model, batch: infer_ort(model, batch)


print('Warming up for', args.iterations_warmup, 'iterations')
tic_wall = tictoc()
torch.backends.cudnn.benchmark = True
for i in range(args.iterations_warmup):
	y = infer(model, batch)
print('Warmup done in {:.02f} wall clock seconds'.format(tictoc() - tic_wall))
print()

if args.profile_cuda:
	torch.cuda.profiler.start()

print('Starting benchmark for', args.iterations, 'iterations:', 'fwd')
times_fwd = torch.zeros(args.iterations)
batch = torch.rand(*batch_shape, dtype=torch.float32)


if args.onnx:
	io_binding = onnxruntime_session.io_binding()
	X_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(batch.numpy(), 'cuda', 0)
	io_binding.bind_input(name='x', device_type=X_ortvalue.device_name(), device_id=0, element_type=np.float32,
			shape=X_ortvalue.shape(), buffer_ptr=X_ortvalue.data_ptr())
	io_binding.bind_output('logits', 'cuda')
	batch = io_binding
	model = onnxruntime_session
	infer = lambda model, batch: infer_ort(model, batch)


for i in range(args.iterations):
	tic = tictoc()
	y = infer(model, batch)
	times_fwd[i] = tictoc() - tic
	y = None

print('Batch shape', batch_shape)
print('average load+fwd {:.02f} msec'.format(float(times_fwd.mean()) * 1e3))

cc @vadimkantorov

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants