Skip to content

Fix cross-attention cache layer type for T5Gemma2 long inputs#45540

Open
Beichen-Ma wants to merge 1 commit intohuggingface:mainfrom
Beichen-Ma:fix-cross-attention-cache-not-sliding
Open

Fix cross-attention cache layer type for T5Gemma2 long inputs#45540
Beichen-Ma wants to merge 1 commit intohuggingface:mainfrom
Beichen-Ma:fix-cross-attention-cache-not-sliding

Conversation

@Beichen-Ma
Copy link
Copy Markdown

Fixes #45521. Cross-attention in T5Gemma2ForConditionalGeneration is supposed to attend to all encoder tokens, but for inputs whose encoder length is >= sliding_window (default 4096) generation crashes with:

RuntimeError: The size of tensor a (4097) must match the size of tensor b (5018) at non-singleton dimension 3

The root cause was in T5Gemma2ForConditionalGeneration._prepare_cache_for_generation, the cross-attention config was being stripped of its sliding-window settings via del:

cross_attn_config = copy.deepcopy(self.config.get_text_config(decoder=True))
del cross_attn_config.sliding_window
del cross_attn_config.layer_types

T5Gemma2DecoderConfig with defaults sliding_window: int | None = 4096 and layer_types: list[str] | None = None. Removing the instance attributes therefore makes attribute lookup fall back to those class defaults, so cross_attn_config once again is sliding_window=4096.

DynamicCache.__init__ sees sliding_window=4096 with layer_types=None will auto-derives layer_types = ["sliding_attention"] * num_hidden_layers, and instantiates DynamicSlidingWindowLayer for every cross-attention layer. On update, those layers truncate the encoder K/V states to the last sliding_window-1 tokens:

self.keys   = full_key_states[:,   :, -self.sliding_window + 1 :, :]
self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]

So when enc_len == 4096, the cached cross-attention keys end up with shape [..., 4095, head_dim], which (after concatenation with the decoder self-attention key in T5Gemma2MergedAttention.forward) yields an attn_weights last-dim of 4097. Hence the mismatch.

Fix

Explicitly set sliding_window to null and layer_types to full attention for all layers, instead of deleting the instance attributes.

Tests

  • Added test T5Gemma2ModelTest::test_cross_attention_cache_is_not_sliding, which asserts that after generate() every layer of output.past_key_values.cross_attention_cache is DynamicLayer. Confirmed test fails on main branch and passes on this branch.
  • tests/models/t5gemma2/test_modeling_t5gemma2.py passes.
  • Verified provided end-to-end reproducer passed after the fix.
python /tmp/transformers_bug_repro.py 
Loading weights: 100%|███████████████████████████████████████████████████████████████████████████████████| 1327/1327 [00:04<00:00, 323.79it/s]

--- target=2500 ---
[transformers] The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
OK  (input=2500, output=17)

--- target=3500 ---
OK  (input=3500, output=17)

--- target=4000 ---
OK  (input=4000, output=17)

--- target=4090 ---
OK  (input=4090, output=17)

--- target=4100 ---
OK  (input=4100, output=17)

--- target=4500 ---
OK  (input=4500, output=17)

--- target=5000 ---
OK  (input=5000, output=17)

--- target=8000 ---
OK  (input=8000, output=17)

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: t5gemma2

@Rocketknight1
Copy link
Copy Markdown
Member

cc @vasqu since you were active in the original issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

T5Gemma2: decoder self-attention fixed 4097-element mask at batch=1, fails on inputs >4094 tokens

2 participants