Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Found regression on ORT 1.8.1 #8513

Open
II245 opened this issue Jul 27, 2021 · 3 comments
Open

Found regression on ORT 1.8.1 #8513

II245 opened this issue Jul 27, 2021 · 3 comments
Labels
api issues related to all other APIs: C, C++, Python, etc. ep:CUDA issues related to the CUDA execution provider

Comments

@II245
Copy link

II245 commented Jul 27, 2021

Describe the bug
Hello! Found the performance issue in new ORT 1.8.1
New ORT 1.8.1 is slower than 1.8.0 ~5-6 times.

Urgency
Not urgent, but
we use C# version of ORT for production environment. Unfortunately we can't use C# ORT 1.8.0 version because of this bug #8052. Old 1.7.0 has performance issues too. Current 1.8.1 has significant performance degradation. Proof on Python below.

System information
OS Platform and Distribution: Linux Ubuntu 18.04: Linux x86_64
ONNX Runtime installed from (source or binary): binary
ONNX Runtime version: onnxruntime-1.8.0 onnxruntime-1.8.1
Python version: Python 3.8.5
Pytorch version: 1.8.1
CUDA/cuDNN version: CUDA Version: 11.1
GPU model and memory: Tesla V100 (32G)

To Reproduce

save this as benchmark_repro.py

import argparse
import time
import torch
import torch.cuda.profiler
import onnxruntime
import numpy as np


def infer_ort(onnxruntime_session, io_binding):
	onnxruntime_session.run_with_iobinding(io_binding)


def get_model(channels):
	return torch.nn.Conv1d(in_channels=channels, out_channels=channels*2, kernel_size=19,
			stride=2, padding=(1 * 19 // 2), dilation=1, groups=1)


def export_onnx(onnx_path, channels, dtype = torch.float16):
	model = get_model(channels)
	torch.set_grad_enabled(False)
	model.eval()
	model.to('cuda', dtype=dtype)
	waveform_input = torch.rand((4, channels, 128), device='cuda', dtype=dtype)

	logits = model(waveform_input)

	torch.onnx.export(
			model, (waveform_input),
			onnx_path,
			opset_version=12,
			export_params=True,
			do_constant_folding=True,
			input_names=['input'],
			dynamic_axes=dict(input={
				0: 'B',
				2: 'T'
			})
	)

	## check export correctness
	onnxruntime_session = onnxruntime.InferenceSession(onnx_path)
	(logits_,) = onnxruntime_session.run(None, dict(input=waveform_input.cpu().to(dtype=dtype).numpy()))

	print((torch.from_numpy(logits_) - logits.cpu()).abs().to(dtype=torch.float32).max())
	assert torch.allclose(
			logits.cpu(),
			torch.from_numpy(logits_),
			**{
				'rtol': 1e-01,
				'atol': 1e-01
			}
	)

parser = argparse.ArgumentParser()
parser.add_argument('--iterations', type=int, default=16)
parser.add_argument('--iterations-warmup', type=int, default=16)
parser.add_argument('--channels', type=int, default=64)
parser.add_argument('--onnx')
parser.add_argument('-B', type=int)
parser.add_argument('-T', type=int)
parser.add_argument('--profile-cuda', action='store_true')
parser.add_argument('--run-with-io-binding', action='store_true')
args = parser.parse_args()

print(args)

dtype = torch.float16
use_cuda = True

if args.onnx:
	export_onnx(args.onnx, args.channels, dtype)
	onnxruntime_session = onnxruntime.InferenceSession(args.onnx)
	model = lambda x: onnxruntime_session.run(None, dict(input=x))
	if args.run_with_io_binding:
		model = lambda io_binding: onnxruntime_session.run_with_iobinding(io_binding)
	pass
else:
	model = get_model(args.channels)
	model.to('cuda')
	model.eval()
	model.to(dtype=dtype)

tictoc = lambda: (use_cuda and torch.cuda.synchronize()) or time.time()

batch_shape = [args.B, args.channels, args.T]

batch = torch.rand(*batch_shape, dtype=torch.float16)

if args.onnx:
	io_binding = onnxruntime_session.io_binding()
	X_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(batch.numpy(), 'cuda', 0)
	io_binding.bind_input(name='input', device_type=X_ortvalue.device_name(), device_id=0, element_type=np.float16,
			shape=X_ortvalue.shape(), buffer_ptr=X_ortvalue.data_ptr())
	io_binding.bind_output('3', 'cuda')
	batch = io_binding
	model = onnxruntime_session
	infer = lambda model, batch: infer_ort(model, batch)
else:
	batch = batch.to(device='cuda')
	infer = lambda model, batch: model(batch).cpu()

print('Warming up for', args.iterations_warmup, 'iterations')
tic_wall = tictoc()
torch.backends.cudnn.benchmark = True
for i in range(args.iterations_warmup):
	y = infer(model, batch)
print('Warmup done in {:.02f} wall clock seconds'.format(tictoc() - tic_wall))
print()

if args.profile_cuda:
	torch.cuda.profiler.start()

print('Starting benchmark for', args.iterations, 'iterations:', 'fwd')
times_fwd = torch.zeros(args.iterations)
batch = torch.rand(*batch_shape, dtype=torch.float16)

if args.onnx:
	io_binding = onnxruntime_session.io_binding()
	X_ortvalue = onnxruntime.OrtValue.ortvalue_from_numpy(batch.numpy(), 'cuda', 0)
	io_binding.bind_input(name='input', device_type=X_ortvalue.device_name(), device_id=0, element_type=np.float16,
			shape=X_ortvalue.shape(), buffer_ptr=X_ortvalue.data_ptr())
	io_binding.bind_output('3', 'cuda')
	batch = io_binding
	model = onnxruntime_session
	infer = lambda model, batch: infer_ort(model, batch)
else:
	batch = batch.to(device='cuda')
	infer = lambda model, batch: model(batch)

for i in range(args.iterations):
	tic = tictoc()
	y = infer(model, batch)
	times_fwd[i] = tictoc() - tic
	y = None

print('Batch shape', batch_shape)
print('average load+fwd {:.02f} msec'.format(float(times_fwd.mean()) * 1e3))

install ORT 1.8.1

pip install onnxruntime==1.8.1 onnxruntime-gpu==1.8.1

reproduce 1.8.1

CUDA_VISIBLE_DEVICES=0 python3 benchmark_repro.py \
  --onnx conv_fp16.onnx \
  --iterations 30 \
  --channels 896 \
  --iterations-warmup 30 \
  --run-with-io-binding \
  -B 32 \
  -T 832

install ORT 1.8.1

pip install onnxruntime==1.8.0 onnxruntime-gpu==1.8.0

reproduce 1.8.0

CUDA_VISIBLE_DEVICES=0 python3 benchmark_repro.py \
  --onnx conv_fp16.onnx \
  --iterations 30 \
  --channels 896 \
  --iterations-warmup 30 \
  --run-with-io-binding \
  -B 32 \
  -T 832

And compare metrics

ORT 1.8.1:
average load+fwd 61.56 msec

vs

ORT 1.8.0:
average load+fwd 11.77 msec

Expected behavior
ORT 1.8.1 works as fast as 1.8.0

Additional context
Maybe this old issue would help: #7212

Thanks!

@hariharans29 hariharans29 added the ep:CUDA issues related to the CUDA execution provider label Jul 28, 2021
@lipo5476
Copy link

lipo5476 commented Sep 8, 2021

Hi,

we have similar issue on Jetson Xavier with CUDA EP in version 1.8.2. In the version 1.7.0 the inference was much faster.

Thanks

@Squire-tomsk
Copy link

Hello. Have same problem with performance on 1.8.1. Are there any updates @yuslepukhin?

@stale
Copy link

stale bot commented Apr 19, 2022

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@stale stale bot added the stale issues that have not been addressed in a while; categorized by a bot label Apr 19, 2022
@sophies927 sophies927 added api issues related to all other APIs: C, C++, Python, etc. and removed api:CSharp labels Aug 12, 2022
@stale stale bot removed the stale issues that have not been addressed in a while; categorized by a bot label Aug 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
api issues related to all other APIs: C, C++, Python, etc. ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

6 participants