Skip to content

Wan2.2 TI2V-5B Tiled VAE Tensor size mismatch #12529

@CalamitousFelicitousness

Description

Describe the bug

When running Wan 2.2 TI2V 5B with tiled VAE, the VAE step is unsuccessful, resulting in:
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 2

Reproduction

import torch
import numpy as np
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video

dtype = torch.bfloat16
device = "cuda"

model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"

vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
vae.enable_tiling(
    tile_sample_min_height=256,
    tile_sample_min_width=256,
    tile_sample_stride_height=64,
    tile_sample_stride_width=64
)

pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=dtype)
pipe.to(device)

height = 704
width = 1280
num_frames = 121
num_inference_steps = 50
guidance_scale = 5.0

prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=num_frames,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
).frames[0]

export_to_video(output, "5bit2v_output.mp4", fps=24)

print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")

Logs

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.07it/s]
The config attributes {'clip_output': False} were passed to AutoencoderKLWan, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:28<00:00,  5.60s/it]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.77it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:26<00:00,  8.66s/it]
Traceback (most recent call last):
  File "/workspace/Wan2.2/wan22tiledvae.py", line 75, in <module>
    output = pipe(
             ^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/pipelines/wan/pipeline_wan.py", line 645, in __call__
    video = self.vae.decode(latents, return_dict=False)[0]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1248, in decode
    decoded = self._decode(z).sample
              ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1204, in _decode
    return self.tiled_decode(z, return_dict=return_dict)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1374, in tiled_decode
    decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 893, in forward
    x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 709, in forward
    x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
        ~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 2

System Info

Tested with multiple, no difference noted.

  • 🤗 Diffusers version: 0.35.2
  • Platform: Linux-6.8.0-65-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.12.3
  • PyTorch version (GPU?): 2.8.0+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.35.3
  • Transformers version: 4.51.3
  • Accelerate version: 1.11.0
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.6.2
  • xFormers version: not installed
  • Accelerator: NVIDIA RTX A5000, 24564 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@sayakpaul @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions