Skip to content

Remove Cast before and after Gelu#11885

Merged
Lafi7e merged 3 commits intomasterfrom
weicwang/cast_gelu
Jun 22, 2022
Merged

Remove Cast before and after Gelu#11885
Lafi7e merged 3 commits intomasterfrom
weicwang/cast_gelu

Conversation

@Lafi7e
Copy link
Contributor

@Lafi7e Lafi7e commented Jun 17, 2022

Some mix-precision models will cast data to float before Gelu and cast back to half after. This PR changes the Gelu and GeluGrad CUDA kernel to use float for compute internally, and remove these Cast nodes from graph on CUDA EP.

  • Change Gelu/GeluGrad CUDA kernel to use float internally will not harm the perf, while we can have better precision.
  • This PR doesn't handle ROCm EP as currently we don't have Gelu/GeluGrad ROCm kernel.
  • This PR doesn't handle FastGelu, as the kernel implementation for FastGelu is to complicated to changed to use float internally.

@Lafi7e Lafi7e added the training issues related to ONNX Runtime training; typically submitted using template label Jun 17, 2022
@tianleiwu
Copy link
Contributor

tianleiwu commented Jun 17, 2022

Is there a way to generalize the fusion for element-wise operator, like
x-->Element-wise Op1 --> Element-wise Op2--> can be fused as Element-wise Op2(Op1(x))

@ytaous
Copy link
Contributor

ytaous commented Jun 17, 2022

Wouldn't the Cast Propagation feature (#7454) that Satya implemented already covers this case?

@Lafi7e
Copy link
Contributor Author

Lafi7e commented Jun 20, 2022

Wouldn't the Cast Propagation feature (#7454) that Satya implemented already covers this case?

Right. Previously I just want to enabled this for CUDA only, but the PropagateCastOps cannot do this. But on second thought, we don't have MLFloat16 kernel for CPU for Gelu/GeluGrad, it's safe to use PropagateCastOps. I've changed the code to use PropagateCastOps by promoting Gelu from Level2 Ops to Level1, which also solves this issue.

@Lafi7e Lafi7e merged commit 03beed0 into master Jun 22, 2022
@Lafi7e Lafi7e deleted the weicwang/cast_gelu branch June 22, 2022 01:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants