-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[LoRA] add support for more Qwen LoRAs #12581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2213,6 +2213,10 @@ def convert_key(key: str) -> str: | |||||||||
|
|
||||||||||
| state_dict = {convert_key(k): v for k, v in state_dict.items()} | ||||||||||
|
|
||||||||||
| has_default = any("default." in k for k in state_dict) | ||||||||||
| if has_default: | ||||||||||
| state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()} | ||||||||||
|
Comment on lines
+2217
to
+2218
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, that it's done as intended:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as here - 'default' is not in the key's prefix, so this won't be the intended behavior in this case |
||||||||||
|
|
||||||||||
| converted_state_dict = {} | ||||||||||
| all_keys = list(state_dict.keys()) | ||||||||||
| down_key = ".lora_down.weight" | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4940,7 +4940,8 @@ def lora_state_dict( | |
| has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict) | ||
| has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict) | ||
| has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict) | ||
| if has_alphas_in_sd or has_lora_unet or has_diffusion_model: | ||
| has_default = any("default." in k for k in state_dict) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as #12581 (comment) - 'default' is not in the key's prefix, so this won't be the intended behavior in this case |
||
| if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default: | ||
| state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict) | ||
|
|
||
| out = (state_dict, metadata) if return_lora_metadata else state_dict | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's starting with "default.", then let's be explicit about that:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"default" isn't the prefix though, e.g.
transformer_blocks.0.attn.add_v_proj.lora_A.default.weight