-
Couldn't load subscription status.
- Fork 6.4k
Description
Describe the bug
The function signature of load_model_dict_into_meta changed in #10604, and device is no longer an accepted argument. However, IP-Adapter loading still passes device, as we can see below:
| load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) |
| load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) |
diffusers/src/diffusers/loaders/transformer_sd3.py
Lines 78 to 80 in e3bc4aa
| load_model_dict_into_meta( | |
| attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype | |
| ) |
| load_model_dict_into_meta(image_proj, updated_state_dict, device=self.device, dtype=self.dtype) |
diffusers/src/diffusers/loaders/unet.py
Line 756 in e3bc4aa
| load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) |
diffusers/src/diffusers/loaders/unet.py
Line 849 in e3bc4aa
| load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) |
Now that #10604 is merged, should we follow a similar approach to FromOriginalModelMixin, as below? That is, now pass device_map as {"": param_device}?
diffusers/src/diffusers/loaders/single_file_model.py
Lines 369 to 387 in e3bc4aa
| device_map = None | |
| if is_accelerate_available(): | |
| param_device = torch.device(device) if device else torch.device("cpu") | |
| empty_state_dict = model.state_dict() | |
| unexpected_keys = [ | |
| param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict | |
| ] | |
| device_map = {"": param_device} | |
| load_model_dict_into_meta( | |
| model, | |
| diffusers_format_checkpoint, | |
| dtype=torch_dtype, | |
| device_map=device_map, | |
| hf_quantizer=hf_quantizer, | |
| keep_in_fp32_modules=keep_in_fp32_modules, | |
| unexpected_keys=unexpected_keys, | |
| ) | |
| else: | |
| _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) |
Happy to contribute with a PR updating IP-Adapter loading :)
Reproduction
Loading any IP-Adapter with low_cpu_mem_usage=True, which is the default value when torch >= 1.9.0, for example:
import torch
from diffusers import FluxPipeline
from diffusers.utils import load_image
pipe: FluxPipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
)
pipe.load_ip_adapter(
"XLabs-AI/flux-ip-adapter",
weight_name="ip_adapter.safetensors",
image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14"
)
pipe.set_ip_adapter_scale(0.6)
pipe.enable_sequential_cpu_offload()
ip_adapter_image = load_image("https://huggingface.co/guiyrt/sample-images/resolve/main/astronaut.jpg")
image = pipe(
width=1024,
height=1024,
prompt="A vintage picture of an astronaut in a starry sky",
generator=torch.manual_seed(42),
ip_adapter_image=ip_adapter_image,
).images[0]
image.save('result.jpg')Logs
Traceback (most recent call last):
File "/home/guiyrt/diffusers/run.py", line 10, in <module>
pipe.load_ip_adapter(
File "/home/guiyrt/anaconda3/envs/diffusers/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
return fn(*args, **kwargs)
File "/home/guiyrt/diffusers/src/diffusers/loaders/ip_adapter.py", line 553, in load_ip_adapter
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
File "/home/guiyrt/diffusers/src/diffusers/loaders/transformer_flux.py", line 168, in _load_ip_adapter_weights
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
File "/home/guiyrt/diffusers/src/diffusers/loaders/transformer_flux.py", line 156, in _convert_ip_adapter_attn_to_diffusers
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
TypeError: load_model_dict_into_meta() got an unexpected keyword argument 'device'System Info
- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.10.16
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): 0.10.2 (cpu)
- Jax version: 0.5.0
- JaxLib version: 0.5.0
- Huggingface_hub version: 0.28.1
- Transformers version: 4.48.3
- Accelerate version: 1.3.0
- PEFT version: 0.14.0
- Bitsandbytes version: not installed
- Safetensors version: 0.5.2
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4070 Ti SUPER, 16376 MiB
- Using GPU in script?:
- Using distributed or parallel set-up in script?: