Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] add: controlnet support for SDXL #4038

Merged
merged 54 commits into from
Jul 18, 2023
Merged

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jul 11, 2023

This PR adds support for ControlNets with SDXL. The two primary components being added to this PR:

  • Training script train_controlnet_sdxl.py.
  • Pipeline StableDiffusionXLControlNetPipeline (with changes to ControlNetModel to accommodate the pipeline-level changes).

However, these seems to be something weird going on here.

I first started training on a small subset of dataset (the circles dataset) with the following command:

export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9"
export OUTPUT_DIR="controlnet-sdxl-circles"

wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png

wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png

accelerate launch train_controlnet_sdxl.py \
 --pretrained_model_name_or_path=$MODEL_DIR \
 --output_dir=$OUTPUT_DIR \
 --dataset_name=fusing/fill50k \
 --mixed_precision="fp16" \
 --resolution=1024 \
 --learning_rate=5e-5 \
 --max_train_samples=500 \
 --max_train_steps=1000 \
 --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
 --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
 --validation_steps=25 \
 --train_batch_size=1 \
 --gradient_accumulation_steps=4 \
 --report_to="wandb" \
 --seed=42 \
 --push_to_hub

The trained checkpoints seem to only generate black images: https://huggingface.co/fusing/controlnet-sdxl-circles-fixed (only visible to the diffusers team members).

To further debug this, I tried:

from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
from diffusers.utils import load_image
import torch 

base_ckpt_id = "stabilityai/stable-diffusion-xl-base-0.9"
controlnet_ckpt_id = "controlnet-sdxl-circles-fixed"

controlnet = ControlNetModel.from_pretrained(
	controlnet_ckpt_id, subfolder="checkpoint-500/controlnet", torch_dtype=torch.float16
).to("cuda")

pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
	base_ckpt_id, controlnet=controlnet, torch_dtype=torch.float16
).to("cuda")


cond_image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png"
)
prompt = "red circle with blue background"

image = pipeline(prompt, image=cond_image).images[0]
image.save("controlnet@ckpt-500.png")

This doesn't generate the expected results (which is expected since the number of training steps is quite low) but doesn't generate all black images either.

@patrickvonplaten @williamberman could you take a deeper look here?

TODOs

  • tests
  • docs
  • misc changes

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 11, 2023

The documentation is not available anymore as the PR was closed or merged.

@gkorepanov
Copy link

The trained checkpoints seem to only generate black images:

Hi! Do you mean that in validation during training all images are black, but if you manually load trained checkpoint using external script, the images are fine?

@sayakpaul
Copy link
Member Author

Hi! Do you mean that in validation during training all images are black, but if you manually load trained checkpoint using external script, the images are fine?

Exactly.

@gkorepanov
Copy link

Exactly.

I think that might relate to SDXL VAE producing NANs in some cases with fp16 mode.

From https://github.com/kohya-ss/sd-scripts/tree/sdxl:

The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using fp16. The images will be black. Currently, the NaNs cannot be avoided even with --no_half_vae option. It works with bf16 or without mixed precision.

Also:
https://huggingface.co/stabilityai/sdxl-vae/discussions/6
https://huggingface.co/madebyollin/sdxl-vae-fp16-fix

@sayakpaul
Copy link
Member Author

Thanks for being willing to help.

I think the issue with VAE is handled. See: https://github.com/huggingface/diffusers/blob/db78a4cb4e3f105cbc7534890f606e25e906e23a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L1118C1-L1133C38.

Also, when I run the manual validation, it's in FP16 only.

@gkorepanov
Copy link

I think the issue with VAE is handled

Ah, really, seems so, thanks. BTW, in the code you mentioned there might be a small bug with unnecessary not which was recently fixed in the main branch: #4019

Also, to run your code, I had to put extra StableDiffusionXLControlNetPipeline imports in few places, I think you have forgotten to include few __init__.py into PR

@sayakpaul
Copy link
Member Author

@gkorepanov thanks so much for your catches. I incorporated the fixes. Let me run the dummy experiment one more time to check quickly.

@laksjdjf
Copy link
Contributor

laksjdjf commented Jul 11, 2023

Is it because autocast is used to generate the validation image?

with torch.autocast("cuda"):
image = pipeline(
validation_prompt, validation_image, num_inference_steps=20, generator=generator
).images[0]

kohya-ss's problems also seem to have been caused by autocast.
kohya-ss/sd-scripts@814996b#diff-5f48c8e976d43e587007dc13a34100a96621cdc5fbe083ee772e920855648722R3877

@patrickvonplaten
Copy link
Contributor

Cool! Let's make sure we have a working controlnet training run before merging this though :-)

@sayakpaul
Copy link
Member Author

Cool! Let's make sure we have a working controlnet training run before merging this though :-)

There's a working script in the PR. The issues described in the original post is why I am seeking reviews for.

@gkorepanov
Copy link

gkorepanov commented Jul 11, 2023

There's a working script in the PR. The issues described in the original post is why I am seeking reviews for.

After disabling autocast in validation and using torch.float32 when loading the pipeline the validation looks better (at least images are not black anymore):
image

But images seem to be awkward.

@gkorepanov
Copy link

gkorepanov commented Jul 11, 2023

But images seem to be awkward.

The difference was caused by different resolution in inference. By default, controlnet pipeline takes height/width from control image

@sayakpaul
Copy link
Member Author

@gkorepanov thanks again for your inputs! Very much appreciated.

Let me run a couple of experiments now.

@sayakpaul
Copy link
Member Author

@gkorepanov may I know which GPU model did you use for your tests? I am currently using a 40GB A100 and when I try log_validation() in FP32, it OOMs.

@sayakpaul
Copy link
Member Author

@patrickvonplaten thanks for all the reviews. A final review and I think we're good to go. Let me know.

@adhikjoshi
Copy link

Will existing controlnet 1.1 checkpoints work here?

@sayakpaul
Copy link
Member Author

Will existing controlnet 1.1 checkpoints work here?

No.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

@sayakpaul sayakpaul merged commit 3eb498e into main Jul 18, 2023
10 checks passed
@sayakpaul sayakpaul deleted the feat/sd-xl-controlnet-2 branch July 18, 2023 12:55
@sayakpaul
Copy link
Member Author

@gkorepanov we start a PR for adding switching support and MultiControlNet too since switching likely impacts that more.

Let me know :)

@sayakpaul
Copy link
Member Author

@williamberman #4188.

orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
* add: controlnet sdxl.

* modifications to controlnet.

* run styling.

* add: __init__.pys

* incorporate huggingface#4019 changes.

* run make fix-copies.

* resize the conditioning images.

* remove autocast.

* run styling.

* disable autocast.

* debugging

* device placement.

* back to autocast.

* remove comment.

* save some memory by reusing the vae and unet in the pipeline.

* apply styling.

* Allow low precision sd xl

* finish

* finish

* changes to accommodate the improved VAE.

* modifications to how we handle vae encoding in the training.

* make style

* make existing controlnet fast tests pass.

* change vae checkpoint cli arg.

* fix: vae pretrained paths.

* fix: steps in get_scheduler().

* debugging.

* debugging./

* fix: weight conversion.

* add: docs.

* add: limited tests./

* add: datasets to the requirements.

* update docstrings and incorporate the usage of watermarking.

* incorporate fix from huggingface#4083

* fix watermarking dependency handling.

* run make-fix-copies.

* Empty-Commit

* Update requirements_sdxl.txt

* remove vae upcasting part.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* run make style

* run make fix-copies.

* disable suppot for multicontrolnet.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* run make fix-copies.

* dtyle/.

* fix-copies.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
* add: controlnet sdxl.

* modifications to controlnet.

* run styling.

* add: __init__.pys

* incorporate huggingface#4019 changes.

* run make fix-copies.

* resize the conditioning images.

* remove autocast.

* run styling.

* disable autocast.

* debugging

* device placement.

* back to autocast.

* remove comment.

* save some memory by reusing the vae and unet in the pipeline.

* apply styling.

* Allow low precision sd xl

* finish

* finish

* changes to accommodate the improved VAE.

* modifications to how we handle vae encoding in the training.

* make style

* make existing controlnet fast tests pass.

* change vae checkpoint cli arg.

* fix: vae pretrained paths.

* fix: steps in get_scheduler().

* debugging.

* debugging./

* fix: weight conversion.

* add: docs.

* add: limited tests./

* add: datasets to the requirements.

* update docstrings and incorporate the usage of watermarking.

* incorporate fix from huggingface#4083

* fix watermarking dependency handling.

* run make-fix-copies.

* Empty-Commit

* Update requirements_sdxl.txt

* remove vae upcasting part.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* run make style

* run make fix-copies.

* disable suppot for multicontrolnet.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* run make fix-copies.

* dtyle/.

* fix-copies.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
* add: controlnet sdxl.

* modifications to controlnet.

* run styling.

* add: __init__.pys

* incorporate huggingface#4019 changes.

* run make fix-copies.

* resize the conditioning images.

* remove autocast.

* run styling.

* disable autocast.

* debugging

* device placement.

* back to autocast.

* remove comment.

* save some memory by reusing the vae and unet in the pipeline.

* apply styling.

* Allow low precision sd xl

* finish

* finish

* changes to accommodate the improved VAE.

* modifications to how we handle vae encoding in the training.

* make style

* make existing controlnet fast tests pass.

* change vae checkpoint cli arg.

* fix: vae pretrained paths.

* fix: steps in get_scheduler().

* debugging.

* debugging./

* fix: weight conversion.

* add: docs.

* add: limited tests./

* add: datasets to the requirements.

* update docstrings and incorporate the usage of watermarking.

* incorporate fix from huggingface#4083

* fix watermarking dependency handling.

* run make-fix-copies.

* Empty-Commit

* Update requirements_sdxl.txt

* remove vae upcasting part.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* run make style

* run make fix-copies.

* disable suppot for multicontrolnet.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* run make fix-copies.

* dtyle/.

* fix-copies.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
@zdxpan
Copy link

zdxpan commented Aug 29, 2023

tranning met loss is nan, and the pred_noise cntain nan, which will case the tranning fail, the traing target is predict noise with given noise, which mse alwas nearby 1 (~= 1),
possible reason
1、 lr too large
2、 data had some nan value
3、is there anty other reason?

and met the log_validate validate image alwas balck

@sayakpaul
Copy link
Member Author

sayakpaul commented Aug 29, 2023

Try passing the following as your VAE: madebyollin/sdxl-vae-fp16-fix.

Additionally, you can ask questions on the repositories like the following, which leveraged our training scripts to obtain nice results: https://huggingface.co/thibaud/controlnet-openpose-sdxl-1.0/discussions.

@patrickvonplaten
Copy link
Contributor

tranning met loss is nan, and the pred_noise cntain nan, which will case the tranning fail, the traing target is predict noise with given noise, which mse alwas nearby 1 (~= 1), possible reason 1、 lr too large 2、 data had some nan value 3、is there anty other reason?

and met the log_validate validate image alwas balck

@zdxpan please make sure to open a new issue instead of commenting on the PR here

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add: controlnet sdxl.

* modifications to controlnet.

* run styling.

* add: __init__.pys

* incorporate huggingface#4019 changes.

* run make fix-copies.

* resize the conditioning images.

* remove autocast.

* run styling.

* disable autocast.

* debugging

* device placement.

* back to autocast.

* remove comment.

* save some memory by reusing the vae and unet in the pipeline.

* apply styling.

* Allow low precision sd xl

* finish

* finish

* changes to accommodate the improved VAE.

* modifications to how we handle vae encoding in the training.

* make style

* make existing controlnet fast tests pass.

* change vae checkpoint cli arg.

* fix: vae pretrained paths.

* fix: steps in get_scheduler().

* debugging.

* debugging./

* fix: weight conversion.

* add: docs.

* add: limited tests./

* add: datasets to the requirements.

* update docstrings and incorporate the usage of watermarking.

* incorporate fix from huggingface#4083

* fix watermarking dependency handling.

* run make-fix-copies.

* Empty-Commit

* Update requirements_sdxl.txt

* remove vae upcasting part.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* run make style

* run make fix-copies.

* disable suppot for multicontrolnet.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* run make fix-copies.

* dtyle/.

* fix-copies.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add: controlnet sdxl.

* modifications to controlnet.

* run styling.

* add: __init__.pys

* incorporate huggingface#4019 changes.

* run make fix-copies.

* resize the conditioning images.

* remove autocast.

* run styling.

* disable autocast.

* debugging

* device placement.

* back to autocast.

* remove comment.

* save some memory by reusing the vae and unet in the pipeline.

* apply styling.

* Allow low precision sd xl

* finish

* finish

* changes to accommodate the improved VAE.

* modifications to how we handle vae encoding in the training.

* make style

* make existing controlnet fast tests pass.

* change vae checkpoint cli arg.

* fix: vae pretrained paths.

* fix: steps in get_scheduler().

* debugging.

* debugging./

* fix: weight conversion.

* add: docs.

* add: limited tests./

* add: datasets to the requirements.

* update docstrings and incorporate the usage of watermarking.

* incorporate fix from huggingface#4083

* fix watermarking dependency handling.

* run make-fix-copies.

* Empty-Commit

* Update requirements_sdxl.txt

* remove vae upcasting part.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* run make style

* run make fix-copies.

* disable suppot for multicontrolnet.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* run make fix-copies.

* dtyle/.

* fix-copies.

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet