-
Notifications
You must be signed in to change notification settings - Fork 344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPU] Fix handling of flags in the cuDNN FMHA test. #12224
Conversation
@@ -91,7 +89,7 @@ class MultiHeadedAttentionTest : public GpuCodegenTest { | |||
DebugOptions GetDebugOptionsForTest() override { | |||
auto debug_options = HloTestBase::GetDebugOptionsForTest(); | |||
debug_options.set_xla_gpu_enable_xla_runtime_executable(false); | |||
debug_options.set_xla_gpu_enable_cudnn_fmha(false); | |||
debug_options.set_xla_gpu_enable_cudnn_fmha(true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: The code below would be more concise and easier to read if you just pass the bool config_with_fmha
into this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is an override, it does not accept parameters. Functions like ParseAndReturnVerifiedModule apply debug options set by GetDebugOptionsForTest automatically. You are right though, code can be simplified further, I did that.
Imported from GitHub PR openxla/xla#12224 The test got broken by openxla/xla@8799ff0, this commit fixes it. Copybara import of the project: -- 80528497321ee6020126b15035050f4c1a0beea9 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Fix handling of flags in the cuDNN FMHA test. Merging this change closes #12224 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12224 from openxla:fix_fmha_test 80528497321ee6020126b15035050f4c1a0beea9 PiperOrigin-RevId: 631728934
Imported from GitHub PR openxla/xla#12224 The test got broken by openxla/xla@8799ff0, this commit fixes it. Copybara import of the project: -- 80528497321ee6020126b15035050f4c1a0beea9 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Fix handling of flags in the cuDNN FMHA test. Merging this change closes #12224 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12224 from openxla:fix_fmha_test 80528497321ee6020126b15035050f4c1a0beea9 PiperOrigin-RevId: 631728934
Imported from GitHub PR openxla/xla#12224 The test got broken by openxla/xla@8799ff0, this commit fixes it. Copybara import of the project: -- 80528497321ee6020126b15035050f4c1a0beea9 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Fix handling of flags in the cuDNN FMHA test. Merging this change closes #12224 PiperOrigin-RevId: 631744108
The test got broken by 8799ff0, this commit fixes it.