Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support EDM-style training in DreamBooth LoRA SDXL script #7126

Merged
merged 29 commits into from
Mar 3, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Feb 28, 2024

Command example:

CUDA_VISIBLE_DEVICES=1 accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic"  \
  --instance_data_dir="dog" \
  --output_dir="dog-playground-lora" \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --use_8bit_adam \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

WandB: https://wandb.ai/sayakpaul/dreambooth-lora-playground/runs/sxe4bavp

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 1456 to 1466
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)

step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For later:

We could think of making this more general by allowing to sample sigmas as presented in the paper, cf https://github.com/NVlabs/edm/blob/main/training/loss.py#L74

examples/dreambooth/train_dreambooth_lora_playground.py Outdated Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora_playground.py Outdated Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora_playground.py Outdated Show resolved Hide resolved
)[0]

model_pred = model_pred * (-sigmas) + noisy_model_input
weighing = sigmas**-2.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For later: this could be made configurable, as there are multiple weighing alternatives. In EDM they use
https://github.com/NVlabs/edm/blob/main/training/loss.py#L75

examples/dreambooth/train_dreambooth_lora_playground.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

sayakpaul commented Feb 28, 2024

@patil-suraj could you give this another look? Results are still blank: https://wandb.ai/sayakpaul/dreambooth-lora-playground/runs/i7aq50g0.

Experimenting with a lower LR.

@patil-suraj
Copy link
Contributor

Can't seem to find anything else, will also try to run the script and see what's going on.

@patil-suraj
Copy link
Contributor

Got a good run: https://wandb.ai/psuraj/dreambooth-lora-playground/runs/j34izml0?workspace=user-psuraj (still going on)

What fixed it:

  • Load fp32 variant of the vae .
  • don't use autocast during generation
  • disable loss weighing.

@sayakpaul
Copy link
Member Author

Applied the changes, @patil-suraj. Could you try another run?

@patil-suraj
Copy link
Contributor

@sayakpaul sayakpaul marked this pull request as ready for review February 28, 2024 13:03
@sayakpaul
Copy link
Member Author

@patil-suraj ready for a review. Feel free to test the script too :)

@pcuenca feel free to give this a review as well.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general. I'd maybe try to avoid hardcoded references to the string "playgroundai" to make decisions, if possible.


It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).

For the SDXL model, simple set:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For the SDXL model, simple set:
For the standard SDXL model, simply set:

Does it work with SDXL out of the box? 🤯

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a test that you can check but I haven't done a full-blown training run.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For LoRA it might not work, but can def be fine-tuned with EDM.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj elaborate?

examples/dreambooth/train_dreambooth_lora_sdxl.py Outdated Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora_sdxl.py Outdated Show resolved Hide resolved
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, left some comments. +1 to what pedro said.


It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).

For the SDXL model, simple set:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For LoRA it might not work, but can def be fine-tuned with EDM.

examples/dreambooth/train_dreambooth_lora_sdxl.py Outdated Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora_sdxl.py Outdated Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora_sdxl.py Outdated Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora_sdxl.py Outdated Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora_sdxl.py Outdated Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora_sdxl.py Outdated Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora_sdxl.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

@patil-suraj ready for another review.

@sayakpaul
Copy link
Member Author

@pcuenca @patil-suraj I have addressed all your comments. Would appreciate another review.

I am going to run with the command from the OP one more time and also with regular SDXL with --do_edm_style_training.

@sayakpaul sayakpaul changed the title Support DreamBooth LoRA for Playground Support EDM-style training in DreamBooth LoRA SDXL script Feb 29, 2024
@sayakpaul
Copy link
Member Author

Started a regular SDXL run with EDM:

CUDA_VISIBLE_DEVICES=1 accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"  \
  --instance_data_dir="dog"\
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --output_dir="lora-sdxl-dog" \
  --mixed_precision="fp16" \
  --use_8bit_adam \
  --do_edm_style_training \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" 

Garbage results: https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl/runs/bup8u1yc. Let me further tweak around some things.

@sayakpaul
Copy link
Member Author

@pcuenca @patil-suraj the script now should work out of the box when do_edm_style_training specified for SDXL:

CUDA_VISIBLE_DEVICES=1 accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"  \
  --instance_data_dir="dog"\
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --output_dir="lora-sdxl-dog" \
  --mixed_precision="fp16" \
  --use_8bit_adam \
  --do_edm_style_training \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" 

Feel free to train one yourselves. Here are my results: https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl/runs/dz77sffl

Please review the changes so that we can ship this beast!

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments. The script is in a very good state for EDM. I would just suggest to verify the euler bit before adding it here or maybe even do it in another PR. (saw the other comment, all good)

Also are the vae weights loaded and kept in fp32 ?

examples/dreambooth/train_dreambooth_lora_sdxl.py Outdated Show resolved Hide resolved
# There might be other alternatives for weighting as well:
# https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
if "EDM" not in scheduler_type:
weighting = (sigmas**-2.0).float()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should verify if this works with euler

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is: https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl/runs/dz77sffl. When do_edm_style_training is True and the scheduler is not EDM*, we are using EulerDiscrete. The run is from that setting.

Does that work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

@sayakpaul
Copy link
Member Author

Also are the vae weights loaded and kept in fp32 ?

Yes, that is the case. I have addressed your other comment as well, @patil-suraj. LMK.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great now! Feel free to merge

@sayakpaul sayakpaul merged commit ccb93dc into main Mar 3, 2024
10 checks passed
@sayakpaul
Copy link
Member Author

I keep getting a PermissionError: [Errno 13] Permission denied when trying to access the dog folder. Logged in to huggingface and running as administrator. All folders have full read/write access.

That is an issue quite unrelated to this PR.

@sayakpaul sayakpaul deleted the playground-dreambooth-lora branch March 3, 2024 03:59
Comment on lines +946 to +947
if args.do_edm_style_training and args.snr_gamma is not None:
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do this earlier, so it doesn't load the model yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's at the beginning:

if args.do_edm_style_training and args.snr_gamma is not None:

way before the model loading code.

sayakpaul added a commit that referenced this pull request Mar 14, 2024
… of #7126) (#7182)

* add edm style training

* style

* finish adding edm training feature

* import fix

* fix latents mean

* minor adjustments

* add edm to readme

* style

* fix autocast and scheduler config issues when using edm

* style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants