Skip to content

Avoid cudaStreamSync at the end of Forward/Backward#9470

Merged
SherlockNoMad merged 2 commits intomasterfrom
bahuang/no_stream_sync
Oct 21, 2021
Merged

Avoid cudaStreamSync at the end of Forward/Backward#9470
SherlockNoMad merged 2 commits intomasterfrom
bahuang/no_stream_sync

Conversation

@SherlockNoMad
Copy link
Copy Markdown
Contributor

@SherlockNoMad SherlockNoMad commented Oct 21, 2021

As ORTModule should match the behavior of nn.module, we don't need to explicitly introduce a cudaStreamSync at the end of each subgraph execution.

Addition cudaStreamSync at the end of forward
As shown in the profiling result below, ORTModule run has an extra “cudaStreamSync” call at the end of forward section. This was introduced as the finalizing step for InferenceSession::PartialRun(). This behavior is copied from original InferenceSession::Run() code when we implemented PartialRun executor.
However, PyTorch would automatically introduce “cudaStreamSync” if following CPU computation has dependency on a GPU tensor. In another word, ORT doesn’t need to introduce this call explicitly.

image

Warmup Patterns after cudaStreamSync
As we zoom in to the time segment following cudaStreamSync call, we can see a time window lasting ~4ms that GPU is barely utilized. As the tasks in compute stream are depleted with cudaStreamSync call, CPU needs to refill the compute stream from scratch. This resulted in the GPU starvation, as CPU is not able to launch the kernels fast enough, worsen by the fact that the scheduled kernels are short to complete (<10us). The starving situation is eventually relieved when a larger kernel kicks in, taking up >100 us, giving time for CPU to catch up with the scheduling work.

image

@SherlockNoMad SherlockNoMad added training issues related to ONNX Runtime training; typically submitted using template component:ortmodule labels Oct 21, 2021
@SherlockNoMad SherlockNoMad merged commit ff23b9f into master Oct 21, 2021
@SherlockNoMad SherlockNoMad deleted the bahuang/no_stream_sync branch October 21, 2021 18:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants