Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cublasLt Fp8 Approx Gelu epilogue fusion. #7751

Closed
wants to merge 3 commits into from

Conversation

wenscarl
Copy link
Contributor

Due to fast accumulation being turned on in the forward mode, the cublasLt fp8 gemm with gelu epilogue can efficiently operate with a fused kernel. Compared against the XLA-generated gelu kernel on H100, the performance demonstrates some improvement for size of [8192, 4096] x [4096, 16384] + gelu:

Execution time for matmul using cublasLt and gelu (XLA): 1.28ms
Execution time for matmul_gelu using cublasLt: 1.25ms

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Dec 13, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Dec 13, 2023
@reedwm reedwm self-requested a review December 13, 2023 20:22
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Dec 14, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Dec 14, 2023
ENTRY test {
x = f8e4m3fn[16,32] parameter(0)
y = f8e4m3fn[32,16] parameter(1)
x_f32 = bf16[16,32] convert(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be named x_bf16. And similarly for y. And same for the other test

@wenscarl wenscarl requested a review from reedwm December 14, 2023 03:16
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Dec 14, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Dec 14, 2023
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 15, 2023
Imported from GitHub PR openxla/xla#7751

Due to fast accumulation being turned on in the forward mode, the cublasLt fp8 gemm with gelu epilogue can efficiently operate with a fused kernel. Compared against the XLA-generated gelu kernel on H100, the performance demonstrates some improvement for size of [8192, 4096] x [4096, 16384] + gelu:

Execution time for matmul using cublasLt and gelu (XLA): 1.28ms
Execution time for matmul_gelu using cublasLt: 1.25ms
Copybara import of the project:

--
e8abce3b41f68cae1bb625cdecd5885413a0781d by Shu Wang <shuw@nvidia.com>:

Support cublasLt Fp8 Approx Gelu epilogue fusion.

--
818127cf582af7ceba014d88bdf027857fc8f0e5 by shuw <shuw@nvidia.com>:

Remove F32 check

--
5ce3108a9bc8459e20456d23a3ae493ef7a6a387 by shuw <shuw@nvidia.com>:

Improve based on review #1

Merging this change closes #7751

PiperOrigin-RevId: 591236441
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants