Skip to content

AssertionError in Context Parallelism during WanImageToVideoPipeline inference: Tensor size along sharding dimension not divisible by mesh size #12536

@leeguandong

Description

@leeguandong

Hi! @a-r-r-o-w ,I would like to ask you about my error on using Context Parallelism for inference.

Issue Description

Environment

  • Diffusers: 0.36.0.dev0

Problem Description

I'm trying to run image-to-video generation using WanImageToVideoPipeline with model quantization (qfloat8_e4m3fn via Quanto), frozen weights, and Context Parallelism enabled with ulysses_degree=8. The pipeline initializes successfully, but during the first inference step (at 0/20 steps), it raises an AssertionError in the Context Parallel hook:

AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size

This occurs in diffusers/hooks/context_parallel.py during the sharding of hidden_states in the transformer block's forward pass.

Expected Behavior: The pipeline should generate the video frames without crashing, distributing computation across GPUs via Context Parallelism.

Actual Behavior: Crashes immediately at the start of denoising loop.

Minimal Reproducible Code

Here's the full script that's failing (run with torchrun --nproc_per_node=8 test.py or similar for 8 GPUs):

import torch
import os
from PIL import Image
from diffusers import (
    AutoencoderKLWan, WanPipeline, WanTransformer3DModel, ContextParallelConfig,
    WanImageToVideoPipeline
)
from diffusers.utils import export_to_video
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from optimum.quanto import freeze, qfloat8_e4m3fn, quantize
from transformers import AutoTokenizer, UMT5EncoderModel, CLIPVisionModel

torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
torch.cuda.set_device(device)
dtype = torch.bfloat16

model_id = '/share/models/checkpoints/Wan-AI/Wan2___1-I2V-14B-720P-Diffusers'

transformer = WanTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
text_encoder = UMT5EncoderModel.from_pretrained(
    model_id, subfolder="text_encoder", torch_dtype=dtype
)

# Quantize text_encoder
quantize(text_encoder, weights=qfloat8_e4m3fn)
freeze(text_encoder)

# Quantize transformer
quantize(transformer, weights=qfloat8_e4m3fn)
freeze(transformer)

pipe = WanImageToVideoPipeline.from_pretrained(
    model_id,
    transformer=transformer,
    text_encoder=text_encoder,
    torch_dtype=dtype
)

flow_shift = 5.0  # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(
    pipe.scheduler.config, flow_shift=flow_shift
)
pipe.to("cuda")

transformer.set_attention_backend("_native_cudnn")
pipe.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=8)
)

image = Image.open("/share/common/AIPhoto/3.jpeg").resize((832, 480))
# .resize((800,1280))

prompt = (
    "现代都市风格摄影,一位身穿白色印花T恤和黑色短裤的年轻男子坐在透明玻璃楼梯上,"
    "脚穿黑白帆布鞋,姿态随性自然。他的皮肤白皙,身材匀称,双腿微微分开,"
    "手肘搭在膝盖上,背景是高耸的玻璃幕墙和现代化建筑,透过玻璃可见城市的高楼轮廓。"
    "在固定镜头下,他缓慢抬起双手【双手比心】,动作轻松流畅,整个画面充满现代感与都市气息。"
    "慢动作展现细腻的动态细节。"
)

negative_prompt = (
    "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, "
    "static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, "
    "extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, "
    "fused fingers, still picture, messy background, three legs, many people in the background, "
    "walking backwards"
)

# Must specify generator so all ranks start with same latents (or pass your own)
generator = torch.Generator().manual_seed(42)

output = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=480,
    width=832,
    num_frames=81,
    guidance_scale=5.0,
    num_inference_steps=20,
    generator=generator,
).frames[0]

if rank == 0:
    export_to_video(output, "output.mp4", fps=16)

if torch.distributed.is_initialized():
    torch.distributed.destroy_process_group()

Error Traceback

0%|          | 0/20 [00:00<?, ?it/s]
[rank1]: Traceback (most recent call last):
[rank1]:   File "/share/gdli7/common/AIPhoto/test.py", line 46, in <module>
[rank1]:     output = pipe(
[rank1]:               ^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 756, in __call__
[rank1]:     noise_pred = current_model(
[rank1]:                  ^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/models/transformers/transformer_wan.py", line 680, in forward
[rank1]:     hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
[rank1]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/hooks.py", line 188, in new_forward
[rank1]:     args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank1]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 157, in pre_forward
[rank1]:     input_val = self._prepare_cp_input(input_val, cpm)
[rank1]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 209, in _prepare_cp_input
[rank1]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 259, in shard
[rank1]:     assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions