Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a ControlNet model & pipeline #2407

Merged
merged 135 commits into from Mar 2, 2023

Conversation

takuma104
Copy link
Contributor

@takuma104 takuma104 commented Feb 17, 2023

ControlNet by @lllyasviel is a neural network structure to control diffusion models by adding extra conditions. Discussed in #2331.

fig_diffusers_controlnet

Usage Example

TODO:

  • Conflict fixes
  • Add utility function to generate control embedding: controlnet_hinter
  • Docstring fixes
  • Unit test fixes and additions to acceptable
  • Test with other ControlNet models (probably subjective comparative evaluation with references): result report
    [ ] (time permitting) Investigate why pixel matches are not perfect

Discussion:

  1. How to provide users with functions to create control embedding? The reference implementation is highly dependent on other libraries, and porting it without thought will increase the number of dependent libraries for Diffusers. From my observations, the most common use case is not Canny Edge conversion, but rather inputting a similar line drawing image (for example), or generating an OpenPose compatible image (for example) and then inputting that image. In such use cases, it may be acceptable to prepare only a minimal conversion from the image.
    -- one idea: controlnet_hinter
  2. There is little difference between StableDiffusionPipeline and the (newly added) StableDiffusionControlNetPipeline. I think it would be preferable to change the behavior of StableDiffusionPipeline instead of creating a new StableDiffusionControlNetPipeline.
  3. How much test code is required? Diffusers does a lot of coverage testing, which is great, but I'm having a hard time physically doing this amount of coding. I would appreciate it if someone could help me.
  4. I currently have the converted models on my Huggingface hub, which I think is more appropriate for the @lllyasviel account. I would appreciate it if Huggingface team could support on release.

@williamberman @apolinario @xvjiarui

- copied convert_controlnet_to_diffusers.py from
convert_original_stable_diffusion_to_diffusers.py
- this makes Missking Key error on ControlNetModel
- init impl for StableDiffusionControlNetPipeline
- init impl for ControlNetModel
from create_unet_diffusers_config()

- add config: hint_channels
middle_block_out
- this makes missing key error on loading model
- copied from unet_2d_blocks.py as impl CrossAttnDownBlock2D,DownBlock2D
- this makes missing key error on loading model
and middle_block_out

- this makes no error message on model loading
- test_controlled_unet_inference passed
- test_modules_controllnet.py passed
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 17, 2023

The documentation is not available anymore as the PR was closed or merged.

@williamberman
Copy link
Contributor

phenomenal @takuma104! let's collaborate on this PR :)

@williamberman williamberman self-assigned this Feb 18, 2023
@takuma104
Copy link
Contributor Author

@williamberman Thanks! Let's brush up to better code together.

@patrickvonplaten
Copy link
Contributor

Alright, I think we can merge this PR now 🥳

After some testing, I would recommend to use the following settings for ControlNet:

# !pip install opencv-python transformers accelerate
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import numpy as np
import torch

import cv2
from PIL import Image

# download an image
image = load_image(
    "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
)
image = np.array(image)

# get canny image
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)

# load control net and stable diffusion v1-5
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)

# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed
pipe.enable_xformers_memory_efficient_attention()

pipe.enable_model_cpu_offload()

# generate image
generator = torch.manual_seed(0)
image = pipe(
    "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
).images[0]

I've also updated the official example to this code snippet. The advantages:

  • We use the currently fastest sampler (UniPCMultistepScheduler) which usually only need 20-25 steps
  • We use smart model offloading which doesn't hurt speed performance really
  • We use memory efficient attention a.k.a x-formers.

=> This allows us to run the code in ~3 seconds on a V100 with max 4.5GB RAM (< 4GB VRAM). Think this should make the model very useable for the community!

Amazing work @takuma104 - it's super cool that you led the integration here and iterated with us on the design choices.


*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.*

This model was contributed by the amazing community contributor [takuma104](https://huggingface.co/takuma104) ❤️ .
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@takuma104 - left a comment here for future readers for credit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten Thank you so much! It is a great honor!

@patrickvonplaten
Copy link
Contributor

@takuma104 , I've changed the official checkpoints to the ones that reside in the ControlNet's authors repo: https://huggingface.co/lllyasviel instead of using the ones of your namespace - I hope this is ok! I've made sure to mark down here: https://github.com/huggingface/diffusers/pull/2407/files#r1123177670 that you contributed the model.
I think it's better to just have the controlnet weights in the individual models to make it trivial to switch out different controlnets with SD-v1-5, e.g.:

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch

controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)

# load new controlnet
pipe.controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-normal", torch_dtype=torch.float16)

Also, we'll make sure to link to your Twitter - ok to use this handle: https://twitter.com/takuma104 ?

Great job again on getting this done!

@patrickvonplaten patrickvonplaten merged commit 8dfff7c into huggingface:main Mar 2, 2023
@takuma104
Copy link
Contributor Author

@patrickvonplaten

I've changed the official checkpoints to the ones that reside in the ControlNet's authors repo: https://huggingface.co/lllyasviel instead of using the ones of your namespace - I hope this is ok!

Yes, that was my intention from the beginning. I'm glad it turned out as hoped. I think there is still a link to my Colab in the documentation (although it can be removed if necessary), so once that is rewritten, I plan to make my takuma104/control_sd_xxx private.

I think it's better to just have the controlnet weights in the individual models to make it trivial to switch out different controlnets with SD-v1-5, e.g.:

I agree. I think it's a very good method. It has become easier to understand than the previous method, and there is no unnecessary loading.

Also, we'll make sure to link to your Twitter - ok to use this handle: https://twitter.com/takuma104 ?

I'm no longer using that Twitter account, so if possible, I would like to link to my GitHub account instead. https://github.com/takuma104

@takuma104
Copy link
Contributor Author

takuma104 commented Mar 2, 2023

To the Diffusers team, thank you! I believe that Diffusers is one of the most high-quality libraries among OSS. I realized that your attention to detail is what supports it. It was a very inspiring experience. I would be happy to help with anything in the future.


pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe = pipe.to(torch_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not move to CUDA when enabling CPU offload cc @williamberman

@@ -362,19 +362,18 @@ def test_seg(self):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/house_seg_out.npy"
)

assert np.abs(expected_image - image).max() < 1e-4
assert np.abs(expected_image - image).max() < 5e-3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-4 won't pass on CI

@takuma104
Copy link
Contributor Author

I have updated the Colab notebook for the merged version in main. It also serves as a simple test for the functionality, and everything works fine. It is based on the sample code at https://huggingface.co/lllyasviel/sd-controlnet-canny

@patrickvonplaten While I would like to keep Colab for the convenience of users, I think it's better for me to let it go. Could you please copy the Colab file and use it as a separate Colab notebook?
I had trouble installing human_pose with pip install (I think it's still a work in progress), so I have included controlnet_hinter, but it's okay to replace it with human_pose or others.

@patrickvonplaten
Copy link
Contributor

I've actually changed the libraries name and made it a pypi package:
https://github.com/patrickvonplaten/controlnet_aux

@yaqing01
Copy link

yaqing01 commented Mar 3, 2023

hi,

I have some troubles while doing backpropagation under controlnet (torch==1.13.1):

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.HalfTensor [24, 1280, 8, 8]], which is output 0 of AddBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

it's solved when I changed all the inplace operations in controlnet.py and unet_2d_condition.py from x += y to x = x + y. I didn't know whether it is a bug or not since I am new to pytorch.

@Suhail
Copy link

Suhail commented Mar 5, 2023

Does anyone know how to do the control net mode conversion correctly or can point me to where the right directory to dl the necessary files?

I tried doing:

python scripts/convert_original_stable_diffusion_to_diffusers.py --controlnet --checkpoint_path ../models/control_sd15_canny.pth --dump_path ../models/control_sd15_canny --device cpu --original_config_file ./cldm_v15.yaml

CleanShot 2023-03-05 at 13 10 10

It worked but I only got this:
CleanShot 2023-03-05 at 13 11 01

I am trying to attempt this: https://github.com/haofanwang/ControlNet-for-Diffusers#controlnet--inpainting but diffusers is already out of date with that README

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 5, 2023

@Suhail I think if you pass --controlnet it will only save the controlnet element

cc @williamberman should know more

@Suhail
Copy link

Suhail commented Mar 5, 2023

Got it. Ok that somewhat solved my problem.

I was able to run this instead (without --controlnet):

python scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path ../models/control_sd15_canny.pth --dump_path ../models/control_sd15_canny --device cpu --original_config_file ./cldm_v15.yaml

Getting this dir:
CleanShot 2023-03-05 at 14 26 51

However I am still met with this error:
CleanShot 2023-03-05 at 14 26 35

Full code here: https://github.com/haofanwang/ControlNet-for-Diffusers/blob/main/pipeline_stable_diffusion_controlnet_inpaint_img2img.py#L850

Ah I see something got rearchitected I guess:

Old code:
CleanShot 2023-03-05 at 14 46 28

New code:
CleanShot 2023-03-05 at 15 18 30

Updating that gets this error:
CleanShot 2023-03-05 at 15 07 41

I am not totally sure what in_channels to set here I guess.

@Suhail
Copy link

Suhail commented Mar 5, 2023

I got it working! Yay. I had a small bug with latent_mode_input in the wrong location:
CleanShot 2023-03-05 at 15 26 50

@williamberman
Copy link
Contributor

@Suhail https://github.com/haofanwang/ControlNet-for-Diffusers/ looks like it was based off of an older commit of this PR.

It might be easiest to work off the version currently in main and using the already converted controlnet models i.e. https://huggingface.co/lllyasviel/sd-controlnet-canny

I haven't looked deeply at the other repo but if there's a particular feature you want from it please let us know and I can see if we can add the feature directly in diffusers :)

@Suhail
Copy link

Suhail commented Mar 6, 2023

Hey @williamberman I am looking to do img2img with inpainting which I think diffusers doesn't quite support. Someone at HF encouraged me to look at the https://github.com/haofanwang/ControlNet-for-Diffusers/ repo as a result

@williamberman
Copy link
Contributor

Hey @williamberman I am looking to do img2img with inpainting which I think diffusers doesn't quite support. Someone at HF encouraged me to look at the https://github.com/haofanwang/ControlNet-for-Diffusers/ repo as a result

ah ok, not sure if we'll officially support that but I can throw together a community pipeline for you :)

@williamberman
Copy link
Contributor

@Suhail we have a PR here #2561 for the community pipelines

mengfei25 pushed a commit to mengfei25/diffusers that referenced this pull request Mar 27, 2023
* add scaffold
- copied convert_controlnet_to_diffusers.py from
convert_original_stable_diffusion_to_diffusers.py

* Add support to load ControlNet (WIP)
- this makes Missking Key error on ControlNetModel

* Update to convert ControlNet without error msg
- init impl for StableDiffusionControlNetPipeline
- init impl for ControlNetModel

* cleanup of commented out

* split create_controlnet_diffusers_config()
from create_unet_diffusers_config()

- add config: hint_channels

* Add input_hint_block, input_zero_conv and
middle_block_out
- this makes missing key error on loading model

* add unet_2d_blocks_controlnet.py
- copied from unet_2d_blocks.py as impl CrossAttnDownBlock2D,DownBlock2D
- this makes missing key error on loading model

* Add loading for input_hint_block, zero_convs
and middle_block_out

- this makes no error message on model loading

* Copy from UNet2DConditionalModel except __init__

* Add ultra primitive test for ControlNetModel
inference

* Support ControlNetModel inference
- without exceptions

* copy forward() from UNet2DConditionModel

* Impl ControlledUNet2DConditionModel inference
- test_controlled_unet_inference passed

* Frozen weight & biases for training

* Minimized version of ControlNet/ControlledUnet
- test_modules_controllnet.py passed

* make style

* Add support model loading for minimized ver

* Remove all previous version files

* from_pretrained and inference test passed

* copied from pipeline_stable_diffusion.py
except `__init__()`

* Impl pipeline, pixel match test (almost) passed.

* make style

* make fix-copies

* Fix to add import ControlNet blocks
for `make fix-copies`

* Remove einops dependency

* Support  np.ndarray, PIL.Image for controlnet_hint

* set default config file as lllyasviel's

* Add support grayscale (hw) numpy array

* Add and update docstrings

* add control_net.mdx

* add control_net.mdx to toctree

* Update copyright year

* Fix to add PIL.Image RGB->BGR conversion
- thanks @Mystfit

* make fix-copies

* add basic fast test for controlnet

* add slow test for controlnet/unet

* Ignore down/up_block len check on ControlNet

* add a copy from test_stable_diffusion.py

* Accept controlnet_hint is None

* merge pipeline_stable_diffusion.py diff

* Update class name to SDControlNetPipeline

* make style

* Baseline fast test almost passed (w long desc)

* still needs investigate.

Following didn't passed descriped in TODO comment:
- test_stable_diffusion_long_prompt
- test_stable_diffusion_no_safety_checker

