-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[docs] LCM training #5796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[docs] LCM training #5796
Conversation
Thanks for the quality library, and writing up docs on how to use everything. I tried LCM training today (for SD2.1 base) and the script failed to run because the machine didn't have the AWS CLI app installed. Then after installing aws, it looks like the s3 bucket used in your example isn't publicly accessible. 😓
Perhaps you want to call out that this is an example of how you could run it, but not a runnable example that you can actually use? One more thing, the scripts are technically broken right now (I guess there are no tests that run for them) because they're using a parse_args argument that isn't defined. I added a guess definition in my local copy to get it to run at all: parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=32, # I guessed at this because 32 was used elsewhere in the diffusers codebase
help="helpful description here",
) |
Hey @justindujardin, Could you try to open a new chat for this please? |
@patrickvonplaten I opened this issue for it: #5829 You can feel free to move on with this ticket without addressing my comments. 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks a lot!
--mixed_precision="fp16" | ||
``` | ||
|
||
Most of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so you'll focus on the parameters that are relevant to latent consistency distillation in this guide. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so you'll focus on the parameters that are relevant to latent consistency distillation in this guide. | |
Most of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so we'll focus on the parameters that are relevant to latent consistency distillation in this guide. |
maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer "you" because it's the user who is doing the task :)
- `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows you to specify a better [VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) (only applicable if you're distilling a SDXL model) | ||
- `--w_min` and `--w_max`: the minimum and maximum guidance scale values for guidance scale sampling | ||
- `--num_ddim_timesteps`: the number of timesteps for DDIM sampling | ||
- `--loss_type`: the type of loss (L2 or Huber) to calculate for latent consistency distillation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe some guidance here? I think Huber is generally preferred for latent consistency training? /cc @sayakpaul
|
||
## LoRA | ||
|
||
LoRA is a training technique for significantly reducing the number of trainable parameters. As a result, training is faster and it is easier to store the resulting weights because they are a lot smaller (~100MBs). Use the [train_lcm_distill_lora_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py) or [train_lcm_distill_lora_sdxl.wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py) script to train with LoRA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LoRA is a training technique for significantly reducing the number of trainable parameters. As a result, training is faster and it is easier to store the resulting weights because they are a lot smaller (~100MBs). Use the [train_lcm_distill_lora_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py) or [train_lcm_distill_lora_sdxl.wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py) script to train with LoRA. | |
LoRA is a training technique for significantly reducing the number of trainable parameters. As a result, training is faster and it is easier to store the resulting weights because they are a lot smaller (~100MBs). Use the [train_lcm_distill_lora_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py) or [train_lcm_distill_lora_sdxl.wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py) scripts to train with LoRA. |
How about linking to the scripts designed for HF datasets? To be included after they are merged?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah lets include after they're merged!
|
||
If you're training on a GPU with limited vRAM, try enabling `gradient_checkpointing`, `gradient_accumulation_steps`, and `mixed_precision` to reduce memory-usage and speedup training. You can reduce your memory-usage even more by enabling memory-efficient attention with [xFormers](../optimization/xformers) and [bitsandbytes'](https://github.com/TimDettmers/bitsandbytes) 8-bit optimizer. | ||
|
||
This guide will explore the [train_lcm_distill_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py) script to help you become more familiar with it, and how you can adapt it for your own use-case. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd maybe clarify from the onset that this script is designed for webdataset
datasets, but we can do it when we add the hf datasets version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do after we've added the other version :)
Ref: #5908 for suggested use of public wds datasets. |
0bcf493
to
3d98685
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
* first draft * feedback
A first draft of the training docs for LCM. Please feel free to add more details where necessary and suggest any clarifications! ❤️
Complementary to the LCM inference docs in #5782