Skip to content

Use cublasHgemm "back" for fp16 computation with Volta GPU#3765

Merged
pengwa merged 3 commits intomasterfrom
pengwa/fp16_gemm
Apr 30, 2020
Merged

Use cublasHgemm "back" for fp16 computation with Volta GPU#3765
pengwa merged 3 commits intomasterfrom
pengwa/fp16_gemm

Conversation

@pengwa
Copy link
Copy Markdown
Contributor

@pengwa pengwa commented Apr 30, 2020

Description: Use cublasHgemm for fp16 computation with Volta GPU

cublasHgemm is used for fp16 computation with Volta GPU before training code is merged into master. For historical reasons when we did master->old training branch internally merge, we commented out that path. So I used "back" in PR title to indicate this change is just re-enable existing path.

This change should do no harm for inference because it is the case already in master before training is merged.
This change bring perf improvement on training, tested on 32GV100.

image

The reasons exists here https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode.
cublasGemmEx's computation type is CUDA_R_32F, though its main data inputs.outputs are CUDA_R_16F. cublasHgemm did CUDA_R_16F computation.

Motivation and Context

  • Why is this change required? What problem does it solve?
  • If it fixes an open issue, please link to the issue here.

@pengwa pengwa requested a review from a team as a code owner April 30, 2020 13:20
@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Apr 30, 2020
@pengwa pengwa changed the title Use cublasHgemm for fp16 computation with Volta GPU Use cublasHgemm "back" for fp16 computation with Volta GPU Apr 30, 2020
@pengwa
Copy link
Copy Markdown
Contributor Author

pengwa commented Apr 30, 2020

@SherlockNoMad feel free to take over this PR if we want it asap in latest benchmarking.

@pengwa pengwa merged commit 177c135 into master Apr 30, 2020
@pengwa pengwa deleted the pengwa/fp16_gemm branch April 30, 2020 16:36
Copy link
Copy Markdown
Contributor

@weixingzhang weixingzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you verify convergence? previously, the accumulation is done in FP32. But with this change, the accumulation will be done in FP16. For training, probably it is not a good idea to use hgemm, even it is OK for bert-l, but may not be OK for other big models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants