Skip to content

[Bug] WanVideoToVideoPipeline fails due to missing handling of control arguments required by underlying VACE transformer #12574

@Mathias5

Description

@Mathias5

Describe the bug

Problem:
WanVideoToVideoPipeline.call does not expose control_hidden_states or control_hidden_states_scale in its signature.
However, the underlying TransformerWanVACE.forward() requires them and assumes they’re non-None.

This leads to a crash every time the pipeline is used:

pipe(video=my_frames, prompt="studio lighting")
# -> AttributeError: 'NoneType' object has no attribute 'new_ones'

and you can’t fix it by passing the tensors, because the pipeline doesn’t accept those kwargs:

pipe(video=my_frames, prompt="...", control_hidden_states=...)
# -> TypeError: __call__() got an unexpected keyword argument 'control_hidden_states'

Suggested fix:
Expose control_hidden_states and control_hidden_states_scale in WanVideoToVideoPipeline.call, or set neutral defaults internally when not provided.

Reproduction

from diffusers import WanVideoToVideoPipeline, AutoencoderKLWan
from PIL import Image
import torch

# 1. Load the model and VAE
vae = AutoencoderKLWan.from_pretrained(
    "Wan-AI/Wan-VACE-14B",
    subfolder="vae",
    torch_dtype=torch.float32,
)
pipe = WanVideoToVideoPipeline.from_pretrained(
    "Wan-AI/Wan-VACE-14B",
    vae=vae,
    torch_dtype=torch.bfloat16,
).to("cuda")

# 2. Create a dummy video (8 identical frames)
frames = [Image.new("RGB", (512, 288), "gray") for _ in range(8)]

# 3. Run baseline inference
with torch.autocast("cuda", dtype=torch.bfloat16):
    result = pipe(video=frames, prompt="studio lighting")

Expected behavior:
The call should run inference normally, producing an output video.

Actual behavior:
It crashes with:
AttributeError: 'NoneType' object has no attribute 'new_ones'

Logs

Traceback (most recent call last):
  File "repro_vace_crash.py", line 21, in <module>
    result = pipe(video=frames, prompt="studio lighting")
  File ".../diffusers/pipelines/wan/pipeline_wan_video2video.py", line 663, in __call__
    noise_pred = self.transformer(
  File ".../torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File ".../diffusers/models/transformers/transformer_wan_vace.py", line 295, in forward
    control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers))
AttributeError: 'NoneType' object has no attribute 'new_ones'

System Info

  • 🤗 Diffusers version: 0.35.2
  • Platform: Linux-6.8.0-1036-aws-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.12.10
  • PyTorch version (GPU?): 2.8.0+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.36.0
  • Transformers version: 4.57.1
  • Accelerate version: 1.11.0
  • PEFT version: 0.17.1
  • Bitsandbytes version: 0.48.1
  • Safetensors version: 0.6.2
  • xFormers version: 0.0.32.post2
  • Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions