diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 443f2d25a788..4448d784c7a5 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -402,6 +402,12 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) if input_args is not None: args = parser.parse_args(input_args) @@ -767,7 +773,9 @@ def main(args): lora_attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) - module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + module = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank + ) unet_lora_attn_procs[name] = module unet_lora_parameters.extend(module.parameters()) @@ -777,8 +785,12 @@ def main(args): # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(text_encoder_one, dtype=torch.float32) - text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(text_encoder_two, dtype=torch.float32) + text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( + text_encoder_one, dtype=torch.float32, rank=args.rank + ) + text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder( + text_encoder_two, dtype=torch.float32, rank=args.rank + ) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir):