qwen image: per-sample split attention for batched training#2648
Merged
bghira merged 1 commit intobghira:mainfrom Mar 20, 2026
Merged
qwen image: per-sample split attention for batched training#2648bghira merged 1 commit intobghira:mainfrom
bghira merged 1 commit intobghira:mainfrom
Conversation
…aining When batch_size > 1, QwenImage text sequences have different lengths. After padding (collate fix), the joint-mask approach still fails because it broadcasts one mask for all samples — padding tokens from shorter sequences contaminate attention scores of longer ones, producing corrupted outputs (desaturated mannequins by step ~50). This commit introduces three coordinated changes: 1. model.py (collate): pad embeds + masks to max_seq_len before torch.cat 2. transformer.py (forward + processor): when batch_size > 1 and a mask is present, run per-sample attention — slice each sample's text Q/K/V to its true length, run unmasked attention per-sample, then re-stack. For batch_size == 1, the existing efficient joint-mask path is unchanged. 3. pipeline.py: preserve all-ones masks instead of discarding them. The mask is the routing signal for split attention; discarding it forces fallback to joint attention even when batch > 1. Approach follows musubi-tuner PR bghira#688 (merged) which solved the same variable-length text sequence problem for a different model. Tested: batch_size=4, diffusers attention, Prodigy optimizer, 300+ steps clean with no corruption (baseline corrupted at step ~50 without fix).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
When
batch_size > 1, QwenImage batched training produces corrupted outputs (desaturated mannequins by step ~50). Root cause: the joint-mask approach broadcasts one bool mask for all samples, so padding tokens from shorter sequences contaminate attention scores of longer ones.Three coordinated changes:
max_seq_lenbeforetorch.cat(also submitted as independent PR qwen image: fix collate padding for batch_size > 1 #2647 for the crash fix alone)batch_size > 1and a mask is present, run per-sample attention — slice each sample's text Q/K/V to its true length, run unmasked SDPA per-sample, then re-stack. Forbatch_size == 1, the existing efficient joint-mask path is unchanged.if mask.all(): mask = None) breaks the routing signal for split attention.Approach follows musubi-tuner PR #688 (merged) which solved the same variable-length text sequence problem.
Test plan
batch_size=1training unchanged (joint-mask path, no regression)batch_size=4with diffusers attention: 300+ steps clean, no corruptionRelated