-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core] add: controlnet support for SDXL (#4038)
* add: controlnet sdxl. * modifications to controlnet. * run styling. * add: __init__.pys * incorporate #4019 changes. * run make fix-copies. * resize the conditioning images. * remove autocast. * run styling. * disable autocast. * debugging * device placement. * back to autocast. * remove comment. * save some memory by reusing the vae and unet in the pipeline. * apply styling. * Allow low precision sd xl * finish * finish * changes to accommodate the improved VAE. * modifications to how we handle vae encoding in the training. * make style * make existing controlnet fast tests pass. * change vae checkpoint cli arg. * fix: vae pretrained paths. * fix: steps in get_scheduler(). * debugging. * debugging./ * fix: weight conversion. * add: docs. * add: limited tests./ * add: datasets to the requirements. * update docstrings and incorporate the usage of watermarking. * incorporate fix from #4083 * fix watermarking dependency handling. * run make-fix-copies. * Empty-Commit * Update requirements_sdxl.txt * remove vae upcasting part. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * run make style * run make fix-copies. * disable suppot for multicontrolnet. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * run make fix-copies. * dtyle/. * fix-copies. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
- Loading branch information
1 parent
c6e56e9
commit 3eb498e
Showing
12 changed files
with
2,686 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# DreamBooth training example for Stable Diffusion XL (SDXL) | ||
|
||
The `train_controlnet_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). | ||
|
||
## Running locally with PyTorch | ||
|
||
### Installing the dependencies | ||
|
||
Before running the scripts, make sure to install the library's training dependencies: | ||
|
||
**Important** | ||
|
||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: | ||
|
||
```bash | ||
git clone https://github.com/huggingface/diffusers | ||
cd diffusers | ||
pip install -e . | ||
``` | ||
|
||
Then cd in the `examples/controlnet` folder and run | ||
```bash | ||
pip install -r requirements_sdxl.txt | ||
``` | ||
|
||
And initialize an [馃Accelerate](https://github.com/huggingface/accelerate/) environment with: | ||
|
||
```bash | ||
accelerate config | ||
``` | ||
|
||
Or for a default accelerate configuration without answering questions about your environment | ||
|
||
```bash | ||
accelerate config default | ||
``` | ||
|
||
Or if your environment doesn't support an interactive shell (e.g., a notebook) | ||
|
||
```python | ||
from accelerate.utils import write_basic_config | ||
write_basic_config() | ||
``` | ||
|
||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. | ||
|
||
## Circle filling dataset | ||
|
||
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script. | ||
|
||
## Training | ||
|
||
Our training examples use two test conditioning images. They can be downloaded by running | ||
|
||
```sh | ||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png | ||
|
||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png | ||
``` | ||
|
||
Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub. | ||
|
||
```bash | ||
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9" | ||
export OUTPUT_DIR="path to save model" | ||
|
||
accelerate launch train_controlnet_sdxl.py \ | ||
--pretrained_model_name_or_path=$MODEL_DIR \ | ||
--output_dir=$OUTPUT_DIR \ | ||
--dataset_name=fusing/fill50k \ | ||
--mixed_precision="fp16" \ | ||
--resolution=1024 \ | ||
--learning_rate=1e-5 \ | ||
--max_train_steps=15000 \ | ||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ | ||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ | ||
--validation_steps=100 \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--report_to="wandb" \ | ||
--seed=42 \ | ||
--push_to_hub | ||
``` | ||
|
||
To better track our training experiments, we're using the following flags in the command above: | ||
|
||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. | ||
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. | ||
|
||
Our experiments were conducted on a single 40GB A100 GPU. | ||
|
||
### Inference | ||
|
||
Once training is done, we can perform inference like so: | ||
|
||
```python | ||
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler | ||
from diffusers.utils import load_image | ||
import torch | ||
|
||
base_model_path = "stabilityai/stable-diffusion-xl-base-0.9" | ||
controlnet_path = "path to controlnet" | ||
|
||
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) | ||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | ||
base_model_path, controlnet=controlnet, torch_dtype=torch.float16 | ||
) | ||
|
||
# speed up diffusion process with faster scheduler and memory optimization | ||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | ||
# remove following line if xformers is not installed or when using Torch 2.0. | ||
pipe.enable_xformers_memory_efficient_attention() | ||
# memory optimization. | ||
pipe.enable_model_cpu_offload() | ||
|
||
control_image = load_image("./conditioning_image_1.png") | ||
prompt = "pale golden rod circle with old lace background" | ||
|
||
# generate image | ||
generator = torch.manual_seed(0) | ||
image = pipe( | ||
prompt, num_inference_steps=20, generator=generator, image=control_image | ||
).images[0] | ||
image.save("./output.png") | ||
``` | ||
|
||
## Notes | ||
|
||
### Specifying a better VAE | ||
|
||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
accelerate>=0.16.0 | ||
torchvision | ||
transformers>=4.25.1 | ||
ftfy | ||
tensorboard | ||
Jinja2 | ||
invisible-watermark>=0.2.0 | ||
datasets | ||
wandb |
Oops, something went wrong.