Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support EDM-style training in DreamBooth LoRA SDXL script #7126

Merged
merged 29 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f3013e2
add: dreambooth lora script for Playground v2.5
sayakpaul Feb 28, 2024
cb5f0d7
fix: kwarg
sayakpaul Feb 28, 2024
84047c2
address suraj's comments.
sayakpaul Feb 28, 2024
746b625
Apply suggestions from code review
sayakpaul Feb 28, 2024
dcf36b4
Merge branch 'main' into playground-dreambooth-lora
sayakpaul Feb 28, 2024
0d1547f
apply suraj's suggestion
sayakpaul Feb 28, 2024
c9dd1d0
incorporate changes in the canonical script./
sayakpaul Feb 28, 2024
e2b7144
tracker naming
sayakpaul Feb 28, 2024
db025fc
fix: schedule determination
sayakpaul Feb 28, 2024
80ef425
add: two simple tests
sayakpaul Feb 28, 2024
3ae0e28
remove playground script
sayakpaul Feb 28, 2024
2f49630
note about edm-style training
sayakpaul Feb 28, 2024
c8ed8af
address pedro's comments.
sayakpaul Feb 28, 2024
e02e6f6
address part of Suraj's comments.
sayakpaul Feb 28, 2024
a8bea31
Apply suggestions from code review
sayakpaul Feb 28, 2024
60bf1e6
remove guidance_scale.
sayakpaul Feb 28, 2024
00156a3
use mse_loss.
sayakpaul Feb 28, 2024
3621e18
add comments for preconditioning.
sayakpaul Feb 28, 2024
206f2c7
quality
sayakpaul Feb 28, 2024
f7fc1f6
Update examples/dreambooth/train_dreambooth_lora_sdxl.py
sayakpaul Feb 29, 2024
d96d8ea
Merge branch 'main' into playground-dreambooth-lora
sayakpaul Feb 29, 2024
dde7595
tackle v-pred.
sayakpaul Feb 29, 2024
128b877
Empty-Commit
sayakpaul Feb 29, 2024
e052fe1
Merge branch 'main' into playground-dreambooth-lora
sayakpaul Feb 29, 2024
65c382c
support edm for sdxl too.
sayakpaul Feb 29, 2024
0f046d8
Merge branch 'main' into playground-dreambooth-lora
sayakpaul Mar 1, 2024
10c55bb
address suraj's comments.
sayakpaul Mar 1, 2024
e3210fc
Merge branch 'main' into playground-dreambooth-lora
sayakpaul Mar 2, 2024
fe75b46
Empty-Commit
sayakpaul Mar 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/dreambooth/README_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,40 @@ You can explore the results from a couple of our internal experiments by checkin
## Running on a free-tier Colab Notebook

Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb).

## Conducting EDM-style training

It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).

For the SDXL model, simple set:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For the SDXL model, simple set:
For the standard SDXL model, simply set:

Does it work with SDXL out of the box? 🤯

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a test that you can check but I haven't done a full-blown training run.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For LoRA it might not work, but can def be fine-tuned with EDM.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj elaborate?


```diff
+ --do_edm_style_training \
```

Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:

```bash
accelerate launch train_dreambooth_lora_sdxl.py \
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
--instance_data_dir="dog" \
--output_dir="dog-playground-lora" \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--use_8bit_adam \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```

> [!CAUTION]
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
99 changes: 99 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_edm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import tempfile

import safetensors


sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402


logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)


class DreamBoothLoRASDXLWithEDM(ExamplesTestsAccelerate):
def test_dreambooth_lora_sdxl_with_edm(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--do_edm_style_training
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` in their names.
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)

def test_dreambooth_lora_playground(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-playground-v2-5-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` in their names.
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)
Loading
Loading