Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix Issue #1197 since 2022] Support pre-trained openai/guided-diffusion (ADM) with minimal code change #6730

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

tongdaxu
Copy link
Contributor

@tongdaxu tongdaxu commented Jan 27, 2024

What does this PR do?

TLDR: This PR includes openai/guided-diffusion pre-trained models into diffusers, with a diverse family of pre-trained openai/guided-diffusion models, including pre-trained model with ImageNet, LSUN bedroom, LSUN cat, LSUN horse and FFHQ dataset.

The openai/guided-diffusion pre-trained models (ADM unconditional 256x256) are used by academic community a lot (e.g. DPS ICLR 2023: https://github.com/DPS2022/diffusion-posterior-sampling, FreeDOM ICCV 2023: https://github.com/vvictoryuki/FreeDoM?tab=readme-ov-file). However, it is not supported in huggingface/diffusers.

This issue has been raised as early as 2022: #1197 but left unsolved. As it is indeed quite complicated.

I make changes to UNet2DModel as minimal as possible, to make it

Those changes includes:

  • Interface of UNet2DModel: add new argument "attention_legacy_order"
  • Two necessary building blocks, including:
    • ADM's time_proj, theoretically they are the same as diffusers implementation, but numerically they are different. Replacing one with anothor breaks the model.
    • attention_legacy_order mode in class Attention: the legacy order mode is necessary. use diffusers attention directly breaks the model.

I have been very careful not to break any existing code, and make the new code as short as possible.

I have provided a script to convert pre-trained openai/guided-diffusion to huggingface compatible model, in https://github.com/tongdaxu/diffusers/blob/main/scripts/convert_adm_to_diffusers.py.

I have also provide my conversion of models with configs. Those conversions have mean absolute error $~5e-5$, and relative absolute error $~6e-5$, when the input noise is the same. As the error is minimal, the model and conversion is correct. The complete list of converted models is:

Now we can sample from the pre-trained models of openai/guided-diffusion, using diffusers in an out of box way.

from diffusers import DiffusionPipeline

generator = DiffusionPipeline.from_pretrained("xutongda/adm_imagenet_256x256_unconditional").to("cuda")
image = generator().images[0]
image.save("generated_image.png")

And the result is as good as the original openai/guided-diffusion model:

  • sample in diffusers with converted model
    sample
  • sample in openai/guided-diffusion with original model
    sample_adm

Before submitting

Who can review?

@patrickvonplaten @yiyixuxu and @sayakpaul

@tongdaxu tongdaxu changed the title [Community] Support pre-trained openai/guided-diffusion (ADM) with minimal code change [Fix Issue #1197 since 2022] Support pre-trained openai/guided-diffusion (ADM) with minimal code change Jan 27, 2024
@tongdaxu
Copy link
Contributor Author

tongdaxu commented Jan 28, 2024

Some other samples using the converted model with diffusers:

  • Samples from LSUN cat model
from diffusers import DiffusionPipeline

generator = DiffusionPipeline.from_pretrained("xutongda/adm_lsun_cat_256x256").to("cuda")
image = generator().images[0]
image.save("generated_image.png")

generated_image_cat

  • Samples from FFHQ model
from diffusers import DiffusionPipeline

generator = DiffusionPipeline.from_pretrained("xutongda/adm_ffhq_256x256").to("cuda")
image = generator().images[0]
image.save("generated_image.png")

generated_image_ffhq

  • Samples from LSUN horse model
from diffusers import DiffusionPipeline

generator = DiffusionPipeline.from_pretrained("xutongda/adm_lsun_horse_256x256").to("cuda")
image = generator().images[0]
image.save("generated_image.png")

generated_image_horse

  • Samples from LSUN bedroom model
from diffusers import DiffusionPipeline

generator = DiffusionPipeline.from_pretrained("xutongda/adm_lsun_bedroom_256x256").to("cuda")
image = generator().images[0]
image.save("generated_image.png")

generated_image_bedroom

@sayakpaul
Copy link
Member

Thanks very much for your work on this.

I agree that ADM is still very much used by the academic community but probably doesn't have a lot of real-world significance because of the lower quality. On the other hand, we do support Consistency Models as well as the original DDPM and DDIM models to respect the literature.

So, given the above point and also considering the minimal changes introduced in this PR, I'd be supportive of adding it. My only major feedback would be to try to not use legacy attention blocks if possible.

@patrickvonplaten @yiyixuxu WDYT here?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@tongdaxu
Copy link
Contributor Author

Thanks very much for your work on this.

I agree that ADM is still very much used by the academic community but probably doesn't have a lot of real-world significance because of the lower quality. On the other hand, we do support Consistency Models as well as the original DDPM and DDIM models to respect the literature.

So, given the above point and also considering the minimal changes introduced in this PR, I'd be supportive of adding it. My only major feedback would be to try to not use legacy attention blocks if possible.

@patrickvonplaten @yiyixuxu WDYT here?

The problem here is that all the offical pre-trained ADM by openai use legacy attention, so I really have no choice but using them. I have tried to use diffuser attention but the model produces garbage images like (suppose to be bedroom):

generated_image_bedroom_bad

@sayakpaul
Copy link
Member

The problem here is that all the offical pre-trained ADM by openai use legacy attention, so I really have no choice but using them.

Can we try to maybe find a way to port the legacy attention to the one that's used now?

@tongdaxu
Copy link
Contributor Author

The problem here is that all the offical pre-trained ADM by openai use legacy attention, so I really have no choice but using them.

Can we try to maybe find a way to port the legacy attention to the one that's used now?

Sorry, I did not quite get what you mean by "port". Did you mean to create a separate class like legacy attention, and use arguement like attention_type?

@sayakpaul
Copy link
Member

@tongdaxu
Copy link
Contributor Author

See my comment here https://github.com/huggingface/diffusers/pull/6730/files#r1468858192

In fact, the part you refer to is about model conversion only, and I have already done it by calling the code of https://github.com/tongdaxu/diffusers/blob/main/scripts/convert_consistency_to_diffusers.py#L143. In this way, we can indeed unify the model weights of openai/guideddiffusion and diffusers. However, it has nothing to do with legacy / non legacy attention order. It is purely about the parameterization of linear layers.

However, what can not be avoided is the run time difference between legacy and non legacy attention. The "qkv, q, k, v" you are referring to are model weights, the "qkv, q, k, v" I am referring to are activation tensors. They are different stuffs with different shape.

@tongdaxu
Copy link
Contributor Author

See my comment here https://github.com/huggingface/diffusers/pull/6730/files#r1468858192

In openai/guideddiffusion, both normal attention and legacy attention are implemented in separate class:

Those two attentions have exactly the same model weights, and the two classes have no parameters at all. So it has nothing to do with clever tricks in model conversion. This has to be solved in runtime.

@tongdaxu
Copy link
Contributor Author

Hi @sayakpaul, any further comments?

output_scale_factor: float = 1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention

if attn_groups is None:
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
attn_groups = None if resnet_time_scale_shift == "spatial" else resnet_groups
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will break for resnet_time_scale_shift = "scale_shift", no? Even though the doc string says only accept "default" and "spatial" I'm not sure if no model has configured with resnet_time_scale_shift = scale_shift

I want to know your thoughts here @sayakpaul

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case is that when actually resnet_time_scale_shift = "scale_shift", the original code will produce attn_groups = None, which nullify the arguement resnet_groups. And this behaviour of UNetMidBlock2D is in-consistent with other up and down classes such as AttnDownBlock2D and ResnetDownsampleBlock2D. The incoming code will produce attn_groups = resnet groups, which make the attention layer of mid block consistent with up and down block.

I am just following the doc string here. But I think it is better to have mid block behaviour consistent with up and down block.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe come back with a better condition that doesn't leave any grounds for a potentially backward-breaking change? I agree with @yiyixuxu here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe come back with a better condition that doesn't leave any grounds for a potentially backward-breaking change? I agree with @yiyixuxu here.

Cool, let me see if I can solve it in some other way. Would you prefer the current solution, or a slightly more ugly one like a new class of UNetMiddleBlock? In fact, I can easy make it free of backward-breaking, but I might introduce more ugly additional blocks with more code.

@tongdaxu
Copy link
Contributor Author

tongdaxu commented Feb 3, 2024

Hi, I have found a way to avoid breaking the possible backward deps in class UNetMidBlock2D and updated the PR. The change is still minimal but it does not break anything.

I would love your advices @yiyixuxu @sayakpaul.

"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show me the codepath that becomes effective when resnet_time_scale_shift == "scale_shift"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convert_adm_to_diffusers.py is just calling functions of con_pt_to_diffuser function in convert_consistency_to_diffusers.py.

First, the resnet_time_scale_shift == "scale_shift" is not a new option set by this PR.

Setting resnet_time_scale_shift == "scale_shift" will pass the argument though UNet2DModel. Through init of UNet2DModel, it will be pass into get_down_block, UNetMidBlock2D and get_up_block, and will be further passed into each building blocks such as ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, AttnDownBlock2D, AttnUpBlock2D. The eventual effect of resnet_time_scale_shift == "scale_shift" will set the class ResnetBlock2D's time_embedding_norm == "scale_shift". And this option effects the resnet's time embedding's shape https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py#L283, and the behaviour of time embedding https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py#L378.

It has no special effect on UNetMidBlock2D, as I circumvent this problem in https://github.com/tongdaxu/diffusers/blob/main/src/diffusers/models/unets/unet_2d.py#L200.

The resnet_time_scale_shift == "scale_shift" is necessary in model conversion script as the resnet's time embedding's input shape is doubled with resnet_time_scale_shift == "scale_shift".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really lean. I left a question.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
@patrickvonplaten
Copy link
Contributor

Let's add this model to the research_projects folder no? It's a bit too outdated to be in core diffusers I'd say (cc @yiyixuxu)

@yiyixuxu
Copy link
Collaborator

@tongdaxu
can we move this to the research folder?

@tongdaxu
Copy link
Contributor Author

@tongdaxu can we move this to the research folder?

I am fine with that, what should I do to move this to research folder?

@sayakpaul
Copy link
Member

I am fine with that, what should I do to move this to research folder?

You can follow the structure of https://github.com/huggingface/diffusers/tree/main/examples/research_projects/controlnetxs as an example. Here's what you could consider.

Have all the conversion script, modeling and pipeline files under a folder and make sure they work.

@kschwethelm
Copy link

Hi @tongdaxu, thank you for your great work!

I am having trouble generating nice images with your PR. I hope you can help :)

I installed the PR as follows:

  1. conda create -n newenv python=3.9
  2. conda activate newenv
  3. pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
  4. pip install git+https://github.com/huggingface/diffusers.git@refs/pull/6730/head

Then i ran:

from diffusers import DiffusionPipeline

generator = DiffusionPipeline.from_pretrained("xutongda/adm_imagenet_256x256_unconditional").to("cuda")
image = generator().images[0]
image.save("generated_image.png")

My generated images look like this:

generated_image

Did I install the PR wrong or is there a bug?

@tongdaxu
Copy link
Contributor Author

Hi @tongdaxu, thank you for your great work!

I am having trouble generating nice images with your PR. I hope you can help :)

