Skip to content

[Qwen3.5] Fix Qwen3.5 linear attention multi-token cached forward#45513

Open
kashif wants to merge 1 commit intohuggingface:mainfrom
kashif:fix-qwen35-linear-attn-multi-token-cached
Open

[Qwen3.5] Fix Qwen3.5 linear attention multi-token cached forward#45513
kashif wants to merge 1 commit intohuggingface:mainfrom
kashif:fix-qwen35-linear-attn-multi-token-cached

Conversation

@kashif
Copy link
Copy Markdown
Contributor

@kashif kashif commented Apr 19, 2026

What does this PR do?

The gated-delta-net forward only used the cached recurrent state when seq_len == 1. For any multi-token forward with a populated cache (e.g. chunked prefill continuation or speculative-decoding verification), it fell through to chunk_gated_delta_rule(initial_state=None), silently restarting the linear layers from zero and ignoring the prefill state.

This breaks the causal-LM invariant that the logits at position i must not depend on whether later tokens are batched into the same call — position 0 of a 16-token verify forward ended up differing from the corresponding single-token cached decode, collapsing to high-frequency context tokens and destroying speculative-decoding correctness.

Add a use_cached_chunk path that, when has_previous_state is true and seq_len > 1:

  • reads the cached conv_state / recurrent_state,
  • prepends the conv context onto the chunk input so the causal conv sees the correct left-context,
  • drops the prepended context from the output,
  • passes the cached recurrent_state as initial_state to chunk_gated_delta_rule.

The same fix propagates to qwen3_5_moe via the modular system.

Add a unit test that compares the first-position output of a multi-token cached forward against the single-token cached forward on the same token and cache. Without this fix the mismatch is 100%.

Fixes # (issue)

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

The gated-delta-net forward only used the cached recurrent state when
`seq_len == 1`. For any multi-token forward with a populated cache (e.g.
chunked prefill continuation or speculative-decoding verification), it
fell through to `chunk_gated_delta_rule(initial_state=None)`, silently
restarting the linear layers from zero and ignoring the prefill state.

This breaks the causal-LM invariant that the logits at position `i` must
not depend on whether later tokens are batched into the same call —
position 0 of a 16-token verify forward ended up differing from the
corresponding single-token cached decode, collapsing to high-frequency
context tokens and destroying speculative-decoding correctness.

Add a `use_cached_chunk` path that, when `has_previous_state` is true
and `seq_len > 1`:
- reads the cached `conv_state` / `recurrent_state`,
- prepends the conv context onto the chunk input so the causal conv sees
  the correct left-context,
- drops the prepended context from the output,
- passes the cached `recurrent_state` as `initial_state` to
  `chunk_gated_delta_rule`.

The same fix propagates to `qwen3_5_moe` via the modular system.

Add a unit test that compares the first-position output of a
multi-token cached forward against the single-token cached forward on
the same token and cache. Without this fix the mismatch is 100%.
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen3_5, qwen3_5_moe

@kashif kashif requested a review from Cyrilvallez April 19, 2026 09:46
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants