Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] fix cross_attention_kwargs problems and tighten tests #7388

Merged
merged 10 commits into from
Mar 19, 2024

Conversation

sayakpaul
Copy link
Member

What does this PR do?

First of all, I would like to apologize for not being rigorous enough with #7338. This was actually breaking:

RUN_SLOW=1 pytest tests/lora/test_lora_layers_peft.py::StableDiffusionLoRATests::test_integration_logits_with_scale

This is because pop() pops the requested key forever from the underlying dictionary (for the first time) and uses the default value throughout the subsequent calls. Since unet within a DiffusionPipeline is iteratively called this phenomenon creates a lot of unexpected consequences. As a result, the above-mentioned test fails. Here are the lora_scale values:

lora scale: 0.5
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0
lora scale: 1.0

Notice how it is defaulting to 1.0 after the first round of denoising step.

A simple solution is to create a shallow copy of cross_attention_kwargs so that the original one is left untouched. This is what this PR does.

Additionally, you may wonder why the below set of tests PASS?

pytest tests/lora/test_lora_layers_peft.py -k "test_simple_inference_with_text_unet_lora_and_scale"

My best guess is that because we use a little too few num_inference_steps to validate things. To see if my hunch was right, I increased the num_inference_steps to 5 here, and run these tests WITHOUT the changes introduced in this PR (i.e., shallow copy). All of those tests failed. With the changes, they pass.

Once this PR is merged, I will take care of making another patch release.

Once again, I am genuinely sorry for the oversight on my end.

@sayakpaul
Copy link
Member Author

Cc: @younesbelkada for viz.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

Will also wait for @BenjaminBossan to approve it. And then I will proceed.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch ! Thanks ! One could also use get to avoid copying the kwargs at each forward !

@sayakpaul
Copy link
Member Author

The problem with get() is that the scale value gets propagated to the internal layers of the UNet, causing unnecessary warnings. This will be confusing for the users. LMK if that makes sense.

@younesbelkada
Copy link
Contributor

ok makes sense ! thanks for explaining !

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this bug, I think the copy solution is solid.

@sayakpaul sayakpaul merged commit b09a2aa into main Mar 19, 2024
17 checks passed
@sayakpaul sayakpaul deleted the debug-lora-scale-issue branch March 19, 2024 12:23
sayakpaul added a commit that referenced this pull request Mar 20, 2024
* debugging

* let's see the numbers

* let's see the numbers

* let's see the numbers

* restrict tolerance.

* increase inference steps.

* shallow copy of cross_attentionkwargs

* remove print
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants