Skip to content

Hooks conflicts: Context Parallelism and CPU Offload #12533

@DefTruth

Description

@DefTruth

Describe the bug

Enable cpu offload before enabling parallelism will raise shape error after first pipe call. It seems a bug of diffusers that cpu offload is not fully compatible with context parallelism, visa versa.

  • cpu offload before context parallelism (not work)
pipe.enable_model_cpu_offload(device=device)

# pipe.transformer.set_attention_backend("flash")
pipe.transformer.set_attention_backend("_native_cudnn")
pipe.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=dist.get_world_size())
)
  • cpu offload after context parallelism (work)
# pipe.transformer.set_attention_backend("flash")
pipe.transformer.set_attention_backend("_native_cudnn")
pipe.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=dist.get_world_size())
)
pipe.enable_model_cpu_offload(device=device)

Reproduction

import os
import time
import torch
import torch.distributed as dist
from diffusers import (
    QwenImagePipeline,
    QwenImageTransformer2DModel,
    ContextParallelConfig,
)


def maybe_init_distributed():
    if not dist.is_initialized():
        dist.init_process_group("nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    torch.cuda.set_device(device)
    return rank, device


def maybe_destroy_distributed():
    if dist.is_initialized():
        dist.destroy_process_group()


rank, device = maybe_init_distributed()

pipe = QwenImagePipeline.from_pretrained(
    os.environ.get(
        "QWEN_IMAGE_DIR",
        "Qwen/Qwen-Image",
    ),
    torch_dtype=torch.bfloat16,
)

# NOTE: Enable cpu offload before enabling parallelism will
# raise shape error after first pipe call, so we enable it after.
# It seems a bug of diffusers that cpu offload is not fully
# compatible with context parallelism, visa versa.
pipe.enable_model_cpu_offload(device=device)

assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
# pipe.transformer.set_attention_backend("flash")
pipe.transformer.set_attention_backend("_native_cudnn")
pipe.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=dist.get_world_size())
)

# NOTE: Enable cpu offload after enabling parallelism
# pipe.enable_model_cpu_offload(device=device)

# assert isinstance(pipe.vae, AutoencoderKLQwenImage)
# pipe.vae.enable_tiling()

positive_magic = {
    "en": ", Ultra HD, 4K, cinematic composition.",  # for english prompt
    "zh": ", 超清,4K,电影级构图.",  # for chinese prompt
}

# Generate image
prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""

# using an empty string if you do not have specific concept to remove
negative_prompt = " "

pipe.set_progress_bar_config(disable=rank != 0)


def run_pipe():
    # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
    image = pipe(
        prompt=prompt + positive_magic["en"],
        negative_prompt=negative_prompt,
        width=1024,
        height=1024,
        num_inference_steps=50,
        true_cfg_scale=4.0,
        generator=torch.Generator(device="cpu").manual_seed(42),
    ).images[0]

    return image


# warmup
_ = run_pipe() # always work

start = time.time()
image = run_pipe() # raise error here if cpu offload is enabled before parallelism
end = time.time()


if rank == 0:
    time_cost = end - start
    save_path = f"qwen-image.cp{dist.get_world_size()}.png"
    print(f"Time cost: {time_cost:.2f}s")
    print(f"Saving image to {save_path}")
    image.save(save_path)

maybe_destroy_distributed()

Logs

Error:


[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/dev/vipshop/cache-dit/examples/parallelism/run_qwen_image_cp_naive.py", line 71, in <module>
[rank0]:     start = time.time()
[rank0]:             ^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/cache-dit/examples/parallelism/run_qwen_image_cp_naive.py", line 54, in run_pipe
[rank0]:     # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py", line 691, in __call__
[rank0]:     noise_pred = self.transformer(
[rank0]:                  ^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 175, in new_forward
[rank0]:     output = module._old_forward(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 647, in forward
[rank0]:     encoder_hidden_states, hidden_states = block(
[rank0]:                                            ^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 443, in forward
[rank0]:     attn_output = self.attn(
[rank0]:                   ^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/attention_processor.py", line 605, in forward
[rank0]:     return self.processor(
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 322, in __call__
[rank0]:     img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 139, in apply_rotary_emb_qwen
[rank0]:     x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
[rank0]:                                ~~~~~~~~~~^~~~~~~~~~~
[rank0]: RuntimeError: The size of tensor a (4096) must match the size of tensor b (2048) at non-singleton dimension 1

System Info

diffusers 0.36.dev0 (latest main branch), pytorch 2.9.0

Who can help?

@yiyixuxu @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions