-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Description
Describe the bug
I came across an issue that my model kept getting moved to CPU after loading LoRA weights with the load_lora_weights() method.
I found out that is_sequential_cpu_offload is set to True while loading LoRA on https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_base.py#L441 despite never enabling CPU offload in my code. Inference then takes about 15x more time than when the model sits on the GPU.
I'm using an 8-bit quantized FLUX model and a FLUX LoRA and my model is supposed to be on the GPU.
When I add the parameter that I commented out in the code ( device_map="balanced" ) to pipeline initialization, the issue disappears and the model stays on the GPU.
Is it intended behavior that it doesn't work without the extra parameter?
This issue is related to #7539 but that issue is stale and I came across the issue in a different way, so I decided to open a new one. In #7539 , the author of the issue explicitly calls align_device_hook(). I'm just trying to load LoRA weights without doing anything with device hooks.
The author of #7539 created a PR (#8750) that solved my issue without having to add device_map="balanced" to the code, but the PR never got merged.
Reproduction
import torch
from diffusers import FluxPipeline, BitsAndBytesConfig, FluxTransformer2DModel
from transformers import T5EncoderModel
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
text_encoder_2_8bit = T5EncoderModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="text_encoder_2",
quantization_config=TransformersBitsAndBytesConfig(load_in_8bit=True),
torch_dtype=torch.float16,
)
transformer_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
torch_dtype=torch.float16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder_2=text_encoder_2_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
#device_map="balanced"
)
pipe.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="anime_lora.safetensors")
prompt = "a hand holding a knife and cutting a cabbage, anime"
out = pipe(
prompt=prompt,
guidance_scale=3.5,
height=1024,
width=1024,
num_inference_steps=25,
).images[0]
out.save("image.png")
Logs
System Info
- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-5.15.0-126-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.11
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.28.1
- Transformers version: 4.48.1
- Accelerate version: 1.3.0
- PEFT version: 0.14.1.dev0
- Bitsandbytes version: 0.45.3
- Safetensors version: 0.4.5
- xFormers version: 0.0.29.post2
- Accelerator: NVIDIA L4, 23034 MiB
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
No response