Skip to content

Conversation

@laksjdjf
Copy link
Contributor

@laksjdjf laksjdjf commented Mar 1, 2024

What does this PR do?

The weights of add_embedding was not loaded by from_unet.
I checked it with the following code.

from diffusers import UNet2DConditionModel, ControlNetModel

unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet")
controlnet = ControlNetModel.from_unet(unet)
for k, v in controlnet.state_dict().items():
    if k in unet.state_dict():
        diff = (unet.state_dict()[k] - v).abs().mean().item()
        if diff != 0:
            print(k, diff)
>add_embedding.linear_1.weight 0.016456469893455505
>add_embedding.linear_1.bias 0.017061714082956314
>add_embedding.linear_2.weight 0.017966963350772858
>add_embedding.linear_2.bias 0.01857309229671955

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh thanks!

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 3, 2024

has bug this been affecting controlnet training?

cc @sayakpaul here

@sayakpaul
Copy link
Member

Could have affected, but I don't have any empirical evidence of that.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice.

I wonder why can't we do something like:

controlnet.load_state_dict(unet.state_dict(), strict=False)

@laksjdjf
Copy link
Contributor Author

laksjdjf commented Mar 4, 2024

Looks very nice.

I wonder why can't we do something like:

controlnet.load_state_dict(unet.state_dict(), strict=False)

I've adjusted it to match the existing code, but your code seems to be better.

@sayakpaul
Copy link
Member

sayakpaul commented Mar 4, 2024

I've adjusted it to match the existing code, but your code seems to be better.

Do you want to give it a try? We can revisit later as well if that's what you prefer.

@laksjdjf
Copy link
Contributor Author

laksjdjf commented Mar 4, 2024

The code has been modified and confirmed to load correctly in SDXL and SD1.5.

@sayakpaul
Copy link
Member

@yiyixuxu WDYT?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 9, 2024

controlnet.load_state_dict(unet.state_dict(), strict=False)

not a fan of this because I think warnings will be confusing

@sayakpaul
Copy link
Member

strict=False doesn’t lead to warnings. This way, we don’t run into the issue this PR aims to tackle. Until and unless there’s some specific init scheme the ControlNet needs to follow when from_unet is used, I think this is a simpler and more elegant solution.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 9, 2024

the code was a lot more readable before because you could immediately understand Controlnet's structure just by reading through these few lines of code.

@sayakpaul
Copy link
Member

I don’t think it’s a necessary case because from_unet is used for initialisation. It’s not meant to serve users for understanding ControlNet blocks.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 9, 2024

thanks for your feedback! In general, I'm not a fan of strict=False - I think it's a bit too hacky and less readable.

@laksjdjf can you match the existing code and only fix the bug in this PR? thanks. Will merge once the tests pass

@laksjdjf
Copy link
Contributor Author

@yiyixuxu

What do you think of this code?

for name, module in unet.named_children():
    if hasattr(controlnet, name):
        getattr(controlnet, name).load_state_dict(module.state_dict())

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 10, 2024

@laksjdjf

I don't see any reason to make this change. I prefer to be explicit when the trade-off isn't too big (e.g. when you need to write tons of code otherwise, not the case here)

@laksjdjf laksjdjf closed this Mar 11, 2024
@laksjdjf laksjdjf deleted the laksjdjf-patch-1 branch March 11, 2024 01:20
@sayakpaul
Copy link
Member

@laksjdjf any reason for closing?

@laksjdjf
Copy link
Contributor Author

@laksjdjf any reason for closing?

I made a mistake with my GitHub operation, so may I recreate the PR?

@sayakpaul
Copy link
Member

Yes sure. Please tag @yiyixuxu and me in the new PR. Many thanks for your patience and contributions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants