Skip to content

Conversation

@weixingzhang
Copy link
Contributor

@weixingzhang weixingzhang commented Aug 7, 2020

For some 1P training task, we found accuracy issue on V100. It turns out that the accumulation for matmul needs to be done in FP32 for training.

Here are the throughput of BERT-L on V100 16GB with Lamb. As expected, the perf is almost same.

  Batch Size Seq Len Throughput (ex/s)
Before 64 128 213.935
  10 512 41.9966
After 64 128 213.228
  10 512 41.6287

To avoid accuracy, the accumulation needs to be done in FP32 for training.
@weixingzhang weixingzhang requested a review from a team as a code owner August 7, 2020 04:04
@weixingzhang weixingzhang added the training issues related to ONNX Runtime training; typically submitted using template label Aug 7, 2020
(const __half**)Barray, ldb,
beta,
(__half**)Carray, ldc,
batch_count);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_count); [](start = 28, length = 13)

nit: Maybe explicitly set CUDA_R_16F here to avoid confusion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I misunderstand your comment. Were you saying to specify CUDA_R_16F when calling cublasHgemmBatched? It doesn't support to set data type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA_R_32F is a case. In another case, we can also explicitly provide that default argument (I wrongly through the default argument is CUDA_R_16F, sorry) for clarify.

inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int m, int n, int k, [](start = 46, length = 20)

nit: some int parameters can be const.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I learned is that const for built-in types in function parameters with pass-by-value is not necessary. See this link: https://abseil.io/tips/109

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice tip!
to quote from that link:

Do use top-level const on function parameters in definitions at your (or your team’s) discretion. You might follow the same rationale as you would for when to declare a function-local variable const.

i think const still has some usefulness here, similar to const local variables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember before when I use const to int as suggested here, the compiler would complain something like "it is not necessary to use const for built-in types.". Moreover, Nvidia doesn't use const for int in these cublas APIs either.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The major benefit of const is for readability to me (so I marked my comment as nit).

inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int m, int n, int k, [](start = 46, length = 20)

Similar to other places, some parameters can be const.

Copy link
Contributor

@wschin wschin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. For training, accumulation should be in fp32.

transb,
m, n, k,
&h_a,
(const void**)Aarray, CUDA_R_16F, lda,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SherlockNoMad
Copy link
Contributor

Please run E2E test and "Pytorch Frontend E2E" test, I not sure if this the batchmatmul result is different enough to affect the expected test values.

@weixingzhang weixingzhang merged commit afa8956 into master Aug 14, 2020
@weixingzhang weixingzhang deleted the wezhan/matmul branch August 14, 2020 09:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants