-
Couldn't load subscription status.
- Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Using Lumina 2.0 lora from civitai throw error.
Works fine for https://huggingface.co/sayakpaul/trained-lumina2-lora-yarn
Reproduction
I tried using loras listed here
https://civitai.com/search/models?baseModel=Lumina&modelType=LORA&sortBy=models_v9&query=lumina
with code
https://huggingface.co/sayakpaul/trained-lumina2-lora-yarn
import torch
from diffusers import Lumina2Text2ImgPipeline
pipe = Lumina2Text2ImgPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16
).to("cuda")
# Art Style of Hitoshi Ashinano https://civitai.com/models/1269546/art-style-of-hitoshi-ashinano-lumina-image-20
pipe.load_lora_weights("newgenai79/lumina2_lora",weight_name="Art_Style_of_Hitoshi_Ashinano.safetensors")
# Art Style of Studio Ghibli https://civitai.com/models/1257597/art-style-of-studio-ghibli-lumina-image-20
# pipe.load_lora_weights("newgenai79/lumina2_lora",weight_name="Art_Style_of_Studio_Ghibli.safetensors")
# Yarn https://huggingface.co/sayakpaul/trained-lumina2-lora-yarn
# pipe.load_lora_weights("newgenai79/lumina2_lora",weight_name="lumina2_puppy_lora.safetensors")
prompt = "Hitoshi Ashinano style. A young girl with vibrant green hair and large purple eyes peeks out from behind a white wooden door. She is wearing a white shirt and have a curious expression on her face. The background shows a blue sky with a few clouds, and there's a white fence visible. Green leaves hang down from the top left corner, and a small white circle can be seen in the sky. The scene captures a moment of innocent curiosity and wonder."
image = pipe(
prompt,
negative_prompt="blurry, ugly, bad, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, cropped, out of frame, worst quality, low quality, jpeg artifacts, fused fingers, morbid, mutilated, extra fingers, mutated hands, bad anatomy, bad proportion, extra limbs",
guidance_scale=6,
num_inference_steps=35,
generator=torch.manual_seed(0)
).images[0]
Logs
(venv) C:\aiOWN\diffuser_webui>python lumina2_lora.py
Loading checkpoint shards: 100%|████████████████████████████████████| 2/2 [00:07<00:00, 3.75s/it]
Loading checkpoint shards: 100%|████████████████████████████████████| 3/3 [00:11<00:00, 3.70s/it]
Loading pipeline components...: 100%|███████████████████████████████| 5/5 [00:19<00:00, 3.98s/it]
Loading default_0 was unsucessful with the following error:
Target modules {'w2', 'adaLN_modulation.1', 'w1', 'out', 'qkv', 'w3'} not found in the base model. Please check the target modules and try again.
Traceback (most recent call last):
File "C:\aiOWN\diffuser_webui\lumina2_lora.py", line 8, in <module>
pipe.load_lora_weights("models/lora/lumina2/Art_Style_of_Hitoshi_Ashinano.safetensors")
File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\diffusers\loaders\lora_pipeline.py", line 3957, in load_lora_weights
self.load_lora_into_transformer(
File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\diffusers\loaders\lora_pipeline.py", line 3994, in load_lora_into_transformer
transformer.load_lora_adapter(
File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\diffusers\loaders\peft.py", line 303, in load_lora_adapter
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\peft\mapping.py", line 260, in inject_adapter_in_model
peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\peft\tuners\lora\model.py", line 141, in __init__
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\peft\tuners\tuners_utils.py", line 184, in __init__
self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
File "C:\aiOWN\diffuser_webui\venv\lib\site-packages\peft\tuners\tuners_utils.py", line 520, in inject_adapter
raise ValueError(error_msg)
ValueError: Target modules {'w2', 'adaLN_modulation.1', 'w1', 'out', 'qkv', 'w3'} not found in the base model. Please check the target modules and try again.System Info
- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Windows-10-10.0.26100-SP0
- Running on Google Colab?: No
- Python version: 3.10.11
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.27.1
- Transformers version: 4.48.1
- Accelerate version: 1.4.0.dev0
- PEFT version: 0.14.0
- Bitsandbytes version: 0.45.3.dev0
- Safetensors version: 0.5.2
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4060 Laptop GPU, 8188 MiB
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working