Skip to content

Commit

Permalink
[XLA:GPU] Disable cuDNN FMHA by default.
Browse files Browse the repository at this point in the history
cuDNN FMHA dispatches pattern-matched regions to a FlashAttention kernel by
default. FlashAttention does not preserve numerics, and thus an illegal
optimization to have on by default.

PiperOrigin-RevId: 631384623
  • Loading branch information
bchetioui authored and Copybara-Service committed May 7, 2024
1 parent d8f814e commit 8799ff0
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_cpu_fast_math_honor_division(true);

// TODO(AyanmoI): Remove this flag when cuDNN FMHA is fully supported.
opts.set_xla_gpu_enable_cudnn_fmha(true);
//
// cuDNN FMHA currently rewrites attention layers to use FlashAttention by
// default. This reassociation is not semantics-preserving, and the user
// should explicitly opt in if they wish to use this feature. cuDNN FMHA can
// not be turned on by default.
opts.set_xla_gpu_enable_cudnn_fmha(false);

opts.set_xla_gpu_fused_attention_use_cudnn_rng(false);

Expand Down

0 comments on commit 8799ff0

Please sign in to comment.