Skip to content

feat: enable Flash Attention 2 for T5Gemma2#45991

Closed
TheShreyanshiDwivedi wants to merge 3 commits into
huggingface:mainfrom
TheShreyanshiDwivedi:fresh/t5gemma2-flash-attn-2
Closed

feat: enable Flash Attention 2 for T5Gemma2#45991
TheShreyanshiDwivedi wants to merge 3 commits into
huggingface:mainfrom
TheShreyanshiDwivedi:fresh/t5gemma2-flash-attn-2

Conversation

@TheShreyanshiDwivedi
Copy link
Copy Markdown

What does this PR do?

Enables Flash Attention 2 support for T5Gemma2. SDPA already worked; FA2 needed a mask materialization guard for hybrid mask cases.

Changes

  • Add _supports_flash_attn = True to T5Gemma2PreTrainedModel
  • Add runtime fallback guard in T5Gemma2MergedAttention for hybrid mask inputs
  • Add logger.warning_once for ineligible inputs (training mode, q_len > 1)

Files changed

  • src/transformers/models/t5gemma2/modeling_t5gemma2.py
  • src/transformers/models/t5gemma2/modular_t5gemma2.py

Srijan Upadhyay and others added 3 commits May 15, 2026 15:17
…merged attention

Encoder layers inherit FA2 support from Gemma3Attention unchanged.

T5Gemma2MergedAttention fuses self-attention and cross-attention over a
concatenated [self_KV ‖ cross_KV] sequence with a fused causal+bidirectional
mask.  This cannot be expressed in a single FA2 kernel for q_len > 1 (prefill
and training), because FA2's window_size would incorrectly gate encoder tokens
behind the sliding window and there is no single mask primitive that is causal
over decoder tokens while bidirectional over encoder tokens.

This patch restricts FA2 in MergedAttention to single-token decode steps
(q_len == 1, eval mode, fp16/bf16, CUDA).  At q_len == 1 there are no future
tokens to mask, so is_causal=False over the merged KV is exact.  Sliding-
window truncation is applied to the self-KV slice before concatenation so that
FA2's window_size=None does not contaminate encoder tokens.  All other
configurations fall back to SDPA with a one-time warning.

The None-mask crash that occurred when both create_causal_mask and
create_bidirectional_mask returned None (no padding) is fixed by the
_merge_fa2() helper in T5Gemma2Decoder, which materialises all-ones tensors
from KV shapes before torch.cat.

Fixes huggingface#45161
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: t5gemma2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants