-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Fix ControlNetModel.from_unet do not load add_embedding.state_dict #7173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh thanks!
|
has bug this been affecting controlnet training? cc @sayakpaul here |
|
Could have affected, but I don't have any empirical evidence of that. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice.
I wonder why can't we do something like:
controlnet.load_state_dict(unet.state_dict(), strict=False)
I've adjusted it to match the existing code, but your code seems to be better. |
Do you want to give it a try? We can revisit later as well if that's what you prefer. |
|
The code has been modified and confirmed to load correctly in SDXL and SD1.5. |
|
@yiyixuxu WDYT? |
not a fan of this because I think warnings will be confusing |
|
strict=False doesn’t lead to warnings. This way, we don’t run into the issue this PR aims to tackle. Until and unless there’s some specific init scheme the ControlNet needs to follow when |
|
the code was a lot more readable before because you could immediately understand Controlnet's structure just by reading through these few lines of code. |
|
I don’t think it’s a necessary case because |
|
thanks for your feedback! In general, I'm not a fan of @laksjdjf can you match the existing code and only fix the bug in this PR? thanks. Will merge once the tests pass |
|
What do you think of this code? for name, module in unet.named_children():
if hasattr(controlnet, name):
getattr(controlnet, name).load_state_dict(module.state_dict()) |
|
I don't see any reason to make this change. I prefer to be explicit when the trade-off isn't too big (e.g. when you need to write tons of code otherwise, not the case here) |
d349af4 to
99734a3
Compare
|
@laksjdjf any reason for closing? |
I made a mistake with my GitHub operation, so may I recreate the PR? |
|
Yes sure. Please tag @yiyixuxu and me in the new PR. Many thanks for your patience and contributions. |
What does this PR do?
The weights of
add_embeddingwas not loaded byfrom_unet.I checked it with the following code.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.