Skip to content
50 changes: 50 additions & 0 deletions docs/source/en/training/lora.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,40 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
--seed=1337
```

### Finetuning the text encoder and UNet

The script also allows you to finetune the `text_encoder` along with the `unet`.

<Tip warning={true}>

Training the text encoder requires additional memory and it won't fit on a 16GB GPU. You'll need at least 24GB VRAM to use this option.

</Tip>

Pass the `--train_text_encoder` argument to the training script to enable finetuning the `text_encoder` and `unet`:

```bash
accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--dataloader_num_workers=8 \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=15000 \
--learning_rate=1e-04 \
--max_grad_norm=1 \
--lr_scheduler="cosine" --lr_warmup_steps=0 \
--output_dir=${OUTPUT_DIR} \
--push_to_hub \
--hub_model_id=${HUB_MODEL_ID} \
--report_to=wandb \
--checkpointing_steps=500 \
--validation_prompt="A pokemon with blue eyes." \
--train_text_encoder \
--seed=1337
```

### Inference[[text-to-image-inference]]

Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`] and then the [`DPMSolverMultistepScheduler`]:
Expand Down Expand Up @@ -144,6 +178,22 @@ pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.

</Tip>

If you used `--train_text_encoder` during training, then use `pipe.load_lora_weights()` to load the LoRA
weights. For example:

```python
from diffusers import StableDiffusionPipeline
import torch

lora_model_id = "takuoko/classic-anime-expressions-lora"
base_model_id = "stablediffusionapi/anything-v5"

pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.load_lora_weights(lora_model_id, weight_name="pytorch_lora_weights.bin")
image = pipe("1girl, >_<", num_inference_steps=50).images[0]
```


## DreamBooth

Expand Down
41 changes: 41 additions & 0 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,47 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multip
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)

def test_text_to_image_lora_with_text_encoder(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"

with tempfile.TemporaryDirectory() as tmpdir:
initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
--seed=0
--train_text_encoder
--num_validation_images=0
""".split()

run_command(self._launch_args + initial_run_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))

# check `text_encoder` is present at all.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
keys = lora_state_dict.keys()
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
self.assertTrue(is_text_encoder_present)

# the names of the keys of the state dict should either start with `unet`
# or `text_encoder`.
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
self.assertTrue(is_correct_naming)

def test_unconditional_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
initial_run_args = f"""
Expand Down
Loading