Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DoRA training feature to sdxl dreambooth lora script #7235

Merged
merged 2 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions examples/dreambooth/README_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,29 @@ accelerate launch train_dreambooth_lora_sdxl.py \

> [!CAUTION]
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".

### DoRA training
The script now supports DoRA training too!
> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353),
**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.
The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.

> [!NOTE]
> 💡DoRA training is still _experimental_
> and is likely to require different hyperparameter values to perform best compared to a LoRA.
> Specifically, we've noticed 2 differences to take into account your training:
> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)
> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.
> This is also aligned with some of the quantitative analysis shown in the paper.

**Usage**
1. To use DoRA you need to install `peft` from main:
```bash
pip install git+https://github.com/huggingface/peft.git
```
2. Enable DoRA training by adding this flag
```bash
--use_dora
```
**Inference**
The inference is the same as if you train a regular LoRA 🤗
11 changes: 11 additions & 0 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,15 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--use_dora",
action="store_true",
default=False,
help=(
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1147,6 +1156,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
Expand All @@ -1158,6 +1168,7 @@ def main(args):
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
Expand Down