I installed the PR as follows:

  1. conda create -n newenv python=3.9
  2. conda activate newenv
  3. pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
  4. pip install git+https://github.com/huggingface/diffusers.git@refs/pull/6730/head

Then i ran:

from diffusers import DiffusionPipeline

generator = DiffusionPipeline.from_pretrained("xutongda/adm_imagenet_256x256_unconditional").to("cuda")
image = generator().images[0]
image.save("generated_image.png")

My generated images look like this:

generated_image

Did I install the PR wrong or is there a bug?

I am out of office until 17 Feb, would you like to try other models first? And the hugging face model hub has been updated since the middle of commit. Are you using the latest model?

@tongdaxu
Copy link
Contributor Author

Hi @tongdaxu, thank you for your great work!

I am having trouble generating nice images with your PR. I hope you can help :)

I installed the PR as follows:

  1. conda create -n newenv python=3.9
  2. conda activate newenv
  3. pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
  4. pip install git+https://github.com/huggingface/diffusers.git@refs/pull/6730/head

Then i ran:

from diffusers import DiffusionPipeline

generator = DiffusionPipeline.from_pretrained("xutongda/adm_imagenet_256x256_unconditional").to("cuda")
image = generator().images[0]
image.save("generated_image.png")

My generated images look like this:

generated_image

Did I install the PR wrong or is there a bug?

Sorry I do not have access to GPU for now. But the instructions in https://github.com/tongdaxu/InverseDiffusion should work. Would you like to give it a try?

@kschwethelm
Copy link

kschwethelm commented Feb 14, 2024

Thank you for your quick response. Sadly, your instructions did not work either.

I tried all versions of your repository and different pretrained models, but I still get bad results, e.g., FFHQ:

generated_image

@tongdaxu
Copy link
Contributor Author

Thank you for your quick response. Sadly, your instructions did not work either.

I tried all versions of your repository and different pretrained models, but I still get bad results, e.g., FFHQ:

generated_image

I believe it is a bug on my side. It could be the last force push. I will fix it ASAP.

@tongdaxu
Copy link
Contributor Author

Thank you for your quick response. Sadly, your instructions did not work either.

I tried all versions of your repository and different pretrained models, but I still get bad results, e.g., FFHQ:

generated_image

Hi, I just ran a small test with tongdaxu@111eac1. It seems to be ok.

捕获

I am not sure what is happening here. And I might need more time figuring it out when I am back to office after 17 Feb.

I can only run some small sanity check for now. And all I can say is that with the commit above and the imagenet model (I check the hash sum), the sampling should be fine. I am not sure if there can be some dependency problem (I am using torch 2.1.0). I might need to go back to office for more testing.

Thanks for pointing it out and for your patience.

@kschwethelm
Copy link

Hi, thank you very much for your help! The problem was actually the PyTorch version. I now tried with torch-2.2.0 and it works fine.

generated_image

@tongdaxu
Copy link
Contributor Author

4. pip install git+https://github.com/huggingface/diffusers.git@refs/pull/6730/head

Thank you @kschwethelm, that is very strange. I also find that it fails with torch 1.9 and works with torch 2.1.

I do not have a clue about why it fails. Have you figured out why?

@tongdaxu
Copy link
Contributor Author

Hi, thank you very much for your help! The problem was actually the PyTorch version. I now tried with torch-2.2.0 and it works fine.

generated_image

I can't remember I have added any torch version sensitive code.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 10, 2024
@sayakpaul
Copy link
Member

Not stale. @yiyixuxu WDYT?

@github-actions github-actions bot removed the stale Issues that haven't received updates label Mar 11, 2024
Copy link

github-actions bot commented Apr 5, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 5, 2024
@jiangyuhangcn
Copy link

how to finetune with the adm pretrained model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants