-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Open
Labels
Description
Describe the bug
I want to dynamically control LoRA, but when I use CPU offloading in leaf_level mode, an error occurs after calling delete_adapters.I think it's caused by LoRA's hook.
Reproduction
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline
from diffusers.utils import load_image
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
# cpu offload
onload_device = torch.device("cuda:0")
offload_device = torch.device("cpu")
pipe.transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
# offload_type="block_level",
# num_blocks_per_group=1,
use_stream=True,
record_stream=False,
)
pipe.vae = pipe.vae.to(onload_device)
pipe.text_encoder_2 = pipe.text_encoder_2.to(onload_device)
pipe.text_encoder = pipe.text_encoder.to(onload_device)
pipe.load_lora_weights("AI-ModelScope/FLUX.1-Canny-dev-lora", adapter_name="canny")
pipe.set_adapters("canny", 0.85)
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("test_image.png")
processor = CannyDetector()
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
image = pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=5,
guidance_scale=30.0,
).images[0]
image.save("output.png")
pipe.delete_adapters("canny")
image = pipe(
prompt=prompt,
control_image=control_image,
height=1024,
width=1024,
num_inference_steps=5,
guidance_scale=30.0,
).images[0]Logs
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/transformers/transformer_flux.py", line 696, in forward
else self.time_text_embed(timestep, guidance, pooled_projections)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/embeddings.py", line 1607, in forward
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/embeddings.py", line 1290, in forward
sample = self.linear_1(sample)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/layer.py", line 757, in forward
result = self.base_layer(x, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/diffusers/hooks/hooks.py", line 189, in new_forward
output = function_reference.forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but got mat1 is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA_addmm)System Info
- 🤗 Diffusers version: 0.36.0.dev0
- Platform: Linux-5.10.134-15.al8.x86_64-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.8.0+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.34.4
- Transformers version: 4.55.0
- Accelerate version: 1.10.1
- PEFT version: 0.17.1
- Bitsandbytes version: not installed
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4090, 49140 MiB
NVIDIA GeForce RTX 4090, 49140 MiB - Using GPU in script?: YES
- Using distributed or parallel set-up in script?: NO
Who can help?
No response