diff --git a/examples/dreambooth/README_sd3.md b/examples/dreambooth/README_sd3.md index a340be350db8..89d87d65dd44 100644 --- a/examples/dreambooth/README_sd3.md +++ b/examples/dreambooth/README_sd3.md @@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \ --push_to_hub ``` +### Targeting Specific Blocks & Layers +As image generation models get bigger & more powerful, more fine-tuners come to find that training only part of the +transformer blocks (sometimes as little as two) can be enough to get great results. +In some cases, it can be even better to maintain some of the blocks/layers frozen. + +For **SD3.5-Large** specifically, you may find this information useful (taken from: [Stable Diffusion 3.5 Large Fine-tuning Tutorial](https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6#12461cdcd19680788a23c650dab26b93): +> [!NOTE] +> A commonly believed heuristic that we verified once again during the construction of the SD3.5 family of models is that later/higher layers (i.e. `30 - 37`)* impact tertiary details more heavily. Conversely, earlier layers (i.e. `12 - 24` )* influence the overall composition/primary form more. +> So, freezing other layers/targeting specific layers is a viable approach. +> `*`These suggested layers are speculative and not 100% guaranteed. The tips here are more or less a general idea for next steps. +> **Photorealism** +> In preliminary testing, we observed that freezing the last few layers of the architecture significantly improved model training when using a photorealistic dataset, preventing detail degradation introduced by small dataset from happening. +> **Anatomy preservation** +> To dampen any possible degradation of anatomy, training only the attention layers and **not** the adaptive linear layers could help. For reference, below is one of the transformer blocks. + + +We've added `--lora_layers` and `--lora_blocks` to make LoRA training modules configurable. +- with `--lora_blocks` you can specify the block numbers for training. E.g. passing - +```diff +--lora_blocks "12,13,14,15,16,17,18,19,20,21,22,23,24,30,31,32,33,34,35,36,37" +``` +will trigger LoRA training of transformer blocks 12-24 and 30-37. By default, all blocks are trained. +- with `--lora_layers` you can specify the types of layers you wish to train. +By default, the trained layers are - +`attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,attn.to_k,attn.to_out.0,attn.to_q,attn.to_v` +If you wish to have a leaner LoRA / train more blocks over layers you could pass - +```diff ++ --lora_layers attn.to_k,attn.to_q,attn.to_v,attn.to_out.0 +``` +This will reduce LoRA size by roughly 50% for the same rank compared to the default. +However, if you're after compact LoRAs, it's our impression that maintaining the default setting for `--lora_layers` and +freezing some of the early & blocks is usually better. + + ### Text Encoder Training Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: diff --git a/examples/dreambooth/test_dreambooth_lora_sd3.py b/examples/dreambooth/test_dreambooth_lora_sd3.py index ec323be4143e..5d6c8bb9938a 100644 --- a/examples/dreambooth/test_dreambooth_lora_sd3.py +++ b/examples/dreambooth/test_dreambooth_lora_sd3.py @@ -38,6 +38,9 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate): pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe" script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py" + transformer_block_idx = 0 + layer_type = "attn.to_k" + def test_dreambooth_lora_sd3(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" @@ -136,6 +139,74 @@ def test_dreambooth_lora_latent_caching(self): starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_block(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_blocks {self.transformer_block_idx} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + # In this test, only params of transformer block 0 should be in the state dict + starts_with_transformer = all( + key.startswith("transformer.transformer_blocks.0") for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_layer(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_layers {self.layer_type} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # In this test, only transformer params of attention layers `attn.to_k` should be in the state dict + starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 4b39dcfe41b0..fc3c69b8901f 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -571,6 +571,25 @@ def parse_args(input_args=None): "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + "The transformer block layers to apply LoRA training on. Please specify the layers in a comma seperated string." + "For examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md" + ), + ) + parser.add_argument( + "--lora_blocks", + type=str, + default=None, + help=( + "The transformer blocks to apply LoRA training on. Please specify the block numbers in a comma seperated manner." + 'E.g. - "--lora_blocks 12,30" will result in lora training of transformer blocks 12 and 30. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md' + ), + ) + parser.add_argument( "--adam_epsilon", type=float, @@ -1222,13 +1241,31 @@ def main(args): if args.train_text_encoder: text_encoder_one.gradient_checkpointing_enable() text_encoder_two.gradient_checkpointing_enable() + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = [ + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "attn.to_k", + "attn.to_out.0", + "attn.to_q", + "attn.to_v", + ] + if args.lora_blocks is not None: + target_blocks = [int(block.strip()) for block in args.lora_blocks.split(",")] + target_modules = [ + f"transformer_blocks.{block}.{module}" for block in target_blocks for module in target_modules + ] # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], + target_modules=target_modules, ) transformer.add_adapter(transformer_lora_config)