diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 04d8a6442ea3..0d7876db9509 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -647,6 +647,15 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--use_dora", + action="store_true", + default=False, + help=( + "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. " + "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`" + ), + ) if input_args is not None: args = parser.parse_args(input_args) @@ -1147,6 +1156,7 @@ def main(args): # now we will add new LoRA weights to the attention layers unet_lora_config = LoraConfig( r=args.rank, + use_dora=args.use_dora, lora_alpha=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"], @@ -1158,6 +1168,7 @@ def main(args): if args.train_text_encoder: text_lora_config = LoraConfig( r=args.rank, + use_dora=args.use_dora, lora_alpha=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],