diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 72fcfa648b48..e3222357ae9b 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -436,6 +436,12 @@ def parse_args(input_args=None): default=None, help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) if input_args is not None: args = parser.parse_args(input_args) @@ -845,7 +851,9 @@ def main(args): LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) unet_lora_attn_procs[name] = lora_attn_processor_class( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=args.rank, ) unet.set_attn_processor(unet_lora_attn_procs) @@ -860,7 +868,9 @@ def main(args): for name, module in text_encoder.named_modules(): if name.endswith(TEXT_ENCODER_ATTN_MODULE): text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_proj.out_features, cross_attention_dim=None + hidden_size=module.out_proj.out_features, + cross_attention_dim=None, + rank=args.rank, ) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) temp_pipeline = DiffusionPipeline.from_pretrained(