🚨 Enable SDPA (and other attention backends) for T5 and propagate to the T5 family#47014
Open
jiqing-feng wants to merge 5 commits into
Open
🚨 Enable SDPA (and other attention backends) for T5 and propagate to the T5 family#47014jiqing-feng wants to merge 5 commits into
jiqing-feng wants to merge 5 commits into
Conversation
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: longt5, mt5, pix2struct, pop2piano, switch_transformers, t5, udop, umt5 |
Contributor
CI recapDashboard: View test results in Grafana |
Member
|
Attention dispatch so cc @ArthurZucker @Cyrilvallez |
Collaborator
|
WIll check it out when I have time but it's a big PR so expect delays 🙏 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this PR does
Refactors the T5 attention stack to route through
ALL_ATTENTION_FUNCTIONS,so T5 can dispatch to
sdpa/eagerinstead of the old eager-only path, andpropagates the change to the copied-from family.
Now
_supports_sdpa = True/_supports_attention_backend = Trueand runningtest_eager_matches_sdpa_inference:t5(reference),mt5,udop,pop2pianopix2struct— also required migrating the vision-tower attention(
Pix2StructVisionAttention), a standard dense attention that was not on theinterface, otherwise SDPA is unusable for the full model.
umt5— its attention is its own dense implementation (not# Copied fromT5); migrated the same way.
switch_transformers(modular, inherits T5 attention) is regenerated so itsmodeling matches the refactored T5, but it stays eager-only (
_supports_sdpanot set): it is not on the SDPA path, this only propagates the eager refactor.
Alignment applied to migrated modules: attention goes through
ALL_ATTENTION_FUNCTIONS(witheager_attention_forwardfallback),self.scaling,is_causal; the relative position bias is folded into the additive attention mask;forwards use
**kwargs: Unpack[TransformersKwargs]+@can_return_tuple/@auto_docstring/@merge_with_config_defaults/@capture_outputs, dropping themanual
return_dict/output_attentions/output_hidden_statesresolution.🚨 Behavior changes
sdpa, softmax no longer force-upcasts to fp32 as the old path did;numerics stay within the
eager_matches_sdpatolerance.output_attentionsremoved from internal block/layer tuple returns; attentionsare collected via
capture_outputs.test_eager_matches_sdpa_inferenceis not skipped for any enabled model.longt5 stays not-SDPA-enabled
The regular
LongT5Attention(copied from T5) is on the interface, but the encoderalways runs the block-sparse
LongT5LocalAttention/LongT5TransientGlobalAttention.These operate on 5D blocked tensors
(batch, num_blocks, heads, block_len, 3*block_len),while
sdpa_attention_forwardassumes 4D(batch, heads, seq, dim)— they cannot gothrough the interface at all.
LongT5EncoderModelis purely block-sparse, so theframework refuses to switch its attention implementation:
Forcing
_supports_sdpa = Truemakes the encoder-only sdpa tests fail (fp16mean relative difference: nan). Same precedent aspegasus_x(block-sparseencoder + regular decoder), which also sets
_supports_sdpa = False.Testing
utils/check_copies.pypasses; full sweep green: