Skip to content

fix(wan): apply patchify flatten in forward/consumers (fixes #1063)#1435

Open
genno-whittlery wants to merge 2 commits into
modelscope:mainfrom
genno-whittlery:fix-wan-patchify
Open

fix(wan): apply patchify flatten in forward/consumers (fixes #1063)#1435
genno-whittlery wants to merge 2 commits into
modelscope:mainfrom
genno-whittlery:fix-wan-patchify

Conversation

@genno-whittlery
Copy link
Copy Markdown

Summary

Fixes #1063.

WanModel.forward (and the parallel usp_dit_forward in diffsynth/utils/xfuser/xdit_context_parallel.py) both unpack patchify's return value:

x, (f, h, w) = self.patchify(x)

But patchify only returns x — and that x is a 5-D (B, dim, f, h, w) tensor straight out of Conv3d, not the (B, f*h*w, dim) token sequence the transformer blocks expect.

Net result: every Wan training forward call dies immediately with:

ValueError: not enough values to unpack (expected 2, got 1)

The block loop never executes.

The fix

Inside patchify, after the Conv3d:

  1. Read (f, h, w) from the Conv3d output shape (last three dims).
  2. Flatten the 3-D spatial-temporal grid into a token sequence: (B, dim, f, h, w) → (B, f*h*w, dim).
  3. Return (x, (f, h, w)) to match what forward and unpatchify expect.

7 lines changed, including the comment.

Why it surfaces now

diffsynth/models/wan_video_dit.py is recent in the repo's history (added with #1427). It looks like forward and patchify got refactored from another file but didn't land in sync — forward was updated to expect the tuple while patchify still returns the un-flattened Conv3d output.

Validation

Smoke-tested locally against a synthetic 8-layer WanModel on 2× RTX 5090:

  • Input: (1, 16, 4, 8, 8) (B, C, T, H, W)
  • After patch_embedding with patch_size=(1, 2, 2): (1, 512, 4, 4, 4) → flattened to (1, 64, 512)
  • Forward through all 8 blocks: (1, 64, 512)
  • Head + unpatchify: back to (1, 16, 4, 8, 8)
  • Backward: loss gradient propagates through every block, LoRA params receive non-zero grads

The same path the example training scripts (examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh etc.) follow. Verified the existing unpatchify consumes the (B, f*h*w, dim_after_head) shape correctly via its rearrange('b (f h w) (x y z c) -> b c (f x) (h y) (w z)', ...) pattern.

Test plan

  • Synthetic mini-WanModel forward + backward — completes with correct output shape, gradients propagate
  • unpatchify consumes the new shape correctly (round-trip preserves dims)
  • Real Wan2.2-I2V-A14B training step — pending; want a maintainer's preferred reproduction recipe first since the weights are ~28 GB to download

Once this lands, a follow-up PR adds an env-var-gated dual-GPU model-parallel path (WAN_DUAL_GPU=true) so users on 2× 24+ GB consumer cards can train Wan 2.2 14B LoRA without OOMing on video activations. That follow-up is structurally similar to the FLUX.2 PR #1434 from this account.

Fixes modelscope#1063.

WanModel.forward (and the parallel usp_dit_forward in
diffsynth/utils/xfuser/xdit_context_parallel.py) both unpack
patchify's return value as `x, (f, h, w) = self.patchify(x)`, but
patchify only returns `x` -- and that `x` is a 5D (B, dim, f, h, w)
tensor straight out of Conv3d, not the (B, f*h*w, dim) sequence the
transformer blocks expect.

Net result: every Wan training run crashes at the first forward call
with `ValueError: not enough values to unpack (expected 2, got 1)`.
The block loop never executes.

The fix: extract (f, h, w) from the Conv3d output shape, flatten the
3D spatial-temporal grid into a token sequence (B, f*h*w, dim) ready
for the transformer blocks, return (x, (f, h, w)) to match what
forward and unpatchify expect.

Verified end-to-end with a synthetic 8-layer WanModel on the same
forward path (input shape (1, 16, 4, 8, 8) -> output shape
(1, 16, 4, 8, 8) round-trip), forward + backward both complete,
loss gradient propagates through all blocks.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the patchify method in wan_video_dit.py to return a tuple containing the flattened tensor and its spatial dimensions, aligning it with the transformer block's expectations. However, the reviewer identified critical regressions where existing callers of this method were not updated, which will lead to AttributeError and shape mismatches due to double embedding. A code suggestion was provided to refine the dimension slicing.

Comment thread diffsynth/models/wan_video_dit.py Outdated
Comment on lines +505 to +507
f, h, w = x.shape[-3], x.shape[-2], x.shape[-1]
x = x.flatten(2).transpose(1, 2)
return x, (f, h, w)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While this change correctly aligns patchify with the expectations of WanModel.forward, it introduces critical regressions in other parts of the codebase that were not updated in this PR:

  1. model_fn_wan_video (lines 1411-1414): This function still assigns the return value of patchify to a single variable x. Consequently, line 1421 (f, h, w = x.shape[2:]) will raise an AttributeError because x is now a tuple. Furthermore, line 1422 attempts to flatten a tensor that is already flattened and transposed by patchify.
  2. model_fn_wans2v (line 1661): This function passes an already embedded tensor (dit.patch_embedding(x) + ...) to patchify. Since patchify now always applies self.patch_embedding (lines 494/496), this results in a double embedding, which will cause a shape mismatch or incorrect spatial dimensions due to the stride in the Conv3d layer.

Please update these callers to handle the new return signature and avoid redundant embedding steps.

Suggested change
f, h, w = x.shape[-3], x.shape[-2], x.shape[-1]
x = x.flatten(2).transpose(1, 2)
return x, (f, h, w)
f, h, w = x.shape[-3:]
x = x.flatten(2).transpose(1, 2)
return x, (f, h, w)

@xiuyuan18
Copy link
Copy Markdown

patchify return type should not be change. f h w can be retrieved from the shape of the returned variable, which avoiding changing the signature.

@xiuyuan18
Copy link
Copy Markdown

the official implementation of Wan 2.1 also flatten the x in forward, not in patchify. It is better to modify the code in forward instead of patchify.

Per @xiuyuan18 review: the bug should be fixed in forward (and the
other tuple-unpacking consumers), not by changing patchify's return
signature. This matches the official Wan 2.1 implementation, which
also flattens x in forward, not in patchify.

Reverts patchify back to returning just `x` (post-Conv3d tensor shape
(B, dim, f, h, w)). Adds the flatten + (f, h, w) extraction at each
tuple-unpacking consumer:

  - wan_video_dit.py: WanModel.forward
  - xdit_context_parallel.py: usp_dit_forward
  - wan_video.py: model_fn_wans2v (two sites: main x and ref_latents)

WanToDance callers in wan_video.py (lines 1412/1414) continue to work
unchanged because they always wanted the raw post-Conv3d tensor and
do their own flattening downstream.

Pre-existing S2V double-embedding pattern at model_fn_wans2v (caller
pre-runs patch_embedding, then patchify re-runs patch_embedding
internally) is preserved as-is — out of scope for this PR.

Refs modelscope#1063
@genno-whittlery
Copy link
Copy Markdown
Author

Thanks for the detailed feedback — you're right on both counts. Pushed a v2 that:

  1. Reverts patchify to its original signature (return x, raw post-Conv3d tensor) — matches the official Wan 2.1 implementation as you noted.
  2. Moves the flatten + (f, h, w) extraction into the consumers that need it:
    • WanModel.forward (the original broken caller)
    • usp_dit_forward in xdit_context_parallel.py
    • model_fn_wans2v in wan_video.py (two sites: main x and ref_latents)
  3. Leaves WanToDance callers at wan_video.py:1412,1414 untouched — they always wanted the raw post-Conv3d tensor and do their own flatten downstream. (This was also the breakage @gemini-code-assist flagged in v1.)

The pre-existing S2V double-embedding pattern at model_fn_wans2v (caller pre-runs patch_embedding then patchify re-runs it) is preserved as-is — looks like a separate latent bug but out of scope for this fix.

Diff: 3 files, +15/-13 lines vs main. PTAL.

@genno-whittlery genno-whittlery changed the title fix(wan): WanModel.patchify flatten + return grid size tuple (fixes #1063) fix(wan): apply patchify flatten in forward/consumers (fixes #1063) May 11, 2026
genno-whittlery added a commit to genno-whittlery/DiffSynth-Studio that referenced this pull request May 11, 2026
Off by default -- no behavior change unless WAN_DUAL_GPU=true is set.

Wan 2.2 14B variants (I2V-A14B, T2V-A14B, S2V-14B, etc.) are ~28 GB
in bf16 -- the weights fit on one 32 GB consumer card with fp8 quant,
but video training activations at 480x832x49 frames + gradient
checkpointing routinely push the actual step over 32 GB even on a
14B model. Splitting the transformer blocks across two GPUs gives
training-step headroom that single-GPU users can't otherwise reach
without dropping resolution or frame count.

What changed:

- examples/wanvideo/model_training/wan_dual_gpu_diffsynth.py (new):
  ~150 LOC helper. Splits WanModel.blocks at the midpoint across
  cuda:0/cuda:1. Registers forward_pre_hook on every cuda:1 block
  (not just the boundary -- Wan's forward passes loop-level constants
  context / t_mod / freqs positionally to each iteration, so a
  boundary-only hook would leave subsequent blocks receiving cuda:0
  tensors). Bridges activations back to cuda:0 at the head module.
  Also explicitly moves WanModel.freqs (a tuple of plain CPU tensors,
  not registered buffers) so .to(device) doesn't miss them.

- examples/wanvideo/model_training/train.py: forces CPU model load
  when WAN_DUAL_GPU=true (so the bf16 transformer doesn't
  pre-allocate on cuda:0 before split), runs torchao
  Float8WeightOnlyConfig quantize_ with the same LoRA-skip filter
  used by the FLUX.2 port (skips lora_A/lora_B Linear submodules --
  otherwise their requires_grad is stripped and backward fails),
  then calls enable_wan_dual_gpu(model.pipe.dit) after PEFT has
  injected LoRA so block.to(device) carries LoRA params with their
  base layers. Also sets FLUX2_DUAL_GPU=true after distribute so the
  existing runner.py branch from PR modelscope#1434 catches the
  device_placement=[False, ...] case in accelerator.prepare without
  needing a parallel WAN_DUAL_GPU branch there.

Depends on modelscope#1435 (patchify fix). The current main has a broken
WanModel.patchify that returns the wrong shape and arity; Wan
training fails immediately at the first forward call regardless of
dual-GPU. Once modelscope#1435 lands, both single-GPU and dual-GPU Wan training
paths work.

Validated locally on 2x RTX 5090 with a synthetic 8-layer WanModel
(same architecture shape as real Wan 2.2, miniaturized to fit a
quick smoke test): forward + backward complete across the cross-
device split, output round-trips to the original (B, C, T, H, W)
shape, LoRA gradients land on both cuda:0 and cuda:1 (proving cross-
device autograd).

Same patch shape as the validated FLUX.2 port in PR modelscope#1434 from this
account. Both share the runner.py model-parallel branch.
Copy link
Copy Markdown

@xiuyuan18 xiuyuan18 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reviewed the changes to these files. They correctly resolve the embedding flattening and shape unpacking issues. Additionally, the original patchify function signature remains unchanged, so there are no compatibility concerns with existing code.

It looks good to merge.

@genno-whittlery
Copy link
Copy Markdown
Author

Thanks for the review @xiuyuan18 — appreciate the quick turnaround.

Standing by for the merge; happy to handle any follow-ups (rebase, additional tests) if useful. Once this lands, I'll rebase #1436 (the dual-GPU helper for Wan video training, which depends on this fix) and we'll have the end-to-end Wan dual-GPU path ready for review there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

wan_video_dit.py中的patchify函数是错误实现

2 participants