feat: enable Flash Attention 2 for T5Gemma2#45991
Closed
TheShreyanshiDwivedi wants to merge 3 commits into
Closed
Conversation
…merged attention Encoder layers inherit FA2 support from Gemma3Attention unchanged. T5Gemma2MergedAttention fuses self-attention and cross-attention over a concatenated [self_KV ‖ cross_KV] sequence with a fused causal+bidirectional mask. This cannot be expressed in a single FA2 kernel for q_len > 1 (prefill and training), because FA2's window_size would incorrectly gate encoder tokens behind the sliding window and there is no single mask primitive that is causal over decoder tokens while bidirectional over encoder tokens. This patch restricts FA2 in MergedAttention to single-token decode steps (q_len == 1, eval mode, fp16/bf16, CUDA). At q_len == 1 there are no future tokens to mask, so is_causal=False over the merged KV is exact. Sliding- window truncation is applied to the self-KV slice before concatenation so that FA2's window_size=None does not contaminate encoder tokens. All other configurations fall back to SDPA with a one-time warning. The None-mask crash that occurred when both create_causal_mask and create_bidirectional_mask returned None (no padding) is fixed by the _merge_fa2() helper in T5Gemma2Decoder, which materialises all-ones tensors from KV shapes before torch.cat. Fixes huggingface#45161
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: t5gemma2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Enables Flash Attention 2 support for T5Gemma2. SDPA already worked; FA2 needed a mask materialization guard for hybrid mask cases.
Changes
_supports_flash_attn = TruetoT5Gemma2PreTrainedModelT5Gemma2MergedAttentionfor hybrid mask inputslogger.warning_oncefor ineligible inputs (training mode,q_len > 1)Files changed
src/transformers/models/t5gemma2/modeling_t5gemma2.pysrc/transformers/models/t5gemma2/modular_t5gemma2.py