Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
269ccf8
initial script
kashif Sep 15, 2023
67d734d
formatting
kashif Sep 15, 2023
ba1c3b7
Merge branch 'main' into wuerstchen-train
kashif Sep 17, 2023
3c7ac6f
prior trainer wip
kashif Sep 18, 2023
b412828
add efficient_net_encoder
kashif Sep 18, 2023
a24131a
add CLIPTextModel
kashif Sep 18, 2023
b4f2cdb
add prior ema support
kashif Sep 19, 2023
3c8f6ed
optimizer
kashif Sep 19, 2023
34aab3e
fix typo
kashif Sep 19, 2023
9def4b5
add dataloader
kashif Sep 19, 2023
d8fb19c
prompt_embeds and image_embeds
kashif Sep 19, 2023
3fe9079
intial training loop
kashif Sep 19, 2023
3a22be0
fix output_dir
kashif Sep 19, 2023
6b5d2e7
fix add_noise
kashif Sep 19, 2023
8f9a683
accelerator check
kashif Sep 19, 2023
8d93fe5
make effnet_transforms dynamic
kashif Sep 20, 2023
7a46b1e
fix training loop
kashif Sep 20, 2023
61c845c
add validation logging
kashif Sep 21, 2023
98ab7f9
Merge branch 'main' into wuerstchen-train
kashif Sep 22, 2023
fdc2c92
use loaded text_encoder
kashif Sep 21, 2023
749f977
use PreTrainedTokenizerFast
kashif Sep 21, 2023
a2a9b97
load weigth from pickle
kashif Sep 23, 2023
81384fb
save_model_card
kashif Sep 23, 2023
64b3d30
remove unused file
kashif Sep 23, 2023
f20a6fc
fix typos
kashif Sep 23, 2023
d9e1d47
save prior pipeilne in its own folder
kashif Sep 23, 2023
67c37e3
fix imports
kashif Sep 23, 2023
021b0a4
fix pipe_t2i
kashif Sep 24, 2023
c2faf11
scale image_embeds
kashif Sep 25, 2023
77924ea
remove snr_gamma
kashif Sep 25, 2023
85efacd
format
kashif Sep 25, 2023
3433ebb
initial lora prior training
kashif Sep 25, 2023
10fb635
log_validation and save
kashif Sep 25, 2023
353d71e
Merge branch 'main' into wuerstchen-train
kashif Sep 26, 2023
0a7ffa9
initial gradient working
kashif Sep 26, 2023
d9b6b48
remove save/load hooks
kashif Sep 26, 2023
dbc238b
set set_attn_processor on prior_prior
kashif Sep 26, 2023
af4dcae
add lora script
kashif Sep 27, 2023
bc776dc
typos
kashif Sep 27, 2023
7989eae
use LoraLoaderMixin for prior pipeline
kashif Sep 27, 2023
70cd979
fix usage
kashif Sep 27, 2023
040de92
Merge branch 'main' into wuerstchen-train
kashif Sep 27, 2023
0454a87
make fix-copies
kashif Sep 27, 2023
7435c70
yse repo_id
kashif Sep 27, 2023
2eb5d9c
write_lora_layers is a staitcmethod
kashif Sep 27, 2023
234bebb
use defualts
kashif Sep 27, 2023
afb001c
fix defaults
kashif Sep 27, 2023
78f2aae
Merge branch 'main' into wuerstchen-train
kashif Sep 27, 2023
fd6f57f
Merge branch 'main' into wuerstchen-train
kashif Sep 28, 2023
47a31ab
undo
kashif Sep 28, 2023
5128f52
Merge branch 'main' into wuerstchen-train
kashif Sep 28, 2023
682f30e
Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
kashif Sep 28, 2023
f0638ff
Update src/diffusers/loaders.py
kashif Sep 28, 2023
8957bf8
Update src/diffusers/loaders.py
kashif Sep 28, 2023
dddd553
Update src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
kashif Sep 28, 2023
9672ec0
Merge branch 'main' into wuerstchen-train
kashif Oct 2, 2023
2a98979
Merge branch 'main' into wuerstchen-train
patrickvonplaten Oct 2, 2023
d767819
Merge branch 'main' into wuerstchen-train
kashif Oct 2, 2023
1ab236f
Update src/diffusers/loaders.py
kashif Oct 4, 2023
402f305
Update src/diffusers/loaders.py
kashif Oct 4, 2023
e1e8f18
Merge branch 'main' into wuerstchen-train
kashif Oct 9, 2023
72e755f
add graident checkpoint support to prior
kashif Oct 9, 2023
43343c6
gradient_checkpointing
kashif Oct 9, 2023
15b2d11
formatting
kashif Oct 9, 2023
4de3fbe
Update examples/wuerstchen/text_to_image/README.md
kashif Oct 9, 2023
162500e
Update examples/wuerstchen/text_to_image/README.md
kashif Oct 9, 2023
b3e54cb
Update examples/wuerstchen/text_to_image/README.md
kashif Oct 9, 2023
12209ef
Update examples/wuerstchen/text_to_image/README.md
kashif Oct 9, 2023
a1527b2
Update examples/wuerstchen/text_to_image/README.md
kashif Oct 9, 2023
a28f5c0
Update examples/wuerstchen/text_to_image/train_text_to_image_lora_pri…
kashif Oct 9, 2023
cda5de4
Update src/diffusers/loaders.py
kashif Oct 9, 2023
d9964e2
Update examples/wuerstchen/text_to_image/train_text_to_image_prior.py
kashif Oct 9, 2023
f2900d1
Merge branch 'main' into wuerstchen-train
kashif Oct 9, 2023
89fa22f
use default unet and text_encoder
kashif Oct 10, 2023
fb07d27
Merge branch 'main' into wuerstchen-train
kashif Oct 10, 2023
a2dd115
Merge branch 'main' into wuerstchen-train
patrickvonplaten Oct 11, 2023
c6fa49d
Merge branch 'main' into wuerstchen-train
kashif Oct 11, 2023
c23f272
Merge branch 'main' into wuerstchen-train
kashif Oct 13, 2023
cc3adb5
fix test
kashif Oct 14, 2023
a97caf6
Merge branch 'main' into wuerstchen-train
kashif Oct 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions examples/wuerstchen/text_to_image/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Würstchen text-to-image fine-tuning

## Running locally with PyTorch

Before running the scripts, make sure to install the library's training dependencies:

**Important**

To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```

Then cd into the example folder and run
```bash
cd examples/wuerstchen/text_to_image
pip install -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```
For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run:
```bash
huggingface-cli login
```

## Prior training

You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups.

<br>

<!-- accelerate_snippet_start -->
```bash
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

accelerate launch train_text_to_image_prior.py \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME \
--resolution=768 \
--train_batch_size=4 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--dataloader_num_workers=4 \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--checkpoints_total_limit=3 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--validation_prompts="A robot pokemon, 4k photo" \
--report_to="wandb" \
--push_to_hub \
--output_dir="wuerstchen-prior-pokemon-model"
```
<!-- accelerate_snippet_end -->

## Training with LoRA

Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.

In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:

- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.


### Prior Training

First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Pokemon captions dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).

```bash
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

accelerate launch train_text_to_image_prior_lora.py \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME --caption_column="text" \
--resolution=768 \
--train_batch_size=8 \
--num_train_epochs=100 --checkpointing_steps=5000 \
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
--seed=42 \
--rank=4 \
--validation_prompt="cute dragon creature" \
--report_to="wandb" \
--push_to_hub \
--output_dir="wuerstchen-prior-pokemon-lora"
```
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch.nn as nn
from torchvision.models import efficientnet_v2_l, efficientnet_v2_s

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin


class EfficientNetEncoder(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"):
super().__init__()

if effnet == "efficientnet_v2_s":
self.backbone = efficientnet_v2_s(weights="DEFAULT").features
else:
self.backbone = efficientnet_v2_l(weights="DEFAULT").features
self.mapper = nn.Sequential(
nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
)

def forward(self, x):
return self.mapper(self.backbone(x))
7 changes: 7 additions & 0 deletions examples/wuerstchen/text_to_image/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
wandb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no its not technically required but the sample snippet in the README has the --report_to="wandb" option

huggingface-cli
bitsandbytes
deepspeed
Loading