-
Notifications
You must be signed in to change notification settings - Fork 6.3k
stable diffusion fine-tuning #356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
66a51ed
begin text2image script
patil-suraj d062da6
loading the datasets, preprocessing & transforms
anton-l 3ed3a34
handle input features correctly
anton-l 066af65
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patil-suraj ce569a1
add gradient checkpointing support
patil-suraj 837a586
fix output names
patil-suraj 3893029
run unet in train mode not text encoder
patil-suraj 61513b0
use no_grad instead of freezing params
patil-suraj ed8f4dd
default max steps None
patil-suraj e4fb478
pad to longest
patil-suraj 7414de1
don't pad when tokenizing
patil-suraj ce4a7a2
fix encode on multi gpu
patil-suraj 95d7836
fix stupid bug
patil-suraj 54b700d
add random flip
patil-suraj 725fb96
add ema
patil-suraj 584b3f7
fix ema
patil-suraj 0f0b098
put ema on cpu
patil-suraj 56a9fd0
improve EMA model
patil-suraj 5c05401
contiguous_format
patil-suraj ad42acb
don't warp vae and text encode in accelerate
patil-suraj 4e54ae2
remove no_grad
patil-suraj 9cf8d2b
use randn_like
patil-suraj 2feec19
fix resize
patil-suraj 7044b2d
improve few things
patil-suraj 4809770
log epoch loss
patil-suraj fdfbad3
set log level
patil-suraj 03d124b
don't log each step
patil-suraj abebd23
remove max_length from collate
patil-suraj 4779819
style
patil-suraj c6ad723
add report_to option
patil-suraj f4cd6ff
make scale_lr false by default
patil-suraj 4cc238d
add grad clipping
patil-suraj 438514c
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patil-suraj c643e94
add an option to use 8bit adam
patil-suraj 3caf7c6
fix logging in multi-gpu, log every step
patil-suraj 518448c
more comments
patil-suraj 926c20e
remove eval for now
patil-suraj 12d19df
adress review comments
patil-suraj ac0b09e
add requirements file
patil-suraj 48930b7
begin readme
patil-suraj 7964342
begin readme
patil-suraj 1c8387c
fix typo
patil-suraj 3eea0db
fix push to hub
patil-suraj eb8e6c3
populate readme
patil-suraj 2cb7c43
update readme
patil-suraj 7228818
remove use_auth_token from the script
patil-suraj 4ffbf57
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patil-suraj b08d85d
address some review comments
patil-suraj 64338d2
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patil-suraj db8e31a
better mixed precision support
patil-suraj 17cb6e7
remove redundant to
patil-suraj 25625ce
create ema model early
patil-suraj 5d71880
Apply suggestions from code review
patil-suraj 3a6e4f2
better description for train_data_dir
patil-suraj b05a860
Merge branch 'finetune-txt2img' of https://github.com/huggingface/dif…
patil-suraj 1c8b026
add diffusers in requirements
patil-suraj 5b22178
update dataset_name_mapping
patil-suraj f0b4357
update readme
patil-suraj f9a4025
add inference example
patil-suraj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Stable Diffusion text-to-image fine-tuning | ||
|
||
The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset. | ||
|
||
___Note___: | ||
|
||
___This script is experimental. The script fine-tunes the whole model and often times the model overifits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___ | ||
|
||
|
||
## Running locally | ||
### Installing the dependencies | ||
|
||
Before running the scripts, make sure to install the library's training dependencies: | ||
|
||
```bash | ||
pip install git+https://github.com/huggingface/diffusers.git | ||
pip install -U -r requirements.txt | ||
``` | ||
|
||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | ||
|
||
```bash | ||
accelerate config | ||
``` | ||
|
||
### Pokemon example | ||
|
||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. | ||
|
||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). | ||
|
||
Run the following command to authenticate your token | ||
|
||
```bash | ||
huggingface-cli login | ||
``` | ||
|
||
If you have already cloned the repo, then you won't need to go through these steps. | ||
|
||
<br> | ||
|
||
#### Hardware | ||
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory. | ||
|
||
```bash | ||
export MODEL_NAME="CompVis/stable-diffusion-v1-4" | ||
export dataset_name="lambdalabs/pokemon-blip-captions" | ||
|
||
accelerate launch train_text_to_image.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME \ | ||
--dataset_name=$dataset_name \ | ||
--use_ema \ | ||
--resolution=512 --center_crop --random_flip \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--gradient_checkpointing \ | ||
--mixed_precision="fp16" \ | ||
--max_train_steps=15000 \ | ||
--learning_rate=1e-05 \ | ||
--max_grad_norm=1 \ | ||
--lr_scheduler="constant" --lr_warmup_steps=0 \ | ||
--output_dir="sd-pokemon-model" | ||
``` | ||
|
||
|
||
To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). | ||
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. | ||
|
||
```bash | ||
export MODEL_NAME="CompVis/stable-diffusion-v1-4" | ||
export TRAIN_DIR="path_to_your_dataset" | ||
|
||
accelerate launch train_text_to_image.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME \ | ||
--train_data_dir=$TRAIN_DIR \ | ||
--use_ema \ | ||
--resolution=512 --center_crop --random_flip \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--gradient_checkpointing \ | ||
--mixed_precision="fp16" \ | ||
--max_train_steps=15000 \ | ||
--learning_rate=1e-05 \ | ||
--max_grad_norm=1 \ | ||
--lr_scheduler="constant" --lr_warmup_steps=0 \ | ||
--output_dir="sd-pokemon-model" | ||
``` | ||
|
||
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline` | ||
|
||
|
||
```python | ||
from diffusers import StableDiffusionPipeline | ||
|
||
model_path = "path_to_saved_model" | ||
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) | ||
pipe.to("cuda") | ||
|
||
image = pipe(prompt="yoda").images[0] | ||
image.save("yoda-pokemon.png") | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
diffusers==0.4.1 | ||
accelerate | ||
patil-suraj marked this conversation as resolved.
Show resolved
Hide resolved
patil-suraj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
torchvision | ||
transformers>=4.21.0 | ||
ftfy | ||
tensorboard | ||
modelcards |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.