From f8ff1cba7bf66a2b20f4e0e639fe35b40b051cc9 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 23 Feb 2024 18:19:24 +0200 Subject: [PATCH 1/5] add is_dora arg --- .../train_dreambooth_lora_sdxl_advanced.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 94a32bcc07f8..1bdcec0e76b6 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -661,6 +661,12 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--use_dora", + type=bool, + default=False, + help=("Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353"), + ) parser.add_argument( "--cache_latents", action="store_true", @@ -1323,6 +1329,7 @@ def main(args): unet_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, + use_dora=args.use_dora, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) @@ -1334,6 +1341,7 @@ def main(args): text_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, + use_dora=args.use_dora, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) From ea92fc174b525e7900b4d3be03a7ee13fa013cd0 Mon Sep 17 00:00:00 2001 From: Linoy Date: Sat, 24 Feb 2024 11:07:27 +0000 Subject: [PATCH 2/5] style --- .../train_dreambooth_lora_sdxl_advanced.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 1bdcec0e76b6..9d6816b002b8 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -665,7 +665,9 @@ def parse_args(input_args=None): "--use_dora", type=bool, default=False, - help=("Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353"), + help=( + "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353" + ), ) parser.add_argument( "--cache_latents", From afd7f6e817752839212c84b8294e2322a162ec3f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 1 Mar 2024 15:42:24 +0100 Subject: [PATCH 3/5] add dora training feature to sd 1.5 script --- .../train_dreambooth_lora_sd15_advanced.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 778e90e2eb25..e2f1f17a11f4 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -651,6 +651,14 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--use_dora", + type=bool, + default=False, + help=( + "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353" + ), + ) parser.add_argument( "--cache_latents", action="store_true", @@ -1219,6 +1227,7 @@ def main(args): unet_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, + use_dora=args.use_dora, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) @@ -1230,6 +1239,7 @@ def main(args): text_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, + use_dora=args.use_dora, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) From a7aa411204d164c20a3550cbd5392ec943d5b50e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 2 Mar 2024 11:06:29 +0100 Subject: [PATCH 4/5] added notes about DoRA training --- .../advanced_diffusion_training/README.md | 29 +++++++++++++++++-- .../train_dreambooth_lora_sd15_advanced.py | 4 ++- .../train_dreambooth_lora_sdxl_advanced.py | 4 ++- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/examples/advanced_diffusion_training/README.md b/examples/advanced_diffusion_training/README.md index 0a49284543d2..d1c2ff71e639 100644 --- a/examples/advanced_diffusion_training/README.md +++ b/examples/advanced_diffusion_training/README.md @@ -80,8 +80,7 @@ To do so, just specify `--train_text_encoder_ti` while launching training (for r Please keep the following points in mind: * SDXL has two text encoders. So, we fine-tune both using LoRA. -* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry. - +* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory. ### 3D icon example @@ -234,6 +233,32 @@ In ComfyUI we will load a LoRA and a textual embedding at the same time. SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). +### DoRA training +The advanced script now supports DoRA training too! +> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353), +**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters. +The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference. + +> [!NOTE] +> 💡DoRA training is still _experimental_ +> and is likely to require different hyperparameter values to perform best compared to a LoRA. +> Specifically, we've noticed 2 differences to take into account your training: +> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA) +> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example. +> This is also aligned with some of the quantitative analysis shown in the paper. + +**Usage** +1. To use DoRA you need to install `peft` from main: +```bash +pip install git+https://github.com/huggingface/peft.git +``` +2. Enable DoRA training by adding this flag +```bash +--use_dora +``` +**Inference** +The inference is the same as if you train a regular LoRA 🤗 + ### Tips and Tricks Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index e2f1f17a11f4..41284fbd9a33 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -654,9 +654,11 @@ def parse_args(input_args=None): parser.add_argument( "--use_dora", type=bool, + action="store_true", default=False, help=( - "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353" + "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " + "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`" ), ) parser.add_argument( diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 9d6816b002b8..30117dcaeec1 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -664,9 +664,11 @@ def parse_args(input_args=None): parser.add_argument( "--use_dora", type=bool, + action="store_true", default=False, help=( - "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353" + "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " + "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`" ), ) parser.add_argument( From 98f94c2c2cd4709b1b596953fdf0c6dbe3c68eec Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 6 Mar 2024 12:47:27 +0100 Subject: [PATCH 5/5] dora in canonical script --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 04d8a6442ea3..0d7876db9509 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -647,6 +647,15 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--use_dora", + action="store_true", + default=False, + help=( + "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " + "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`" + ), + ) if input_args is not None: args = parser.parse_args(input_args) @@ -1147,6 +1156,7 @@ def main(args): # now we will add new LoRA weights to the attention layers unet_lora_config = LoraConfig( r=args.rank, + use_dora=args.use_dora, lora_alpha=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"], @@ -1158,6 +1168,7 @@ def main(args): if args.train_text_encoder: text_lora_config = LoraConfig( r=args.rank, + use_dora=args.use_dora, lora_alpha=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],