Pass packed boundary metadata to Qwen3.5 linear-attention fast kernels from data collator#45034
Pass packed boundary metadata to Qwen3.5 linear-attention fast kernels from data collator#45034sdharani91 wants to merge 8 commits intohuggingface:mainfrom
Conversation
…s for issue 44717
|
[For maintainers] Suggested jobs to run (before merge) run-slow: qwen3_5 |
vasqu
left a comment
There was a problem hiding this comment.
Thanks for sticking with this and sorry about the confusion about the other PR / issues
I added some comments, my main point stands to move even more to the data collator and take a look at bamba maybe which has done something similar with seq_idx for example
| @@ -214,6 +206,9 @@ def forward( | |||
| hidden_states: torch.Tensor, | |||
| cache_params: Qwen3_5DynamicCache | None = None, | |||
| attention_mask: torch.Tensor | None = None, | |||
| seq_idx: torch.IntTensor | None = None, | |||
| cu_seq_lens_q: torch.LongTensor | None = None, | |||
| cu_seq_lens_k: torch.LongTensor | None = None, | |||
There was a problem hiding this comment.
Ok, not on you but it is annoying because they have another different standard on FLA side 😭 They use cu_seqlens, see e.g. https://github.com/fla-org/flash-linear-attention/blob/2e90142c8075af0a0efe4979c22136194a307140/fla/ops/gated_delta_rule/fused_recurrent.py#L298
We had a similar thing in Bamba where we added these for typing under our kwargs, see
just now adjusted for this special case in this weird FLA/conv mixupWe do need to change the datacollator tho to support only returning cu_seqlens instead of the q/k versions
| has_fast_path = self.causal_conv1d_fn is not None and self.chunk_gated_delta_rule.__module__.startswith( | ||
| "fla." | ||
| ) | ||
| if not has_fast_path and any(x is not None for x in (seq_idx, cu_seq_lens_q, cu_seq_lens_k)): | ||
| raise NotImplementedError( | ||
| "Padding-free training kwargs require fast path support. Please install `flash-linear-attention` " | ||
| "and `causal-conv1d`." | ||
| ) |
There was a problem hiding this comment.
| has_fast_path = self.causal_conv1d_fn is not None and self.chunk_gated_delta_rule.__module__.startswith( | |
| "fla." | |
| ) | |
| if not has_fast_path and any(x is not None for x in (seq_idx, cu_seq_lens_q, cu_seq_lens_k)): | |
| raise NotImplementedError( | |
| "Padding-free training kwargs require fast path support. Please install `flash-linear-attention` " | |
| "and `causal-conv1d`." | |
| ) |
we shouldn't have these checks, it should stay a power feature for people who know what they do
| chunk_kwargs = {} | ||
| if getattr(self.chunk_gated_delta_rule, "__module__", "").startswith("fla."): | ||
| chunk_kwargs["cu_seqlens"] = cu_seq_lens_q | ||
|
|
There was a problem hiding this comment.
Mentioned in the first comments, should become part of the collator instead and just pass it directly here then
| seq_idx=kwargs.get("seq_idx"), | ||
| cu_seq_lens_q=kwargs.get("cu_seq_lens_q"), | ||
| cu_seq_lens_k=kwargs.get("cu_seq_lens_k"), |
There was a problem hiding this comment.
Would pass kwargs directly and adjust the signature instead
| @@ -57,6 +59,7 @@ class Qwen3_5TextModelTester(CausalLMModelTester): | |||
|
|
|||
| def __init__(self, parent): | |||
| super().__init__(parent=parent) | |||
| self.hidden_act = "silu" | |||
| if not is_flash_linear_attention_available() or not is_causal_conv1d_available(): | ||
| self.skipTest("Qwen3.5 padding-free fast path requires `flash-linear-attention` and `causal-conv1d`.") |
There was a problem hiding this comment.
Let's make require decorators out of these instead
| def test_padding_free_matches_padded_fast_path_regression(self): | ||
| if not is_flash_linear_attention_available() or not is_causal_conv1d_available(): | ||
| self.skipTest("Qwen3.5 padding-free fast path requires `flash-linear-attention` and `causal-conv1d`.") | ||
| torch.manual_seed(0) |
There was a problem hiding this comment.
| torch.manual_seed(0) |
| config = Qwen3_5TextConfig( | ||
| vocab_size=100, | ||
| hidden_size=64, | ||
| intermediate_size=128, | ||
| num_hidden_layers=2, | ||
| num_attention_heads=4, | ||
| num_key_value_heads=2, | ||
| head_dim=16, | ||
| max_position_embeddings=64, | ||
| hidden_act="silu", | ||
| layer_types=["full_attention", "linear_attention"], | ||
| linear_conv_kernel_dim=2, | ||
| linear_key_head_dim=16, | ||
| linear_value_head_dim=16, | ||
| linear_num_key_heads=2, | ||
| linear_num_value_heads=4, | ||
| pad_token_id=0, | ||
| ) | ||
| model = Qwen3_5ForCausalLM(config).to(torch_device).eval() |
There was a problem hiding this comment.
Can we use the prepare configs function instead and get the text config through that?
| linear_attn = model.model.layers[1].linear_attn | ||
| self.assertIsNotNone(linear_attn.causal_conv1d_fn) | ||
| self.assertTrue(linear_attn.chunk_gated_delta_rule.__module__.startswith("fla.")) | ||
| self.assertTrue(linear_attn.recurrent_gated_delta_rule.__module__.startswith("fla.")) |
There was a problem hiding this comment.
| linear_attn = model.model.layers[1].linear_attn | |
| self.assertIsNotNone(linear_attn.causal_conv1d_fn) | |
| self.assertTrue(linear_attn.chunk_gated_delta_rule.__module__.startswith("fla.")) | |
| self.assertTrue(linear_attn.recurrent_gated_delta_rule.__module__.startswith("fla.")) |
Not needed
| self.assertTrue(linear_attn.recurrent_gated_delta_rule.__module__.startswith("fla.")) | ||
|
|
||
| padded_input_ids = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], device=torch_device) | ||
| attention_mask = torch.tensor([[0, 1, 1, 1], [1, 1, 1, 1]], dtype=torch.long, device=torch_device) |
There was a problem hiding this comment.
I think we can sparsify this a bit more (i.e. more padding)
What does this PR do?
This is a follow up to #44867
This PR fixes Qwen3.5 padding-free packed inputs on the linear-attention fast path by consuming collator-provided packed metadata. The linear-attention block now uses seq_idx for the causal convolution path and cu_seq_lens_q / cu_seq_lens_k for the FLA path, matching the repo’s existing DataCollatorWithFlattening contract. I also added a deterministic fast-path regression test comparing padded and padding-free inputs, plus a slow-path contract test that raises clearly when padding-free kwargs are passed without fast-kernel support. The slow fallback implementation itself is unchanged in this PR.
Fixes # 44717
Code Agent Policy
The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.
PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.
This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read
CONTRIBUTING.md.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@vasqu