Skip to content

[Examples] Add a training script for SDXL DreamBooth LoRA#4016

Merged
sayakpaul merged 13 commits intomainfrom
dreambooth/sd-xl-3
Jul 11, 2023
Merged

[Examples] Add a training script for SDXL DreamBooth LoRA#4016
sayakpaul merged 13 commits intomainfrom
dreambooth/sd-xl-3

Conversation

@sayakpaul
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul commented Jul 10, 2023

What does this PR do?

Adds a training script to support fine-tuning SDXL with DreamBooth and LoRA. On top of #3896. Since the conflicts are brutal this PR exists.

Notes

Used the following command to do a test run:

export MODEL_NAME="diffusers/stable-diffusion-xl-base-0.9"
export INSTANCE_DIR="dog"
export CLASS_DIR="dog-class"
export OUTPUT_DIR="lora-trained-xl"

accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-5 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=100 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=50 \
  --seed="0" \
  --push_to_hub

The dog dataset was downloaded using the following code:

from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
    "diffusers/dog-example",
    local_dir=local_dir, repo_type="dataset",
    ignore_patterns=".gitattributes",
)

TODOs

  • Docs
  • Tests

Comment on lines +717 to +719
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay to create the repos with public visibility here since we're NOT distributing the SDXL weights here. It's just the LoRA weights. So, that should be fine w.r.t privacy and access concerns.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do they contain the training checkpoints? or just for the lora weights?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoints (which are intermediate LoRA parameters) as well as the final LoRA parameters. Example: https://huggingface.co/diffusers/lora-trained-xl-potato-head/tree/main (private diffusers team members).

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Jul 10, 2023

The documentation is not available anymore as the PR was closed or merged.

@sayakpaul sayakpaul marked this pull request as ready for review July 10, 2023 07:05
Comment on lines +741 to +743
## Stable Diffusion XL

We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://github.com/Stability-AI/generative-models/blob/main/assets/sdxl_report.pdf) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md). No newline at end of file
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate README here makes sense here I think.

Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! For now we don't allow training the text encoder I assume no?

@patrickvonplaten
Copy link
Copy Markdown
Contributor

@williamberman could you maybe also take a look here?

@sayakpaul
Copy link
Copy Markdown
Member Author

Very nice! For now we don't allow training the text encoder I assume no?

Yeah that is right since the results are already quite nice especially if someone uses the Refiner. So, let's wait for a bit for the community and then incorporating that support won't be a big deal.

Copy link
Copy Markdown
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: license header :)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't include licensing in the REAMDEs, no?

Comment on lines +717 to +719
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do they contain the training checkpoints? or just for the lora weights?

sayakpaul and others added 2 commits July 10, 2023 16:14
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
@n00mkrad
Copy link
Copy Markdown

Training is cool, but do we have a way of applying these LoRAs with SDXL yet?

Comment on lines +771 to +782
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers

xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be still including xformers in new training scripts given we have native flash attention in pytorch now?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The community still seems to use xformers quite a bit. So, let's keep its support.

Copy link
Copy Markdown
Contributor

@williamberman williamberman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c'est magnifique!

@sayakpaul sayakpaul merged commit 3d74dc2 into main Jul 11, 2023
@sayakpaul sayakpaul deleted the dreambooth/sd-xl-3 branch July 11, 2023 02:08
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…e#4016)

* add dreambooth lora script for SDXL incorporating latest changes.

* remove use_auth_token=True.

* add: documentation

* remove unneeded cli.

* increase the number of training steps in the readme.

* add LoraLoaderMixin to the subclassing mix.

* add sdxl lora dreambooth test.

* add: inference code sample.

* add: refiner output.

* add LoraLoaderMixin to the mix of classes of StableDiffusionXLImg2ImgPipeline.

* change default resolution of DreamBoothDataset.

* better sdxl report path.

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…e#4016)

* add dreambooth lora script for SDXL incorporating latest changes.

* remove use_auth_token=True.

* add: documentation

* remove unneeded cli.

* increase the number of training steps in the readme.

* add LoraLoaderMixin to the subclassing mix.

* add sdxl lora dreambooth test.

* add: inference code sample.

* add: refiner output.

* add LoraLoaderMixin to the mix of classes of StableDiffusionXLImg2ImgPipeline.

* change default resolution of DreamBoothDataset.

* better sdxl report path.

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants