Skip to content

Add CUDA option to run copy in default stream#5445

Merged
ke1337 merged 4 commits intomasterfrom
kedeng/stream
Oct 13, 2020
Merged

Add CUDA option to run copy in default stream#5445
ke1337 merged 4 commits intomasterfrom
kedeng/stream

Conversation

@ke1337
Copy link
Copy Markdown
Contributor

@ke1337 ke1337 commented Oct 10, 2020

This change fixes #4829. Thanks @maherzog for providing the repro!

The bug is caused by memory reuse in BFC arena, where copy and
compute stream in CUDA has a racing condition.

BFC arena is an arena allocator on top of cudaMalloc/Free to
reduce the cost in syncing CPU and GPU when alloc/free. It means
when CPU alloc/free the memory, GPU might not finished previous
work on the memory, so that CPU and GPU could run asynchronously.

This is OK if there's only one stream, where the execution order
in CPU and GPU are consistent. For example, if we have two kernels
A and B, CPU runs allocA->computeA->freeA->allocB->computeB->freeB,
A and B could shares the same memory since computeA and computeB
will not have racing as long as they run in the same GPU compute
stream.

However, if CPU runs allocA->CopyA->freeA->allocB->computeB->freeB,
the order of execution in GPU could have copyA happen after computeB,
if copy and compute happens in different GPU streams.

This change makes copy to run in default compute stream, while adding
an option to fall back to previous behavior if there's perf hit. This
is a short term fix before BFC arena could support multiple streams.

User may use following options to revert to previous behavior:
C API:
struct OrtCUDAProviderOptions cudaProviderOpt;
cudaProviderOpt.do_copy_in_default_stream = false;
C++ API:
CUDAExecutionProviderInfo cudaEPInfo;
cudaEPInfo.do_copy_in_default_stream = false;
C# API:
pending...
Python:
import onnxruntime
onnxruntime.capi._pybind_state.set_do_copy_in_default_stream(False)

This change fixes #4829. Thanks @maherzog for providing the repro!

The bug is caused by memory reuse in BFC arena, where copy and
compute stream in CUDA has a racing condition.

BFC arena is an arena allocator on top of cudaMalloc/Free to
reduce the cost in syncing CPU and GPU when alloc/free. It means
when CPU alloc/free the memory, GPU might not finished previous
work on the memory, so that CPU and GPU could run asynchronously.

This is OK if there's only one stream, where the execution order
in CPU and GPU are consistent. For example, if we have two kernels
A and B, CPU runs allocA->computeA->freeA->allocB->computeB->freeB,
A and B could shares the same memory since computeA and computeB
will not have racing as long as they run in the same GPU compute
stream.

However, if CPU runs allocA->CopyA->freeA->allocB->computeB->freeB,
the order of execution in GPU could have copyA happen after computeB,
if copy and compute happens in different GPU streams.

This change makes copy to run in default compute stream, while adding
an option to fall back to previous behavior if there's perf hit. This
is a short term fix before BFC arena could support multiple streams.

User may use following options to revert to previous behavior:
C API:
  struct OrtCUDAProviderOptions cudaProviderOpt;
  cudaProviderOpt.do_copy_in_default_stream = false;
C++ API:
  CUDAExecutionProviderInfo cudaEPInfo;
  cudaEPInfo.do_copy_in_default_stream = false;
C# API:
  pending...
Python:
  import onnxruntime
  onnxruntime.capi._pybind_state.set_do_copy_in_default_stream(False)
@ke1337 ke1337 requested a review from a team as a code owner October 10, 2020 05:07
Comment thread onnxruntime/test/python/onnxruntime_test_python.py Outdated
Comment thread onnxruntime/python/onnxruntime_pybind_state.cc Outdated
Comment thread onnxruntime/core/providers/cuda/cuda_provider_factory.cc
Copy link
Copy Markdown
Contributor

@HectorSVC HectorSVC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Wrong requested shape after a few thousand inference steps when using CUDA

2 participants