Following didn't passed same as stable_diffusion_pipeline:
- test_attention_slicing_forward_pass
- test_inference_batch_single_identical
- test_xformers_attention_forwardGenerator_pass
these seems come from calc accuracy.

* Add note comment related vae_scale_factor

* add test_stable_diffusion_controlnet_ddim

* add assertion for vae_scale_factor != 8

* slow test of pipeline almost passed
Failed: test_stable_diffusion_pipeline_with_model_offloading
- ImportError: `enable_model_offload` requires `accelerate v0.17.0` or higher

but currently latest version == 0.16.0

* test_stable_diffusion_long_prompt passed

* test_stable_diffusion_no_safety_checker passed

- due to its model size, move to slow test

* remove PoC test files

* fix num_of_image, prompt length issue add add test

* add support List[PIL.Image] for controlnet_hint

* wip

* all slow test passed

* make style

* update for slow test

* RGB(PIL)->BGR(ctrlnet) conversion

* fixes

* remove manual num_images_per_prompt test

* add document

* add `image` argument docstring

* make style

* Add line to correct conversion

* add controlnet_conditioning_scale (aka control_scales
strength)

* rgb channel ordering by default

* image batching logic

* Add control image descriptions for each checkpoint

* Only save controlnet model in conversion script

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

typo

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* add gerated image example

* a depth mask -> a depth map

* rename control_net.mdx to controlnet.mdx

* fix toc title

* add ControlNet abstruct and link

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Co-authored-by: dqueue <dbyqin@gmail.com>

* remove controlnet constructor arguments re: @patrickvonplaten

* [integration tests] test canny

* test_canny fixes

* [integration tests] test_depth

* [integration tests] test_hed

* [integration tests] test_mlsd

* add channel order config to controlnet

* [integration tests] test normal

* [integration tests] test_openpose test_scribble

* change height and width to default to conditioning image

* [integration tests] test seg

* style

* test_depth fix

* [integration tests] size fixes

* [integration tests] cpu offloading

* style

* generalize controlnet embedding

* fix conversion script

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Style adapted to the documentation of pix2pix

* merge main by hand

* style

* [docs] controlling generation doc nits

* correct some things

* add: controlnetmodel to autodoc.

* finish docs

* finish

* finish 2

* correct images

* finish controlnet

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* uP

* upload model

* up

* up

---------

Co-authored-by: William Berman <WLBberman@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: dqueue <dbyqin@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
jackalcooper pushed a commit to siliconflow/onediff that referenced this pull request Mar 27, 2023
还有一些问题需要讨论和确定

