Skip to content

Fix flash attention crash with 3D position_ids (Qwen3.5)#44911

Closed
ouroborosscr wants to merge 4 commits intohuggingface:mainfrom
ouroborosscr:fix/qwen35-flash-attn-3d-position-ids
Closed

Fix flash attention crash with 3D position_ids (Qwen3.5)#44911
ouroborosscr wants to merge 4 commits intohuggingface:mainfrom
ouroborosscr:fix/qwen35-flash-attn-3d-position-ids

Conversation

@ouroborosscr
Copy link
Copy Markdown

@ouroborosscr ouroborosscr commented Mar 21, 2026

Qwen3.5 uses 3D position_ids [3, batch, seq_len] for multi-dimensional rotary embedding. _is_packed_sequence() misinterprets this as a packed sequence, causing cu_seqlens to be constructed with 3x the actual token count. Flash attention then reads beyond tensor boundaries, resulting in CUDA illegal memory access.

Add a dimensionality check to reject >2D position_ids, since packed sequences always use 2D [batch, seq_len] format.

What does this PR do?

Qwen3.5 uses a hybrid architecture (GatedDeltaNet + standard attention) with 3D position_ids of shape [3, batch_size, seq_len] for multi-dimensional rotary embedding. The function _is_packed_sequence() in modeling_flash_attention_utils.py does not handle >2D tensors, causing it to misidentify the input as a packed sequence. This leads to cu_seqlens being constructed with 3× the actual token count, and flash_attn_varlen_func reads beyond tensor boundaries, resulting in CUDA error: illegal memory access.

The fix: Add if position_ids.dim() > 2: return False at the top of _is_packed_sequence(), since packed sequences always use 2D [batch, seq_len] position_ids.

Intercepted evidence before crash:

q: torch.Size([256, 16, 256])           ← 256 tokens
cu_seqlens_q: tensor([0, 256, 512, 768]) ← claims 768 tokens (3×256)
q total=256 vs cu_seqlens_q[-1]=768      ← MISMATCH → illegal memory access

Fixes #44910

Before submitting

Who can review?

@vasqu @ArthurZucker @Cyrilvallez (attention)

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Qwen3.5 uses 3D position_ids [3, batch, seq_len] for multi-dimensional
rotary embedding. _is_packed_sequence() misinterprets this as a packed
sequence, causing cu_seqlens to be constructed with 3x the actual token
count. Flash attention then reads beyond tensor boundaries, resulting in
CUDA illegal memory access.

Add a dimensionality check to reject >2D position_ids, since packed
sequences always use 2D [batch, seq_len] format.
@JJJYmmm
Copy link
Copy Markdown
Contributor

JJJYmmm commented Mar 22, 2026

Hi, would you mind checking out this comment: https://github.com/QwenLM/Qwen3.5/issues/104#issuecomment-4105702644 ? (btw, I'm not sure if multimodal sequence packing is supported yet. If it isn't, then the current fix is fine.)

Qwen3.5 uses 3D position_ids [3, batch, seq_len] for multi-dimensional
rotary embedding. _is_packed_sequence() misinterprets this as a packed
sequence, causing cu_seqlens to be constructed with 3x the actual token
count. Flash attention then reads beyond tensor boundaries, resulting in
CUDA illegal memory access.

Add a dimensionality check to reject >2D position_ids, since packed
sequences always use 2D [batch, seq_len] format.
@ouroborosscr
Copy link
Copy Markdown
Author

The @jjymmm scheme has been updated to this pull requests.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ty both! Do you mind adding a direct ref to the issue in the packed sequence doc (we want to make sure vendored / authors are credited !) 🤗

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noope, this should not be the way. Please see the comment under issue, we pass over text positions already without THW

@ouroborosscr
Copy link
Copy Markdown
Author

Done! Added a direct reference to both issues and credited contributors in the docstring. Thanks for the reminder 🤗 @ArthurZucker

@JJJYmmm
Copy link
Copy Markdown
Contributor

JJJYmmm commented Mar 23, 2026

Noope, this should not be the way. Please see the comment under issue, we pass over text positions already without THW

@zucchini-nlp Oops, I overlooked that text_position_ids would be set to None when position_ids.shape[0] == 3. Previously, I thought this is just a corner case when users skipped preparing 4d position ids while calling the model directly. Sorry for the confusion.🥹
Could you test this bug with latest commit? @ouroborosscr

@ouroborosscr
Copy link
Copy Markdown
Author

@zucchini-nlp @JJJYmmm My test is over. When transformers=5.3.0.dev0, there will be no illegal memory access.

Now:

TORCH_USE_CUDA_DSA=1 CUDA_LAUNCH_BLOCKING=1 \                                       CUDA_HOME=/usr/local/cuda-12.8 \                                                                                 
LD_PRELOAD=/opt/anaconda3/envs/scr_train2/lib/libstdc++.so.6 \
CUDA_VISIBLE_DEVICES=1 python intercept_transformers_flash.py
✅ transformers flash attention 拦截器已安装

加载模型...
`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 427/427 [00:04<00:00, 95.91it/s]

======================================================================
执行 model forward (seq=256)...
======================================================================

============================================================
📌 transformers _flash_attention_forward 被调用!
============================================================
  arg[0]: shape=torch.Size([1, 256, 16, 256]), dtype=torch.bfloat16
  arg[1]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[2]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[3]: NoneType = None
  kwarg query_length: 256
  kwarg is_causal: True
  kwarg dropout: 0.0
  kwarg softmax_scale: 0.0625
  kwarg sliding_window: None
  kwarg softcap: None
  kwarg use_top_left_mask: False
  kwarg target_dtype: None
  kwarg attn_implementation: flash_attention_2
  kwarg layer_idx: 3
  kwarg position_ids: shape=torch.Size([1, 256]), dtype=torch.int64, nan=False
  kwarg use_cache: False
  query: range=[-10.7500, 9.7500], nan=False, inf=False
  key: range=[-12.8125, 11.7500], nan=False, inf=False
  value: range=[-14.1250, 8.5000], nan=False, inf=False

============================================================
📌 transformers _flash_attention_forward 被调用!
============================================================
  arg[0]: shape=torch.Size([1, 256, 16, 256]), dtype=torch.bfloat16
  arg[1]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[2]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[3]: NoneType = None
  kwarg query_length: 256
  kwarg is_causal: True
  kwarg dropout: 0.0
  kwarg softmax_scale: 0.0625
  kwarg sliding_window: None
  kwarg softcap: None
  kwarg use_top_left_mask: False
  kwarg target_dtype: None
  kwarg attn_implementation: flash_attention_2
  kwarg layer_idx: 7
  kwarg position_ids: shape=torch.Size([1, 256]), dtype=torch.int64, nan=False
  kwarg use_cache: False
  query: range=[-14.8125, 11.5000], nan=False, inf=False
  key: range=[-14.2500, 11.1875], nan=False, inf=False
  value: range=[-10.5000, 18.1250], nan=False, inf=False

============================================================
📌 transformers _flash_attention_forward 被调用!
============================================================
  arg[0]: shape=torch.Size([1, 256, 16, 256]), dtype=torch.bfloat16
  arg[1]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[2]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[3]: NoneType = None
  kwarg query_length: 256
  kwarg is_causal: True
  kwarg dropout: 0.0
  kwarg softmax_scale: 0.0625
  kwarg sliding_window: None
  kwarg softcap: None
  kwarg use_top_left_mask: False
  kwarg target_dtype: None
  kwarg attn_implementation: flash_attention_2
  kwarg layer_idx: 11
  kwarg position_ids: shape=torch.Size([1, 256]), dtype=torch.int64, nan=False
  kwarg use_cache: False
  query: range=[-11.3750, 12.5625], nan=False, inf=False
  key: range=[-10.3750, 12.4375], nan=False, inf=False
  value: range=[-8.2500, 10.7500], nan=False, inf=False

============================================================
📌 transformers _flash_attention_forward 被调用!
============================================================
  arg[0]: shape=torch.Size([1, 256, 16, 256]), dtype=torch.bfloat16
  arg[1]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[2]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[3]: NoneType = None
  kwarg query_length: 256
  kwarg is_causal: True
  kwarg dropout: 0.0
  kwarg softmax_scale: 0.0625
  kwarg sliding_window: None
  kwarg softcap: None
  kwarg use_top_left_mask: False
  kwarg target_dtype: None
  kwarg attn_implementation: flash_attention_2
  kwarg layer_idx: 15
  kwarg position_ids: shape=torch.Size([1, 256]), dtype=torch.int64, nan=False
  kwarg use_cache: False
  query: range=[-9.1250, 9.4375], nan=False, inf=False
  key: range=[-10.3125, 10.3750], nan=False, inf=False
  value: range=[-8.6250, 7.5000], nan=False, inf=False

============================================================
📌 transformers _flash_attention_forward 被调用!
============================================================
  arg[0]: shape=torch.Size([1, 256, 16, 256]), dtype=torch.bfloat16
  arg[1]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[2]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[3]: NoneType = None
  kwarg query_length: 256
  kwarg is_causal: True
  kwarg dropout: 0.0
  kwarg softmax_scale: 0.0625
  kwarg sliding_window: None
  kwarg softcap: None
  kwarg use_top_left_mask: False
  kwarg target_dtype: None
  kwarg attn_implementation: flash_attention_2
  kwarg layer_idx: 19
  kwarg position_ids: shape=torch.Size([1, 256]), dtype=torch.int64, nan=False
  kwarg use_cache: False
  query: range=[-10.6875, 8.6250], nan=False, inf=False
  key: range=[-11.9375, 10.1250], nan=False, inf=False
  value: range=[-17.0000, 20.6250], nan=False, inf=False

============================================================
📌 transformers _flash_attention_forward 被调用!
============================================================
  arg[0]: shape=torch.Size([1, 256, 16, 256]), dtype=torch.bfloat16
  arg[1]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[2]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[3]: NoneType = None
  kwarg query_length: 256
  kwarg is_causal: True
  kwarg dropout: 0.0
  kwarg softmax_scale: 0.0625
  kwarg sliding_window: None
  kwarg softcap: None
  kwarg use_top_left_mask: False
  kwarg target_dtype: None
  kwarg attn_implementation: flash_attention_2
  kwarg layer_idx: 23
  kwarg position_ids: shape=torch.Size([1, 256]), dtype=torch.int64, nan=False
  kwarg use_cache: False
  query: range=[-8.2500, 10.8750], nan=False, inf=False
  key: range=[-13.2500, 10.0000], nan=False, inf=False
  value: range=[-28.8750, 26.0000], nan=False, inf=False

============================================================
📌 transformers _flash_attention_forward 被调用!
============================================================
  arg[0]: shape=torch.Size([1, 256, 16, 256]), dtype=torch.bfloat16
  arg[1]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[2]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[3]: NoneType = None
  kwarg query_length: 256
  kwarg is_causal: True
  kwarg dropout: 0.0
  kwarg softmax_scale: 0.0625
  kwarg sliding_window: None
  kwarg softcap: None
  kwarg use_top_left_mask: False
  kwarg target_dtype: None
  kwarg attn_implementation: flash_attention_2
  kwarg layer_idx: 27
  kwarg position_ids: shape=torch.Size([1, 256]), dtype=torch.int64, nan=False
  kwarg use_cache: False
  query: range=[-8.6875, 8.8750], nan=False, inf=False
  key: range=[-8.9375, 9.8750], nan=False, inf=False
  value: range=[-94.5000, 57.5000], nan=False, inf=False

============================================================
📌 transformers _flash_attention_forward 被调用!
============================================================
  arg[0]: shape=torch.Size([1, 256, 16, 256]), dtype=torch.bfloat16
  arg[1]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[2]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[3]: NoneType = None
  kwarg query_length: 256
  kwarg is_causal: True
  kwarg dropout: 0.0
  kwarg softmax_scale: 0.0625
  kwarg sliding_window: None
  kwarg softcap: None
  kwarg use_top_left_mask: False
  kwarg target_dtype: None
  kwarg attn_implementation: flash_attention_2
  kwarg layer_idx: 31
  kwarg position_ids: shape=torch.Size([1, 256]), dtype=torch.int64, nan=False
  kwarg use_cache: False
  query: range=[-11.1875, 10.8750], nan=False, inf=False
  key: range=[-8.0000, 9.3750], nan=False, inf=False
  value: range=[-43.0000, 56.5000], nan=False, inf=False

✅ 成功!

完成

Past:

TORCH_USE_CUDA_DSA=1 CUDA_LAUNCH_BLOCKING=1 \
CUDA_HOME=/usr/local/cuda-12.8 \
LD_PRELOAD=/opt/anaconda3/envs/scr_train2/lib/libstdc++.so.6 \
CUDA_VISIBLE_DEVICES=2 python intercept_transformers_flash.py
python: can't open file '/date/sunchengrui/gyf/train/flash-attention/intercept_transformers_flash.py': [Errno 2] No such file or directory
(scr_train2) sunchengrui@ubuntu-SYS-420GP-TNR:/date/sunchengrui/gyf/train/flash-attention$ cd ..
(scr_train2) sunchengrui@ubuntu-SYS-420GP-TNR:/date/sunchengrui/gyf/train$ TORCH_USE_CUDA_DSA=1 CUDA_LAUNCH_BLOCKING=1 \
CUDA_HOME=/usr/local/cuda-12.8 \
LD_PRELOAD=/opt/anaconda3/envs/scr_train2/lib/libstdc++.so.6 \
CUDA_VISIBLE_DEVICES=2 python intercept_transformers_flash.py
✅ transformers flash attention 拦截器已安装
加载模型...
torch_dtype is deprecated! Use dtype instead!
Loading weights: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 427/427 [00:08<00:00, 50.40it/s]
======================================================================
执行 model forward (seq=256)...
======================================================================
============================================================
📌 transformers _flash_attention_forward 被调用!
============================================================
  arg[0]: shape=torch.Size([1, 256, 16, 256]), dtype=torch.bfloat16
  arg[1]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[2]: shape=torch.Size([1, 256, 4, 256]), dtype=torch.bfloat16
  arg[3]: NoneType = None
  kwarg query_length: 256
  kwarg is_causal: True
  kwarg dropout: 0.0
  kwarg softmax_scale: 0.0625
  kwarg sliding_window: None
  kwarg softcap: None
  kwarg use_top_left_mask: False
  kwarg target_dtype: None
  kwarg attn_implementation: flash_attention_2
  kwarg layer_idx: 3
  kwarg position_ids: shape=torch.Size([3, 1, 256]), dtype=torch.int64, nan=False
  kwarg use_cache: False
  query: range=[-11.4375, 11.1875], nan=False, inf=False
  key: range=[-12.9375, 12.1250], nan=False, inf=False
  value: range=[-9.2500, 9.3750], nan=False, inf=False
❌ 崩溃 : CUDA error: an illegal memory access was encountered
完成

@zucchini-nlp
Copy link
Copy Markdown
Member

Nice @ouroborosscr , then we can close the PR as there is no bug. The next release will include the working FA for qwen3 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Flash Attention crashes with illegal memory access on Qwen3.5 due to 3D position_ids being misinterpreted as packed sequence

4 participants