-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
I use group offloading with QwenImagepipeline, it shows a lot of blocks are not used in forward. And the final image is a mess. I try enable_sequential_cpu_offload is working fine. I believe it is the error from group offloading.
Reproduction
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
import torch
import math
# From https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3), # We use shift=3 in distillation
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3), # We use shift=3 in distillation
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None, # set shift_terminal to None
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
pipe = DiffusionPipeline.from_pretrained(
"Qwen/Qwen-Image", scheduler=scheduler, torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(
"lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors"
)
onload_device = torch.device("cuda")
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_type="leaf_level", use_stream=True, non_blocking=True, low_cpu_mem_usage=True)
diffusers.hooks.apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="leaf_level", use_stream=True, non_blocking=True, low_cpu_mem_usage=True)
prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition."
negative_prompt = " "
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=1024,
height=1024,
num_inference_steps=8,
true_cfg_scale=1.0,
generator=torch.manual_seed(0),
).images[0]
image.save("qwen_fewsteps.png")Logs
System Info
ubuntu22.04-cuda12.4.0-py311-torch2.8.0-diffusers0.35.2
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working