Skip to content

[TypeError] in DreamBooth SDXL LoRA training when use_dora parameter is False #9841

@adhiiisetiawan

Description

@adhiiisetiawan

Describe the bug

When running the DreamBooth SDXL training script with LoRA, it throws a TypeError even when use_dora=False (default). This happens because the use_dora parameter is always being passed to LoraConfig, regardless of whether DoRA is being used or not. I plan to submit a PR to fix this by conditionally including the use_dora parameter only when it's True.

Reproduction

  1. Run the DreamBooth SDXL training script in examples/dreambooth/train_dreambooth_lora_sdxl.py
  2. The script fails even with default parameters (where use_dora=False)
python3 train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --dataset_name="adhisetiawan/food-bakso-img" \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --output_dir="lora-trained-xl" \
  --use_8bit_adam \
  --mixed_precision="bf16" \
  --train_text_encoder \
  --instance_prompt="a photo of bakso" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of bakso in a bowl" \
  --validation_epochs=5 \
  --seed=0

Logs

{'reverse_transformer_layers_per_block', 'attention_type', 'dropout'} was not found in config. Values will be initialized to default values.
Traceback (most recent call last):
  File "/workspace/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1986, in <module>
    main(args)
  File "/workspace/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1187, in main
    unet_lora_config = LoraConfig(
TypeError: LoraConfig.__init__() got an unexpected keyword argument 'use_dora'

System Info

  • 🤗 Diffusers version: 0.32.0.dev0
  • Platform: Linux-5.4.0-153-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.1.0+cu118 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.26.2
  • Transformers version: 4.46.1
  • Accelerate version: 1.1.0
  • PEFT version: 0.7.0
  • Bitsandbytes version: 0.44.1
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA RTX A6000, 49140 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions