-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[LoRA] use the PyTorch classes wherever needed and start depcrecation cycles #7204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh great! Thanks so much
We can deprecate the scale argument everywhere now, too, right?
e.g., all the attention processors
|
@yiyixuxu up for another review. Rigorous review appreciated! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
I have two main feedbacks:
- let's deprecate the
scaleargument by removing it from the signature (see here #7204 (comment)) - I see that you deprecated the
scaleargument from some more public classes, but for some less public classes, you simply removed or silently ignored it - maybe we should just deprecate it everywhere and remove them all together later? e.g. e.g. #7204 (comment) and #7204 (comment)
|
cc @BenjaminBossan - we would appreciate it if you can give a review also |
|
@yiyixuxu addressed all your comments. They were very very helpful! Thank you! |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very thorough PR, thanks a lot Sayak. This LGTM overall.
Just a suggestion for the deprecation. Currently, the message is:
Use of
scaleis deprecated. Please remove the argument
As a user, I might not know what to do with that info: Is this feature removed completely or can I still use it, but have to do it differently? Also, I might get the impression that I can still pass scale and it works, it's just deprecated, when in fact the argument doesn't do anything, right? Perhaps the message could be clarified.
Moreover, if we already have an idea in which diffusers version this will be removed (hence raise an error), it could be added to the warning. On top, we could add a comment like # TODO remove argument in diffusers X.Y to make it more likely that this will indeed be cleaned up when this version is released.
|
Thanks, Benjamin!
Very good point. I clarified that as much as I could.
The |
|
Regarding the error message:
I think it's almost too detailed, users will not normally pass the argument directly to the
Can we also add a sentence on how to control the scale instead?
Cool, I didn't know 👍 |
|
How about?
|
|
Yes, that sounds good, as it clarifies to the user what they need to do. |
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh thanks!
I did another round of review,
- I left a question about the deprecation message, and I think we should use same message everywhere (I saw you updated in some places but not others)
- let's add a warning everywhere when the
scalepassed viacross_atten_kwargsis ignored - we can remove all these warnings all together at the same time in the future.
younesbelkada
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much ! IMO all good on PEFT end ! Great work @sayakpaul !
| encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | ||
|
|
||
| # Retrieve lora scale. | ||
| lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
| ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: | ||
| output_states = () | ||
|
|
||
| lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
| attention_mask: Optional[torch.FloatTensor] = None, | ||
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | ||
| ) -> torch.FloatTensor: | ||
| lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
| ): | ||
| output_states = () | ||
|
|
||
| lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | ||
| num_frames: int = 1, | ||
| ) -> torch.FloatTensor: | ||
| lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | ||
| num_frames: int = 1, | ||
| ) -> torch.FloatTensor: | ||
| lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warn("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
Co-authored-by: YiYi Xu <yixu310@gmail.com>
|
Can anyone tell me how to make a scale for LoRA if I can't use |
|
@panxiaoguang scale corresponds to |
Apart from what @younesbelkada mentioned (applies to "training" only) you can definitely use |
What does this PR do?
Since we have shifted to the
peftbackend for all things LoRA, there's no need for us to useLoRACompatible*classes now.We should also start the deprecation cycles for the
LoRALinearLayerandLoRAConv2dLayer. This PR does that.