Skip to content

qwen image: per-sample split attention for batched training#2648

Merged
bghira merged 1 commit intobghira:mainfrom
rafstahelin:pr/qwen-split-attention
Mar 20, 2026
Merged

qwen image: per-sample split attention for batched training#2648
bghira merged 1 commit intobghira:mainfrom
rafstahelin:pr/qwen-split-attention

Conversation

@rafstahelin
Copy link
Copy Markdown
Contributor

@rafstahelin rafstahelin commented Mar 19, 2026

Summary

When batch_size > 1, QwenImage batched training produces corrupted outputs (desaturated mannequins by step ~50). Root cause: the joint-mask approach broadcasts one bool mask for all samples, so padding tokens from shorter sequences contaminate attention scores of longer ones.

Three coordinated changes:

  1. model.py (collate): pad embeds + masks to max_seq_len before torch.cat (also submitted as independent PR qwen image: fix collate padding for batch_size > 1 #2647 for the crash fix alone)
  2. transformer.py (forward + processor): when batch_size > 1 and a mask is present, run per-sample attention — slice each sample's text Q/K/V to its true length, run unmasked SDPA per-sample, then re-stack. For batch_size == 1, the existing efficient joint-mask path is unchanged.
  3. pipeline.py: preserve all-ones masks instead of discarding them. The all-ones optimization (if mask.all(): mask = None) breaks the routing signal for split attention.

Approach follows musubi-tuner PR #688 (merged) which solved the same variable-length text sequence problem.

Test plan

  • batch_size=1 training unchanged (joint-mask path, no regression)
  • batch_size=4 with diffusers attention: 300+ steps clean, no corruption
  • Baseline comparison: without fix, corruption appears at step ~50
  • Prodigy optimizer, RTX PRO 6000 Blackwell (98GB)
  • Running to step 800 for full validation (currently at 300, clean)

Related

…aining

When batch_size > 1, QwenImage text sequences have different lengths.
After padding (collate fix), the joint-mask approach still fails because
it broadcasts one mask for all samples — padding tokens from shorter
sequences contaminate attention scores of longer ones, producing
corrupted outputs (desaturated mannequins by step ~50).

This commit introduces three coordinated changes:

1. model.py (collate): pad embeds + masks to max_seq_len before torch.cat
2. transformer.py (forward + processor): when batch_size > 1 and a mask
   is present, run per-sample attention — slice each sample's text Q/K/V
   to its true length, run unmasked attention per-sample, then re-stack.
   For batch_size == 1, the existing efficient joint-mask path is unchanged.
3. pipeline.py: preserve all-ones masks instead of discarding them. The
   mask is the routing signal for split attention; discarding it forces
   fallback to joint attention even when batch > 1.

Approach follows musubi-tuner PR bghira#688 (merged) which solved the same
variable-length text sequence problem for a different model.

Tested: batch_size=4, diffusers attention, Prodigy optimizer, 300+ steps
clean with no corruption (baseline corrupted at step ~50 without fix).
@rafstahelin rafstahelin marked this pull request as ready for review March 19, 2026 19:34
@bghira bghira merged commit 2b6cdd5 into bghira:main Mar 20, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants