From 77b9f04a0c0edc20b9cff1c88589332c70a647d3 Mon Sep 17 00:00:00 2001 From: Levi McCallum Date: Fri, 28 Jul 2023 08:09:56 -0700 Subject: [PATCH 1/2] Add rank argument to train_dreambooth_lora_sdxl.py --- .../dreambooth/train_dreambooth_lora_sdxl.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 443f2d25a788..6e8694b2042e 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -402,6 +402,12 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) if input_args is not None: args = parser.parse_args(input_args) @@ -767,7 +773,11 @@ def main(args): lora_attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) - module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + module = lora_attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=args.rank + ) unet_lora_attn_procs[name] = module unet_lora_parameters.extend(module.parameters()) @@ -777,8 +787,12 @@ def main(args): # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(text_encoder_one, dtype=torch.float32) - text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(text_encoder_two, dtype=torch.float32) + text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( + text_encoder_one, dtype=torch.float32, rank=args.rank + ) + text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder( + text_encoder_two, dtype=torch.float32, rank=args.rank + ) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): From 8997fc6d025123aec4b7a0d00a45c85b84e8015a Mon Sep 17 00:00:00 2001 From: Levi McCallum Date: Fri, 28 Jul 2023 11:30:25 -0700 Subject: [PATCH 2/2] Update train_dreambooth_lora_sdxl.py --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 6e8694b2042e..4448d784c7a5 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -774,9 +774,7 @@ def main(args): LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) module = lora_attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=args.rank + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank ) unet_lora_attn_procs[name] = module unet_lora_parameters.extend(module.parameters())