-
Notifications
You must be signed in to change notification settings - Fork 6.5k
🚨 feat: add non-breaking support to serialize metadata in loras. #9143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice quick work. Looks almost perfect to me already :) I can't really say if all places have been adjusted or some were missed, I'll leave that to someone who has better knowledge of the diffusers code base.
Apart from that, I only have a few comments, please check.
I also noticed that there is already a mechanism for providing alphas via network_alpha_dict. IIUC, this needs to be user provided for now and cannot be inferred automatically from the checkpoint. Still, I wonder if we can set network_alpha_dict from config instead of having two somewhat independent code paths to achieve the same goal.
The docstrings for config are still empty, please add them, as it's not immediately obvious what values are expected.
src/diffusers/utils/peft_utils.py
Outdated
| # Try to retrive config. | ||
| alpha_retrieved = False | ||
| if config is not None: | ||
| lora_alpha = config["lora_alpha"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this also needs to consider alpha_pattern -- if not intended to support for now, at least raise an error if alpha_pattern is given? If support is added, the unit test should include different alpha values (or a separate test could be added for this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alpha_pattern cannot be provided through get_peft_kwargs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, maybe I'm misunderstanding, but we're passing a LoraConfig instance, how can we know that this does not have alpha_pattern?
I wonder if this should be:
# Try to retrieve config.
alpha_retrieved = False
if config is not None:
lora_alpha = config["lora_alpha"] if "lora_alpha" in config else lora_alpha
alpha_retrieved = True
+ if config.get("alpha_pattern", None):
+ alpha_pattern = config["alpha_pattern"]Similar argument for rank_pattern. In general, it's not clear to me how we should handle it if rank_pattern/alpha_pattern differ from rank_dict/ network_alpha_dict (or is it not possible)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Let me accommodate those changes and have it ready for your review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BenjaminBossan does f7d30de work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, IIUC there is still an issue. Say get_peft_kwargs is called with both network_alpha_dict and config being passed. As is, network_alpha_dict is completely ignored and alpha_pattern remains an empty dict. I think what should happen is that either
alpha_patternis taken fromnetwork_alpha_dict(same as previously)alpha_patternis taken fromconfigifconfig.alpha_patternis notNone. If it isNone,network_alpha_dictshould be used.
So either network_alpha_dict or config.alpha_pattern should take precedence. And if both are given, potentially warn about the one being ignored. WDYT?
Depending on the choice here, rank_pattern should also be adjusted for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alpha_pattern is taken from config if config.alpha_pattern is not None. If it is None, network_alpha_dict should be used.
This is not what I am doing. That is because network_alpha_dict can only be true for non-diffusers checkpoints. For those checkpoints, metadata won't have a PEFT config.
Long story cut, short, rank_pattern and alpha_pattern from config (if found) will be simply ignored for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying, so it is assumed that rank_pattern/alpha_pattern are never passed when config is passed, and vice versa. In that case, this could be checked and an error raised, or at least a comment added, WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add a comment. For now, a warning like we are doing now would suffice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment in 178a459.
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
|
Thanks for suggestions, @BenjaminBossan! I have added the docstrings. LMK what you think. |
|
for now we've set up the default alpha to equate to rank since i was sorta confused or misunderstood what the purpose of the alpha parameter was for. as it turns out some trainers implement it a bit differently, eg. permanently scale the weights before exporting/saving. but rank=alpha works for us for now, i'll implement this when it's merged in. |
|
Right, thanks! |
| if file_extension == SAFETENSORS_FILE_EXTENSION: | ||
| with safetensors.safe_open(model_file, framework="pt", device="cpu") as f: | ||
| metadata = f.metadata() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will we have any way or desire to warn users loading a lora from a pth file that we can't scale it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We currently error out when trying to save metadata to a pth file (or more generally, when use_safetensors is False).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good! i never use them. but i was concerned for others.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying my questions get_peft_kwargs. LGTM
bghira
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haven't tested it but it looks like it does the job.
|
Closing this PR for now in light of #9236 |
What does this PR do?
To avoid bugs like:
In my view, it's a non-breaking change. Otherwise, training tests would have failed because of unpacking mismatches.
I attempted this back in the day: #6135. But we had to close it. Now, its necessity is evident.