1.  为支持Control Net需要的升级:
- huggingface官方的 [diffusers
0.14.0](https://github.com/huggingface/diffusers/releases/tag/v0.14.0)
支持了Control Net pipeline。所以我们要支持的话也需要安装0.14.0版本的diffusers `python3 -m pip
install "transformers>=4.26" "diffusers[torch]==0.14.0"`

- canny example 中为了检测Canny edge用到了 cv2,需要安装 `pip install
opencv-contrib-python`

2.  可能还需要测试升级diffusers版本后对其他pipeline的影响

3.  第一次提PR不知还有什么步骤需要完善

### Reference
-
https://huggingface.co/docs/diffusers/v0.14.0/en/api/pipelines/stable_diffusion/controlnet

- huggingface/diffusers#2407
@sayakpaul sayakpaul mentioned this pull request Apr 29, 2023
2 tasks
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add scaffold
- copied convert_controlnet_to_diffusers.py from
convert_original_stable_diffusion_to_diffusers.py

* Add support to load ControlNet (WIP)
- this makes Missking Key error on ControlNetModel

* Update to convert ControlNet without error msg
- init impl for StableDiffusionControlNetPipeline
- init impl for ControlNetModel

* cleanup of commented out

* split create_controlnet_diffusers_config()
from create_unet_diffusers_config()

- add config: hint_channels

* Add input_hint_block, input_zero_conv and
middle_block_out
- this makes missing key error on loading model

* add unet_2d_blocks_controlnet.py
- copied from unet_2d_blocks.py as impl CrossAttnDownBlock2D,DownBlock2D
- this makes missing key error on loading model

* Add loading for input_hint_block, zero_convs
and middle_block_out

- this makes no error message on model loading

* Copy from UNet2DConditionalModel except __init__

* Add ultra primitive test for ControlNetModel
inference

* Support ControlNetModel inference
- without exceptions

* copy forward() from UNet2DConditionModel

* Impl ControlledUNet2DConditionModel inference
- test_controlled_unet_inference passed

* Frozen weight & biases for training

* Minimized version of ControlNet/ControlledUnet
- test_modules_controllnet.py passed

* make style

* Add support model loading for minimized ver

* Remove all previous version files

* from_pretrained and inference test passed

* copied from pipeline_stable_diffusion.py
except `__init__()`

* Impl pipeline, pixel match test (almost) passed.

* make style

* make fix-copies

* Fix to add import ControlNet blocks
for `make fix-copies`

* Remove einops dependency

* Support  np.ndarray, PIL.Image for controlnet_hint

* set default config file as lllyasviel's

* Add support grayscale (hw) numpy array

* Add and update docstrings

* add control_net.mdx

* add control_net.mdx to toctree

* Update copyright year

* Fix to add PIL.Image RGB->BGR conversion
- thanks @Mystfit

* make fix-copies

* add basic fast test for controlnet

* add slow test for controlnet/unet

* Ignore down/up_block len check on ControlNet

* add a copy from test_stable_diffusion.py

* Accept controlnet_hint is None

* merge pipeline_stable_diffusion.py diff

* Update class name to SDControlNetPipeline

* make style

* Baseline fast test almost passed (w long desc)

* still needs investigate.

Following didn't passed descriped in TODO comment:
- test_stable_diffusion_long_prompt
- test_stable_diffusion_no_safety_checker

Following didn't passed same as stable_diffusion_pipeline:
- test_attention_slicing_forward_pass
- test_inference_batch_single_identical
- test_xformers_attention_forwardGenerator_pass
these seems come from calc accuracy.

* Add note comment related vae_scale_factor

* add test_stable_diffusion_controlnet_ddim

* add assertion for vae_scale_factor != 8

* slow test of pipeline almost passed
Failed: test_stable_diffusion_pipeline_with_model_offloading
- ImportError: `enable_model_offload` requires `accelerate v0.17.0` or higher

but currently latest version == 0.16.0

* test_stable_diffusion_long_prompt passed

* test_stable_diffusion_no_safety_checker passed

- due to its model size, move to slow test

* remove PoC test files

* fix num_of_image, prompt length issue add add test

* add support List[PIL.Image] for controlnet_hint

* wip

* all slow test passed

* make style

* update for slow test

* RGB(PIL)->BGR(ctrlnet) conversion

* fixes

* remove manual num_images_per_prompt test

* add document

* add `image` argument docstring

* make style

* Add line to correct conversion

* add controlnet_conditioning_scale (aka control_scales
strength)

* rgb channel ordering by default

* image batching logic

* Add control image descriptions for each checkpoint

* Only save controlnet model in conversion script

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

typo

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* add gerated image example

* a depth mask -> a depth map

* rename control_net.mdx to controlnet.mdx

* fix toc title

* add ControlNet abstruct and link

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Co-authored-by: dqueue <dbyqin@gmail.com>

* remove controlnet constructor arguments re: @patrickvonplaten

* [integration tests] test canny

* test_canny fixes

* [integration tests] test_depth

* [integration tests] test_hed

* [integration tests] test_mlsd

* add channel order config to controlnet

* [integration tests] test normal

* [integration tests] test_openpose test_scribble

* change height and width to default to conditioning image

* [integration tests] test seg

* style

* test_depth fix

* [integration tests] size fixes

* [integration tests] cpu offloading

* style

* generalize controlnet embedding

* fix conversion script

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Style adapted to the documentation of pix2pix

* merge main by hand

* style

* [docs] controlling generation doc nits

* correct some things

* add: controlnetmodel to autodoc.

* finish docs

* finish

* finish 2

* correct images

* finish controlnet

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* uP

* upload model

* up

* up

---------

Co-authored-by: William Berman <WLBberman@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: dqueue <dbyqin@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add scaffold
- copied convert_controlnet_to_diffusers.py from
convert_original_stable_diffusion_to_diffusers.py

* Add support to load ControlNet (WIP)
- this makes Missking Key error on ControlNetModel

* Update to convert ControlNet without error msg
- init impl for StableDiffusionControlNetPipeline
- init impl for ControlNetModel

* cleanup of commented out

* split create_controlnet_diffusers_config()
from create_unet_diffusers_config()

- add config: hint_channels

* Add input_hint_block, input_zero_conv and
middle_block_out
- this makes missing key error on loading model

* add unet_2d_blocks_controlnet.py
- copied from unet_2d_blocks.py as impl CrossAttnDownBlock2D,DownBlock2D
- this makes missing key error on loading model

* Add loading for input_hint_block, zero_convs
and middle_block_out

- this makes no error message on model loading

* Copy from UNet2DConditionalModel except __init__

* Add ultra primitive test for ControlNetModel
inference

* Support ControlNetModel inference
- without exceptions

* copy forward() from UNet2DConditionModel

* Impl ControlledUNet2DConditionModel inference
- test_controlled_unet_inference passed

* Frozen weight & biases for training

* Minimized version of ControlNet/ControlledUnet
- test_modules_controllnet.py passed

* make style

* Add support model loading for minimized ver

* Remove all previous version files

* from_pretrained and inference test passed

* copied from pipeline_stable_diffusion.py
except `__init__()`

* Impl pipeline, pixel match test (almost) passed.

* make style

* make fix-copies

* Fix to add import ControlNet blocks
for `make fix-copies`

* Remove einops dependency

* Support  np.ndarray, PIL.Image for controlnet_hint

* set default config file as lllyasviel's

* Add support grayscale (hw) numpy array

* Add and update docstrings

* add control_net.mdx

* add control_net.mdx to toctree

* Update copyright year

* Fix to add PIL.Image RGB->BGR conversion
- thanks @Mystfit

* make fix-copies

* add basic fast test for controlnet

* add slow test for controlnet/unet

* Ignore down/up_block len check on ControlNet

* add a copy from test_stable_diffusion.py

* Accept controlnet_hint is None

* merge pipeline_stable_diffusion.py diff

* Update class name to SDControlNetPipeline

* make style

* Baseline fast test almost passed (w long desc)

* still needs investigate.

Following didn't passed descriped in TODO comment:
- test_stable_diffusion_long_prompt
- test_stable_diffusion_no_safety_checker

Following didn't passed same as stable_diffusion_pipeline:
- test_attention_slicing_forward_pass
- test_inference_batch_single_identical
- test_xformers_attention_forwardGenerator_pass
these seems come from calc accuracy.

* Add note comment related vae_scale_factor

* add test_stable_diffusion_controlnet_ddim

* add assertion for vae_scale_factor != 8

* slow test of pipeline almost passed
Failed: test_stable_diffusion_pipeline_with_model_offloading
- ImportError: `enable_model_offload` requires `accelerate v0.17.0` or higher

but currently latest version == 0.16.0

* test_stable_diffusion_long_prompt passed

* test_stable_diffusion_no_safety_checker passed

- due to its model size, move to slow test

* remove PoC test files

* fix num_of_image, prompt length issue add add test

* add support List[PIL.Image] for controlnet_hint

* wip

* all slow test passed

* make style

* update for slow test

* RGB(PIL)->BGR(ctrlnet) conversion

* fixes

* remove manual num_images_per_prompt test

* add document

* add `image` argument docstring

* make style

* Add line to correct conversion

* add controlnet_conditioning_scale (aka control_scales
strength)

* rgb channel ordering by default

* image batching logic

* Add control image descriptions for each checkpoint

* Only save controlnet model in conversion script

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

typo

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update docs/source/en/api/pipelines/stable_diffusion/control_net.mdx

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* add gerated image example

* a depth mask -> a depth map

* rename control_net.mdx to controlnet.mdx

* fix toc title

* add ControlNet abstruct and link

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py

Co-authored-by: dqueue <dbyqin@gmail.com>

* remove controlnet constructor arguments re: @patrickvonplaten

* [integration tests] test canny

* test_canny fixes

* [integration tests] test_depth

* [integration tests] test_hed

* [integration tests] test_mlsd

* add channel order config to controlnet

* [integration tests] test normal

* [integration tests] test_openpose test_scribble

* change height and width to default to conditioning image

* [integration tests] test seg

* style

* test_depth fix

* [integration tests] size fixes

* [integration tests] cpu offloading

* style

* generalize controlnet embedding

* fix conversion script

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Style adapted to the documentation of pix2pix

* merge main by hand

* style

* [docs] controlling generation doc nits

* correct some things

* add: controlnetmodel to autodoc.

* finish docs

* finish

* finish 2

* correct images

* finish controlnet

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* uP

* upload model

* up

* up

---------

Co-authored-by: William Berman <WLBberman@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: dqueue <dbyqin@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet