[Examples] Add a training script for SDXL DreamBooth LoRA#4016
[Examples] Add a training script for SDXL DreamBooth LoRA#4016
Conversation
| repo_id = create_repo( | ||
| repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token | ||
| ).repo_id |
There was a problem hiding this comment.
I think it's okay to create the repos with public visibility here since we're NOT distributing the SDXL weights here. It's just the LoRA weights. So, that should be fine w.r.t privacy and access concerns.
There was a problem hiding this comment.
do they contain the training checkpoints? or just for the lora weights?
There was a problem hiding this comment.
Checkpoints (which are intermediate LoRA parameters) as well as the final LoRA parameters. Example: https://huggingface.co/diffusers/lora-trained-xl-potato-head/tree/main (private diffusers team members).
|
The documentation is not available anymore as the PR was closed or merged. |
examples/dreambooth/README.md
Outdated
| ## Stable Diffusion XL | ||
|
|
||
| We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://github.com/Stability-AI/generative-models/blob/main/assets/sdxl_report.pdf) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md). No newline at end of file |
There was a problem hiding this comment.
A separate README here makes sense here I think.
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Show resolved
Hide resolved
patrickvonplaten
left a comment
There was a problem hiding this comment.
Very nice! For now we don't allow training the text encoder I assume no?
|
@williamberman could you maybe also take a look here? |
Yeah that is right since the results are already quite nice especially if someone uses the Refiner. So, let's wait for a bit for the community and then incorporating that support won't be a big deal. |
There was a problem hiding this comment.
We don't include licensing in the REAMDEs, no?
| repo_id = create_repo( | ||
| repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token | ||
| ).repo_id |
There was a problem hiding this comment.
do they contain the training checkpoints? or just for the lora weights?
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
|
Training is cool, but do we have a way of applying these LoRAs with SDXL yet? |
| if args.enable_xformers_memory_efficient_attention: | ||
| if is_xformers_available(): | ||
| import xformers | ||
|
|
||
| xformers_version = version.parse(xformers.__version__) | ||
| if xformers_version == version.parse("0.0.16"): | ||
| logger.warn( | ||
| "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | ||
| ) | ||
| unet.enable_xformers_memory_efficient_attention() | ||
| else: | ||
| raise ValueError("xformers is not available. Make sure it is installed correctly") |
There was a problem hiding this comment.
Should we be still including xformers in new training scripts given we have native flash attention in pytorch now?
There was a problem hiding this comment.
The community still seems to use xformers quite a bit. So, let's keep its support.
…e#4016) * add dreambooth lora script for SDXL incorporating latest changes. * remove use_auth_token=True. * add: documentation * remove unneeded cli. * increase the number of training steps in the readme. * add LoraLoaderMixin to the subclassing mix. * add sdxl lora dreambooth test. * add: inference code sample. * add: refiner output. * add LoraLoaderMixin to the mix of classes of StableDiffusionXLImg2ImgPipeline. * change default resolution of DreamBoothDataset. * better sdxl report path. * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
…e#4016) * add dreambooth lora script for SDXL incorporating latest changes. * remove use_auth_token=True. * add: documentation * remove unneeded cli. * increase the number of training steps in the readme. * add LoraLoaderMixin to the subclassing mix. * add sdxl lora dreambooth test. * add: inference code sample. * add: refiner output. * add LoraLoaderMixin to the mix of classes of StableDiffusionXLImg2ImgPipeline. * change default resolution of DreamBoothDataset. * better sdxl report path. * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
What does this PR do?
Adds a training script to support fine-tuning SDXL with DreamBooth and LoRA. On top of #3896. Since the conflicts are brutal this PR exists.
Notes
Used the following command to do a test run:
The
dogdataset was downloaded using the following code:TODOs