-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support EDM-style training in DreamBooth LoRA SDXL script #7126
Merged
+307
−29
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
f3013e2
add: dreambooth lora script for Playground v2.5
sayakpaul cb5f0d7
fix: kwarg
sayakpaul 84047c2
address suraj's comments.
sayakpaul 746b625
Apply suggestions from code review
sayakpaul dcf36b4
Merge branch 'main' into playground-dreambooth-lora
sayakpaul 0d1547f
apply suraj's suggestion
sayakpaul c9dd1d0
incorporate changes in the canonical script./
sayakpaul e2b7144
tracker naming
sayakpaul db025fc
fix: schedule determination
sayakpaul 80ef425
add: two simple tests
sayakpaul 3ae0e28
remove playground script
sayakpaul 2f49630
note about edm-style training
sayakpaul c8ed8af
address pedro's comments.
sayakpaul e02e6f6
address part of Suraj's comments.
sayakpaul a8bea31
Apply suggestions from code review
sayakpaul 60bf1e6
remove guidance_scale.
sayakpaul 00156a3
use mse_loss.
sayakpaul 3621e18
add comments for preconditioning.
sayakpaul 206f2c7
quality
sayakpaul f7fc1f6
Update examples/dreambooth/train_dreambooth_lora_sdxl.py
sayakpaul d96d8ea
Merge branch 'main' into playground-dreambooth-lora
sayakpaul dde7595
tackle v-pred.
sayakpaul 128b877
Empty-Commit
sayakpaul e052fe1
Merge branch 'main' into playground-dreambooth-lora
sayakpaul 65c382c
support edm for sdxl too.
sayakpaul 0f046d8
Merge branch 'main' into playground-dreambooth-lora
sayakpaul 10c55bb
address suraj's comments.
sayakpaul e3210fc
Merge branch 'main' into playground-dreambooth-lora
sayakpaul fe75b46
Empty-Commit
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 HuggingFace Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
import os | ||
import sys | ||
import tempfile | ||
|
||
import safetensors | ||
|
||
|
||
sys.path.append("..") | ||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 | ||
|
||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
logger = logging.getLogger() | ||
stream_handler = logging.StreamHandler(sys.stdout) | ||
logger.addHandler(stream_handler) | ||
|
||
|
||
class DreamBoothLoRASDXLWithEDM(ExamplesTestsAccelerate): | ||
def test_dreambooth_lora_sdxl_with_edm(self): | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
test_args = f""" | ||
examples/dreambooth/train_dreambooth_lora_sdxl.py | ||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe | ||
--do_edm_style_training | ||
--instance_data_dir docs/source/en/imgs | ||
--instance_prompt photo | ||
--resolution 64 | ||
--train_batch_size 1 | ||
--gradient_accumulation_steps 1 | ||
--max_train_steps 2 | ||
--learning_rate 5.0e-04 | ||
--scale_lr | ||
--lr_scheduler constant | ||
--lr_warmup_steps 0 | ||
--output_dir {tmpdir} | ||
""".split() | ||
|
||
run_command(self._launch_args + test_args) | ||
# save_pretrained smoke test | ||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) | ||
|
||
# make sure the state_dict has the correct naming in the parameters. | ||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) | ||
is_lora = all("lora" in k for k in lora_state_dict.keys()) | ||
self.assertTrue(is_lora) | ||
|
||
# when not training the text encoder, all the parameters in the state dict should start | ||
# with `"unet"` in their names. | ||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) | ||
self.assertTrue(starts_with_unet) | ||
|
||
def test_dreambooth_lora_playground(self): | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
test_args = f""" | ||
examples/dreambooth/train_dreambooth_lora_sdxl.py | ||
--pretrained_model_name_or_path hf-internal-testing/tiny-playground-v2-5-pipe | ||
--instance_data_dir docs/source/en/imgs | ||
--instance_prompt photo | ||
--resolution 64 | ||
--train_batch_size 1 | ||
--gradient_accumulation_steps 1 | ||
--max_train_steps 2 | ||
--learning_rate 5.0e-04 | ||
--scale_lr | ||
--lr_scheduler constant | ||
--lr_warmup_steps 0 | ||
--output_dir {tmpdir} | ||
""".split() | ||
|
||
run_command(self._launch_args + test_args) | ||
# save_pretrained smoke test | ||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) | ||
|
||
# make sure the state_dict has the correct naming in the parameters. | ||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) | ||
is_lora = all("lora" in k for k in lora_state_dict.keys()) | ||
self.assertTrue(is_lora) | ||
|
||
# when not training the text encoder, all the parameters in the state dict should start | ||
# with `"unet"` in their names. | ||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) | ||
self.assertTrue(starts_with_unet) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it work with SDXL out of the box? 🤯
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a test that you can check but I haven't done a full-blown training run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For LoRA it might not work, but can def be fine-tuned with EDM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj elaborate?