-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[Wuerstchen] text to image training script #5052
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
80 commits
Select commit
Hold shift + click to select a range
269ccf8
initial script
kashif 67d734d
formatting
kashif ba1c3b7
Merge branch 'main' into wuerstchen-train
kashif 3c7ac6f
prior trainer wip
kashif b412828
add efficient_net_encoder
kashif a24131a
add CLIPTextModel
kashif b4f2cdb
add prior ema support
kashif 3c8f6ed
optimizer
kashif 34aab3e
fix typo
kashif 9def4b5
add dataloader
kashif d8fb19c
prompt_embeds and image_embeds
kashif 3fe9079
intial training loop
kashif 3a22be0
fix output_dir
kashif 6b5d2e7
fix add_noise
kashif 8f9a683
accelerator check
kashif 8d93fe5
make effnet_transforms dynamic
kashif 7a46b1e
fix training loop
kashif 61c845c
add validation logging
kashif 98ab7f9
Merge branch 'main' into wuerstchen-train
kashif fdc2c92
use loaded text_encoder
kashif 749f977
use PreTrainedTokenizerFast
kashif a2a9b97
load weigth from pickle
kashif 81384fb
save_model_card
kashif 64b3d30
remove unused file
kashif f20a6fc
fix typos
kashif d9e1d47
save prior pipeilne in its own folder
kashif 67c37e3
fix imports
kashif 021b0a4
fix pipe_t2i
kashif c2faf11
scale image_embeds
kashif 77924ea
remove snr_gamma
kashif 85efacd
format
kashif 3433ebb
initial lora prior training
kashif 10fb635
log_validation and save
kashif 353d71e
Merge branch 'main' into wuerstchen-train
kashif 0a7ffa9
initial gradient working
kashif d9b6b48
remove save/load hooks
kashif dbc238b
set set_attn_processor on prior_prior
kashif af4dcae
add lora script
kashif bc776dc
typos
kashif 7989eae
use LoraLoaderMixin for prior pipeline
kashif 70cd979
fix usage
kashif 040de92
Merge branch 'main' into wuerstchen-train
kashif 0454a87
make fix-copies
kashif 7435c70
yse repo_id
kashif 2eb5d9c
write_lora_layers is a staitcmethod
kashif 234bebb
use defualts
kashif afb001c
fix defaults
kashif 78f2aae
Merge branch 'main' into wuerstchen-train
kashif fd6f57f
Merge branch 'main' into wuerstchen-train
kashif 47a31ab
undo
kashif 5128f52
Merge branch 'main' into wuerstchen-train
kashif 682f30e
Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
kashif f0638ff
Update src/diffusers/loaders.py
kashif 8957bf8
Update src/diffusers/loaders.py
kashif dddd553
Update src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
kashif 9672ec0
Merge branch 'main' into wuerstchen-train
kashif 2a98979
Merge branch 'main' into wuerstchen-train
patrickvonplaten d767819
Merge branch 'main' into wuerstchen-train
kashif 1ab236f
Update src/diffusers/loaders.py
kashif 402f305
Update src/diffusers/loaders.py
kashif e1e8f18
Merge branch 'main' into wuerstchen-train
kashif 72e755f
add graident checkpoint support to prior
kashif 43343c6
gradient_checkpointing
kashif 15b2d11
formatting
kashif 4de3fbe
Update examples/wuerstchen/text_to_image/README.md
kashif 162500e
Update examples/wuerstchen/text_to_image/README.md
kashif b3e54cb
Update examples/wuerstchen/text_to_image/README.md
kashif 12209ef
Update examples/wuerstchen/text_to_image/README.md
kashif a1527b2
Update examples/wuerstchen/text_to_image/README.md
kashif a28f5c0
Update examples/wuerstchen/text_to_image/train_text_to_image_lora_pri…
kashif cda5de4
Update src/diffusers/loaders.py
kashif d9964e2
Update examples/wuerstchen/text_to_image/train_text_to_image_prior.py
kashif f2900d1
Merge branch 'main' into wuerstchen-train
kashif 89fa22f
use default unet and text_encoder
kashif fb07d27
Merge branch 'main' into wuerstchen-train
kashif a2dd115
Merge branch 'main' into wuerstchen-train
patrickvonplaten c6fa49d
Merge branch 'main' into wuerstchen-train
kashif c23f272
Merge branch 'main' into wuerstchen-train
kashif cc3adb5
fix test
kashif a97caf6
Merge branch 'main' into wuerstchen-train
kashif File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Würstchen text-to-image fine-tuning | ||
|
|
||
| ## Running locally with PyTorch | ||
|
|
||
| Before running the scripts, make sure to install the library's training dependencies: | ||
|
|
||
| **Important** | ||
|
|
||
| To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment: | ||
| ```bash | ||
| git clone https://github.com/huggingface/diffusers | ||
| cd diffusers | ||
| pip install . | ||
| ``` | ||
|
|
||
| Then cd into the example folder and run | ||
| ```bash | ||
| cd examples/wuerstchen/text_to_image | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
| And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | ||
|
|
||
| ```bash | ||
| accelerate config | ||
| ``` | ||
| For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run: | ||
| ```bash | ||
| huggingface-cli login | ||
| ``` | ||
|
|
||
| ## Prior training | ||
|
|
||
| You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups. | ||
|
|
||
| <br> | ||
|
|
||
| <!-- accelerate_snippet_start --> | ||
| ```bash | ||
| export DATASET_NAME="lambdalabs/pokemon-blip-captions" | ||
|
|
||
| accelerate launch train_text_to_image_prior.py \ | ||
| --mixed_precision="fp16" \ | ||
| --dataset_name=$DATASET_NAME \ | ||
| --resolution=768 \ | ||
| --train_batch_size=4 \ | ||
| --gradient_accumulation_steps=4 \ | ||
| --gradient_checkpointing \ | ||
| --dataloader_num_workers=4 \ | ||
| --max_train_steps=15000 \ | ||
| --learning_rate=1e-05 \ | ||
| --max_grad_norm=1 \ | ||
| --checkpoints_total_limit=3 \ | ||
| --lr_scheduler="constant" --lr_warmup_steps=0 \ | ||
| --validation_prompts="A robot pokemon, 4k photo" \ | ||
| --report_to="wandb" \ | ||
| --push_to_hub \ | ||
| --output_dir="wuerstchen-prior-pokemon-model" | ||
| ``` | ||
| <!-- accelerate_snippet_end --> | ||
|
|
||
| ## Training with LoRA | ||
|
|
||
| Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. | ||
|
|
||
| In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: | ||
|
|
||
| - Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). | ||
| - Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. | ||
| - LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. | ||
|
|
||
|
|
||
| ### Prior Training | ||
|
|
||
| First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Pokemon captions dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). | ||
|
|
||
| ```bash | ||
| export DATASET_NAME="lambdalabs/pokemon-blip-captions" | ||
|
|
||
| accelerate launch train_text_to_image_prior_lora.py \ | ||
| --mixed_precision="fp16" \ | ||
| --dataset_name=$DATASET_NAME --caption_column="text" \ | ||
| --resolution=768 \ | ||
| --train_batch_size=8 \ | ||
| --num_train_epochs=100 --checkpointing_steps=5000 \ | ||
| --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ | ||
| --seed=42 \ | ||
| --rank=4 \ | ||
| --validation_prompt="cute dragon creature" \ | ||
| --report_to="wandb" \ | ||
| --push_to_hub \ | ||
| --output_dir="wuerstchen-prior-pokemon-lora" | ||
| ``` |
Empty file.
23 changes: 23 additions & 0 deletions
23
examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| import torch.nn as nn | ||
| from torchvision.models import efficientnet_v2_l, efficientnet_v2_s | ||
|
|
||
| from diffusers.configuration_utils import ConfigMixin, register_to_config | ||
| from diffusers.models.modeling_utils import ModelMixin | ||
|
|
||
|
|
||
| class EfficientNetEncoder(ModelMixin, ConfigMixin): | ||
| @register_to_config | ||
| def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"): | ||
| super().__init__() | ||
|
|
||
| if effnet == "efficientnet_v2_s": | ||
| self.backbone = efficientnet_v2_s(weights="DEFAULT").features | ||
| else: | ||
| self.backbone = efficientnet_v2_l(weights="DEFAULT").features | ||
| self.mapper = nn.Sequential( | ||
| nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False), | ||
| nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| return self.mapper(self.backbone(x)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| accelerate>=0.16.0 | ||
| torchvision | ||
| transformers>=4.25.1 | ||
| wandb | ||
| huggingface-cli | ||
| bitsandbytes | ||
| deepspeed | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no its not technically required but the sample snippet in the README has the
--report_to="wandb"option