Skip to content

Commit a0bbc82

Browse files
committed
dreambooth if docs - stage II, more info
1 parent 4f14b36 commit a0bbc82

File tree

1 file changed

+132
-16
lines changed

1 file changed

+132
-16
lines changed

docs/source/en/training/dreambooth.mdx

Lines changed: 132 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,54 @@ You may also run inference from any of the [saved training checkpoints](#inferen
502502

503503
## IF
504504

505-
You can use the lora and full dreambooth scripts to also train the text to image [IF model](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0). A few alternative cli flags are needed due to the model size, the expected input resolution, and the text encoder conventions.
505+
You can use the lora and full dreambooth scripts to train the text to image [IF model](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0) and the stage II upscaler
506+
[IF model](https://huggingface.co/DeepFloyd/IF-II-L-v1.0).
506507

507-
### LoRA Dreambooth
508+
Note that IF has a predicted variance, and our finetuning scripts only train the models predicted error, so for finetuned IF models we switch to a fixed
509+
variance schedule. The full finetuning scripts will update the scheduler config for the full saved model. However, when loading saved LoRA weights, you
510+
must also update the pipeline's scheduler config.
511+
512+
```py
513+
from diffusers import DiffusionPipeline
514+
515+
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0")
516+
517+
pipe.load_lora_weights("<lora weights path>")
518+
519+
# Update scheduler config to fixed variance schedule
520+
pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small")
521+
```
522+
523+
Additionally, a few alternative cli flags are needed for IF.
524+
525+
`--resolution=64`: IF is a pixel space diffusion model. In order to operate on un-compressed pixels, the input images are of a much smaller resolution.
526+
527+
`--pre_compute_text_embeddings`: IF uses T5 for its text encoder. In order to save GPU memory, we pre compute all text embeddings and then de-allocate
528+
T5.
529+
530+
`--tokenizer_max_length=77`: T5 has a longer default text length, but the default IF encoding procedure uses a smaller number.
531+
532+
`--text_encoder_use_attention_mask`: T5 passes the attention mask to the text encoder.
533+
534+
`--skip_save_text_encoder`: When training the full model, this will skip saving the entire T5 with the finetuned model. You can still load the pipeline
535+
with a T5 loaded from the original model.
536+
537+
`use_8bit_adam`: When training the full model,
538+
539+
### Tips and Tricks
540+
We find LoRA to be sufficient for finetuning the stage I model as the low resolution of the model makes representing finegrained detail hard regardless.
541+
542+
For common and/or not-visually complex object concepts, you can get away with not-finetuning the upscaler. Just be sure to adjust the prompt passed to the
543+
upscaler to remove the new token from the instance prompt. I.e. if your stage I prompt is "a sks dog", use "a dog" for your stage II prompt.
544+
545+
For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
546+
LoRA finetuning stage II.
547+
548+
For finegrained detail like faces, we find that lower learning rates work best.
549+
550+
For stage II, we find that lower learning rates are also needed.
551+
552+
### IF stage I LoRA Dreambooth
508553
This training configuration requires ~28 GB VRAM.
509554

510555
```sh
@@ -518,7 +563,7 @@ accelerate launch train_dreambooth_lora.py \
518563
--instance_data_dir=$INSTANCE_DIR \
519564
--output_dir=$OUTPUT_DIR \
520565
--instance_prompt="a sks dog" \
521-
--resolution=64 \ # The input resolution of the IF unet is 64x64
566+
--resolution=64 \
522567
--train_batch_size=4 \
523568
--gradient_accumulation_steps=1 \
524569
--learning_rate=5e-6 \
@@ -527,16 +572,54 @@ accelerate launch train_dreambooth_lora.py \
527572
--validation_prompt="a sks dog" \
528573
--validation_epochs=25 \
529574
--checkpointing_steps=100 \
530-
--pre_compute_text_embeddings \ # Pre compute text embeddings to that T5 doesn't have to be kept in memory
531-
--tokenizer_max_length=77 \ # IF expects an override of the max token length
532-
--text_encoder_use_attention_mask # IF expects attention mask for text embeddings
575+
--pre_compute_text_embeddings \
576+
--tokenizer_max_length=77 \
577+
--text_encoder_use_attention_mask
578+
```
579+
580+
### IF stage II LoRA Dreambooth
581+
582+
`--validation_images`: These images are upscaled during validation steps.
583+
584+
`--class_labels_conditioning=timesteps`: Pass additional conditioning to the UNet needed for stage II.
585+
586+
`--learning_rate=1e-6`: Lower learning rate than stage I.
587+
588+
`--resolution=256`: The upscaler expects higher resolution inputs
589+
590+
```sh
591+
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
592+
export INSTANCE_DIR="dog"
593+
export OUTPUT_DIR="dreambooth_dog_upscale"
594+
export VALIDATION_IMAGES="image_1.png image_2.png image_3.png image_4.png"
595+
596+
python train_dreambooth_lora.py \
597+
--report_to wandb \
598+
--pretrained_model_name_or_path=$MODEL_NAME \
599+
--instance_data_dir=$INSTANCE_DIR \
600+
--output_dir=$OUTPUT_DIR \
601+
--instance_prompt="a sks dog" \
602+
--resolution=256 \
603+
--train_batch_size=4 \
604+
--gradient_accumulation_steps=1 \
605+
--learning_rate=1e-6 \
606+
--max_train_steps=2000 \
607+
--validation_prompt="a sks dog" \
608+
--validation_epochs=100 \
609+
--checkpointing_steps=500 \
610+
--pre_compute_text_embeddings \
611+
--tokenizer_max_length=77 \
612+
--text_encoder_use_attention_mask \
613+
--validation_images $VALIDATION_IMAGES \
614+
--class_labels_conditioning=timesteps
533615
```
534616

535-
### Full Dreambooth
536-
Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.
537-
Using 8bit adam and the rest of the following config, the model can be trained in ~48 GB VRAM.
617+
### IF Stage I Full Dreambooth
618+
`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.
538619

539-
For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade.
620+
`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade.
621+
622+
Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.
540623

541624
```sh
542625
export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
@@ -549,17 +632,50 @@ accelerate launch train_dreambooth.py \
549632
--instance_data_dir=$INSTANCE_DIR \
550633
--output_dir=$OUTPUT_DIR \
551634
--instance_prompt="a photo of sks dog" \
552-
--resolution=64 \ # The input resolution of the IF unet is 64x64
635+
--resolution=64 \
553636
--train_batch_size=4 \
554637
--gradient_accumulation_steps=1 \
555638
--learning_rate=1e-7 \
556639
--max_train_steps=150 \
557640
--validation_prompt "a photo of sks dog" \
558641
--validation_steps 25 \
559-
--text_encoder_use_attention_mask \ # IF expects attention mask for text embeddings
560-
--tokenizer_max_length 77 \ # IF expects an override of the max token length
561-
--pre_compute_text_embeddings \ # Pre compute text embeddings to that T5 doesn't have to be kept in memory
642+
--text_encoder_use_attention_mask \
643+
--tokenizer_max_length 77 \
644+
--pre_compute_text_embeddings \
562645
--use_8bit_adam \ #
563646
--set_grads_to_none \
564-
--skip_save_text_encoder # do not save the full T5 text encoder with the model
565-
```
647+
--skip_save_text_encoder
648+
```
649+
650+
### IF Stage II Full Dreambooth
651+
652+
`--learning_rate=1e-8`: Even lower learning rate.
653+
654+
`--resolution=256`: The upscaler expects higher resolution inputs
655+
656+
```sh
657+
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
658+
export INSTANCE_DIR="dog"
659+
export OUTPUT_DIR="dreambooth_dog_upscale"
660+
export VALIDATION_IMAGES="image_1.png image_2.png image_3.png image_4.png"
661+
662+
accelerate launch train_dreambooth.py \
663+
--report_to wandb \
664+
--pretrained_model_name_or_path=$MODEL_NAME \
665+
--instance_data_dir=$INSTANCE_DIR \
666+
--output_dir=$OUTPUT_DIR \
667+
--instance_prompt="a sks dog" \
668+
--resolution=256 \
669+
--train_batch_size=2 \
670+
--gradient_accumulation_steps=2 \
671+
--learning_rate=1e-8 \
672+
--max_train_steps=2000 \
673+
--validation_prompt="a sks dog" \
674+
--validation_steps=150 \
675+
--checkpointing_steps=500 \
676+
--pre_compute_text_embeddings \
677+
--tokenizer_max_length=77 \
678+
--text_encoder_use_attention_mask \
679+
--validation_images $VALIDATION_IMAGES \
680+
--class_labels_conditioning timesteps
681+
```

0 commit comments

Comments
 (0)