Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 Fast Accumulation support for cublasLt #6599

Closed
wants to merge 3 commits into from

Conversation

wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Oct 27, 2023

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-3416.

// accumulation is enabled. When Precision is set to HIGHEST, indicative of
// scenarios in backward propagation, a higher precision accumulation method
// is utilized.
bool fast_accum = (xla::primitive_util::IsF8Type(lhs_layout.dtype) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call it use_fast_accum or enable_fast_accum to imply it is a bool? Also, shouldn't it be is_fp8(lhs) && is_fp8(rhs) && cfg.compute_precision == 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on current rewrite rule, Fp8 matmuls has to take both inputs as Fp8 types. So any one of them being Fp8 type should be good to indicate the Fp8 matmul.

// encountered during forward propagation with E4M3 operands, fast
// accumulation is enabled. When Precision is set to HIGHEST, indicative of
// scenarios in backward propagation, a higher precision accumulation method
// is utilized.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to specify the particular cases for using different precisions. Let's simply state that for fp8 matmul, there are two options available: fast accumulation (PrecisionConfig.Precision.DEFAULT) and higher precision accumulation (PrecisionConfig.Precision.HIGHEST).

@@ -210,6 +210,8 @@ cudaDataType_t BlasLt::MatrixLayout::type() const {
AsCublasOperation(trans_b)));
TF_ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue));
TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi));
TF_RETURN_IF_ERROR(
SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, int8_t(fast_accum)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use static_cast<int8_t>?

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Oct 30, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Oct 30, 2023
@kaixih
Copy link
Contributor

kaixih commented Oct 30, 2023

@reedwm We conducted a brief performance test on this. Implementing this change alone could result in approximately a 4% speedup. Could you please review the PR, considering both this update and the flax one?"

Copy link
Member

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a test by checking an FP8 matmul output is accurate enough if the PrecisionConfig is HIGHEST? I'm OK having no test if it's not easy to test this.

Comment on lines 321 to 323
// For FP8 matmuls, there are two options available: fast
// accumulation(PrecisionConfig.Precision.DEFAULT) and
// higher precision accumulation (PrecisionConfig.Precision.HIGHEST).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't mention the HIGHER case. I would phrase this as:

FP8 matmuls have a fast accumulation mode that is less precise than the default accumulation mode. Use the fast accumulation mode if the compute precision is DEFAULT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test in gemm_rewrite_test.cc

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Nov 1, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Nov 1, 2023
replacements["<<precision>>"] = "default";
const auto hlo_text_default = absl::StrReplaceAll(hlo_template, replacements);
EXPECT_TRUE(RunAndCompare(hlo_text_default, ErrorSpec{1e-3, 1e-3}));
EXPECT_FALSE(RunAndCompare(hlo_text_default, ErrorSpec{1e-4, 1e-4}));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This expectation seems to fail, it seems if we are lucky, it already has enough precision to pass with a tolerance of 1e-4.

Copy link
Contributor Author

@wenscarl wenscarl Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Do you suggest to remove this test or replace by some file check?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be guarded by a check for Ada/Hopper, because it only affects those GPUs and our tests are running on pre-Ada GPUs. But it's possible new GPUs will treat the fast-accumulation flag differently, so we should not do this check anyway.

I'll remove this line when merging.

copybara-service bot pushed a commit that referenced this pull request Nov 2, 2023
Imported from GitHub PR #6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <shuw@nvidia.com>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <shuw@nvidia.com>:

Improve based on review #1

--
e906d76 by shuw <shuw@nvidia.com>:

Improve based on review #2

Merging this change closes #6599

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578904075
copybara-service bot pushed a commit that referenced this pull request Nov 2, 2023
Imported from GitHub PR #6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <shuw@nvidia.com>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <shuw@nvidia.com>:

Improve based on review #1

--
e906d76 by shuw <shuw@nvidia.com>:

Improve based on review #2

Merging this change closes #6599

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578904075
copybara-service bot pushed a commit that referenced this pull request Nov 2, 2023
Imported from GitHub PR #6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <shuw@nvidia.com>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <shuw@nvidia.com>:

Improve based on review #1

--
e906d76 by shuw <shuw@nvidia.com>:

Improve based on review #2

Merging this change closes #6599

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578904075
@copybara-service copybara-service bot closed this in b716639 Nov 2, 2023
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 2, 2023
Imported from GitHub PR openxla/xla#6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue#openxla/xla#6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da8ca08cd2d4796a7b8f032827867a361bc by shuw <shuw@nvidia.com>:

Add FP8 fast accumulation support for cublasLt.

--
96845683cc4b1e7b947bc919fbf97d8865abeac9 by shuw <shuw@nvidia.com>:

Improve based on review #1

--
e906d7620780d2cf1fe8433c933648dcb98dc61d by shuw <shuw@nvidia.com>:

Improve based on review #2

Merging this change closes #6599

PiperOrigin-RevId: 578948593
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Nov 2, 2023
cuBLAS LT has a flag, CUBLASLT_MATMUL_DESC_FAST_ACCUM, that can be set for FP8 gemms. This flag causes the matmul to run faster but with lower accumulation precision. NVIDIA recommends using this flag on the forward pass on FP8 models but not the backwards pass, since the backwards pass needs more accumulation precision.

The PR openxla/xla#6599 enabled fast accumulation on FP8 dots whose PrecisionConfig is DEFAULT (but not HIGH or HIGHEST). This allows layers in frameworks to use fast accumulation on the forward pass but not the backwards pass by setting the PrecisionConfig on the backwards pass to be HIGH or HIGHEST.

The issue is, Flax and Praxis do not yet set the PrecisionConfig to HIGH or HIGHEST on the backwards pass, so the PR will cause poor FP8 training quality. The PR should not have been merged until Flax and Praxis set the PrecisionConfig, but I didn't realize this and merged it anyway. Reverting the PR is a pain, so instead this CL just removes the line that sets CUBLASLT_MATMUL_DESC_FAST_ACCUM, while keeping most of the plumbing around it. This CL will be rolled back once Flax and Praxis set the PrecisionConfig.

PiperOrigin-RevId: 579018421
copybara-service bot pushed a commit that referenced this pull request Nov 2, 2023
cuBLAS LT has a flag, CUBLASLT_MATMUL_DESC_FAST_ACCUM, that can be set for FP8 gemms. This flag causes the matmul to run faster but with lower accumulation precision. NVIDIA recommends using this flag on the forward pass on FP8 models but not the backwards pass, since the backwards pass needs more accumulation precision.

The PR #6599 enabled fast accumulation on FP8 dots whose PrecisionConfig is DEFAULT (but not HIGH or HIGHEST). This allows layers in frameworks to use fast accumulation on the forward pass but not the backwards pass by setting the PrecisionConfig on the backwards pass to be HIGH or HIGHEST.

The issue is, Flax and Praxis do not yet set the PrecisionConfig to HIGH or HIGHEST on the backwards pass, so the PR will cause poor FP8 training quality. The PR should not have been merged until Flax and Praxis set the PrecisionConfig, but I didn't realize this and merged it anyway. Reverting the PR is a pain, so instead this CL just removes the line that sets CUBLASLT_MATMUL_DESC_FAST_ACCUM, while keeping most of the plumbing around it. This CL will be rolled back once Flax and Praxis set the PrecisionConfig.

PiperOrigin-RevId: 579018421
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants