Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Add LoRA training script #1884

Merged
merged 29 commits into from
Jan 18, 2023
Merged

[LoRA] Add LoRA training script #1884

merged 29 commits into from
Jan 18, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jan 2, 2023

Update:

Training seems to work fine -> see some results here (after 4min of training on a A100): https://wandb.ai/patrickvonplaten/stable_diffusion_lora/reports/LoRA-training-results--VmlldzozMzI4MTI3?accessToken=d7x29esww3nvbrilo18hyto784w4oep721jiqgophgzdhztytwko1stcscp38gld

Possible API:

The premise of LoRA is to add weights to the model and only train those so that the fine-tuned weights result in some very small portable weights.

Therefore it is important to add a new "LoRA weights loading API" which is currently implemented as follows:

#!/usr/bin/env python3
from diffusers import StableDiffusionPipeline
import torch

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline.unet.load_attn_procs("patrickvonplaten/lora")
pipeline.to("cuda")

prompt = "A photo of sks dog in a bucket"

images = pipeline(prompt, num_images_per_prompt=4).images
    
for i, image in enumerate(images):
    image.save(f"/home/patrick_huggingface_co/images/dog_{i}.png")

The idea is the following. During training only the loRA layers are saved which for the default rank=4 are only around 3MB: https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin

Those weights can then be downloaded easily from the Hub via a novel load_lora loading function as implemented here:
https://github.com/huggingface/diffusers/pull/1884/files#r1069869084

Co-authors:
Co-authored by: https://github.com/cloneofsimo - the first that came up with the idea of using LoRA for stable diffusion in the popular "lora" repo: https://github.com/cloneofsimo/lora

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 2, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten changed the title [Lora] first upload [WIP, Lora] Add lora training script Jan 2, 2023
@patrickvonplaten patrickvonplaten changed the title [WIP, Lora] Add lora training script [WIP, LoRA] Add lora training script Jan 2, 2023
@jrd-rocks
Copy link

jrd-rocks commented Jan 3, 2023

This is great, but should it be part of diffusers? Why not have this as an external library. Maybe this is more of a meta-comment, but imho there is no need for diffusers to be everything. It should be the base where other libraries can build on. To me this seems to be both easier to the contributors/maintainers of the "advanced" libraries and also for diffusers as such, as there's bound to be a difference in development speed/cadence of these new and shiny methods and the core, it won't be pleasant to have to update an amalgamation of code just because the one of the new shiny embedded libraries advances. The cloneofsimo/lora repo works very well with diffusers, wouldn't it be better to do all lora related development there (or if for this implementation is incompatible, just a new repo so that there are two lora flavored libraries, which i think is preferable over just putting everything in diffusers)

@cloneofsimo
Copy link
Contributor

I would say it is honor to have my project be an inspiration to become one of official huggingface source code, but I do have a feeling that we are reinventing the wheel here...

@patrickvonplaten
Copy link
Contributor Author

Hey @jrd-rocks and @cloneofsimo,

Thanks for your comments - it's super nice to see that other repositories such as https://github.com/cloneofsimo/lora are using diffusers.

@cloneofsimo, would it be ok if we state you as one of the authors of this script and link to your GitHub repo? (or would you maybe like to help with this PR to make you an author by commit?)

This example script will have a couple of differences:

  • 1.) We will load the new "set cross attention" method so that one can load LoRA checkpoints directly with from_pretrained(...)
  • 2.) We won't add all the features (such as hacking CLIP and Feed-forwards) to begin with

=> We intend this script rather as a long-term maintained example script of how to use LoRA, we're happy to refer to yours as "the" LoRA training script if you'd like :-)

@patrickvonplaten
Copy link
Contributor Author

Actually, the main reason we opened this PR was because the community asked for it here: #1715

@brian6091
Copy link

brian6091 commented Jan 4, 2023

Hi @patrickvonplaten

Super cool to see this development. However, I'm wondering why it was necessary to create new CrossAttention classes for LoRA? I can't figure out how this differs from how people have been applying @cloneofsimo 's repo. In case you don't want to pollute the PR, I've posted the question in another discussion here: cloneofsimo/lora#107

Thanks for all your efforts and any insight(s) you can offer!

@patrickvonplaten
Copy link
Contributor Author

Hey @brian6091,

The CrossAttention mechanism was not (just) introduced for LoRA, it's main usage is to be able to tweak attention weights at runtime as explained in: #1639

@brian6091
Copy link

Hey @brian6091,

The CrossAttention mechanism was not (just) introduced for LoRA, it's main usage is to be able to tweak attention weights at runtime as explained in: #1639

Thanks for the context @patrickvonplaten, I better understand the design decision now.

@cloneofsimo
Copy link
Contributor

Hi @patrickvonplaten , thank you for your kind explanations! I would love if you would reference like that for me. Thanks for the hard work!

optimizer_class = torch.optim.AdamW

# Optimizer creation
params_to_optimize = itertools.chain(*[v.parameters() for v in unet.attn_processors.values()])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this also contain the text encoder parameters if train_text_encoder is set to True?

pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
pipeline.set_progress_bar_config(disable=True)
sample_dir = "/home/patrick_huggingface_co/lora-tryout/samples"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably have noted it already but this needs to be more generic I guess.

@patrickvonplaten
Copy link
Contributor Author

Comment on lines +880 to +881
for tracker in accelerator.trackers:
if tracker.name == "wandb":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A cleaner way might be to

if wandb.run is not None:
   ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul This comes from a code snippet where I did different things depending on whether the tracker was tensorflow, wandb or something else (there can be different trackers enabled). But yes, if we are only considering the case of wandb we could maybe simplify it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten the report does not include prompts for the logged images, did you prepare it from a previous run?

@patrickvonplaten patrickvonplaten changed the title [WIP, LoRA] Add lora training script [LoRA] Add LoRA training script Jan 17, 2023
@patrickvonplaten
Copy link
Contributor Author

@pcuenca @patil-suraj @sayakpaul - I think this is ready for a final review :-)

@sayakpaul sayakpaul self-requested a review January 17, 2023 16:48
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me, I like the loaders API! I just have two questions:

  1. Why is the casted to weight_dtype ?
  2. Does generation work when mixed-precision training is enabled ?

examples/dreambooth/README.md Outdated Show resolved Hide resolved
examples/dreambooth/README.md Outdated Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora.py Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora.py Outdated Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora.py Outdated Show resolved Hide resolved
if global_step >= args.max_train_steps:
break

if args.validation_prompt is not None and epoch % 10 == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we let the user control how often to generate, rather than hardcoding the value here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes good point adding validation_epochs

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use autocast here to do generation in fp16. Have you verified this with mixed-precision ?

examples/dreambooth/train_dreambooth_lora.py Show resolved Hide resolved
examples/dreambooth/train_dreambooth_lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders.py Outdated Show resolved Hide resolved
Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive work!

examples/dreambooth/README.md Outdated Show resolved Hide resolved
examples/dreambooth/README.md Outdated Show resolved Hide resolved
examples/dreambooth/README.md Outdated Show resolved Hide resolved
examples/dreambooth/README.md Outdated Show resolved Hide resolved
```

**___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we
use *1e-4* instead of the usual *2e-6*.___**
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Perfect

examples/dreambooth/README.md Outdated Show resolved Hide resolved
Comment on lines +880 to +881
for tracker in accelerator.trackers:
if tracker.name == "wandb":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul This comes from a code snippet where I did different things depending on whether the tracker was tensorflow, wandb or something else (there can be different trackers enabled). But yes, if we are only considering the case of wandb we could maybe simplify it.

Comment on lines +880 to +881
for tracker in accelerator.trackers:
if tracker.name == "wandb":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten the report does not include prompts for the logged images, did you prepare it from a previous run?

src/diffusers/loaders.py Outdated Show resolved Hide resolved
src/diffusers/models/cross_attention.py Show resolved Hide resolved
logger = logging.get_logger(__name__)


ATTN_WEIGHT_NAME = "pytorch_attn_procs.bin"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually maybe it would make more sense to make the naming more general here:

Suggested change
ATTN_WEIGHT_NAME = "pytorch_attn_procs.bin"
ATTN_WEIGHT_NAME = "embeddings.bin"

So that multiple loaders could be applied on the same file? cc @pcuenca @patil-suraj

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it, but wasn't sure. Another idea would be to make it more specific and descriptive, like lora_embeddings.bin and then use different names for others. Not sure what would be easiest to deal with going forward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if embeddings is a good name, since these aren't embeddings, agree with Pedro, maybe make it specific to procs, for example lora_layers.bin or lora_weights.bin.

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Jan 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not a big fan of making it super specific in case, we will want to expand the functionality to more "adapter" layers in the future.

E.g. if someone wants to use both LoRA and textual inversion it'd be nicer to have everything in one file no?
=> going for adapter_weights.bin now, ok for you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, @patil-suraj convinced me to go for LoRA weights specific name - this means in the longer run:

  • We have different files for different parts of the pipeline
  • The user will call multiple loading methods for different parts of the pipeline, e.g.:
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("...")
pipe.unet.load_attn_procs("...")
pipe.load_text_embeddings("...")

But I think that's fine!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be an enabler. I am telling you!

patrickvonplaten and others added 4 commits January 18, 2023 15:11
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
@patrickvonplaten patrickvonplaten merged commit ed616bd into main Jan 18, 2023
@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Jan 18, 2023

Update

When running the script in mixed precision, it's running at 6.5 GB GPU RAM, but actually goes up to 14GB GPU for inference, but one can just generate fewer samples during inference or run the pipeline in a loop.

@jtoy
Copy link

jtoy commented Jan 24, 2023

14gb for inference!?!?! I do think this should be part of diffusers

@jtoy
Copy link

jtoy commented Jan 24, 2023

wandb is a requirement in the current code,that seems like a bug...

@pcuenca
Copy link
Member

pcuenca commented Jan 24, 2023

@jtoy I could complete a fine-tuning run using a 2080 Ti (11 GB of RAM) :) And yes, I agree that wandb should not be in requirements.txt, will open a PR.

@jtoy
Copy link

jtoy commented Jan 24, 2023

@pcuenca what args did you use? I used my titan X 1080 with 12 gb and it dies with OOM. I used the example in the README:

accelerate launch train_dreambooth_lora.py
--pretrained_model_name_or_path=$MODEL_NAME
--instance_data_dir=$INSTANCE_DIR
--output_dir=$OUTPUT_DIR
--instance_prompt="a photo of sks dog"
--resolution=512
--train_batch_size=1
--gradient_accumulation_steps=1
--checkpointing_steps=100
--learning_rate=1e-4
--lr_scheduler="constant"
--lr_warmup_steps=0
--max_train_steps=500
--validation_prompt="A photo of sks dog in a bucket"
--validation_epochs=50
--seed="0"

@pcuenca
Copy link
Member

pcuenca commented Jan 24, 2023

@jtoy Sorry for the confusion, I was referring to a complete fine-tuning using the train_text_to_image_lora.py script, not Dreambooth. I haven't verified memory consumption on Dreambooth yet.

@FBehrad
Copy link

FBehrad commented Jan 31, 2023

Update:

Training seems to work fine -> see some results here (after 4min of training on a A100): https://wandb.ai/patrickvonplaten/stable_diffusion_lora/reports/LoRA-training-results--VmlldzozMzI4MTI3?accessToken=d7x29esww3nvbrilo18hyto784w4oep721jiqgophgzdhztytwko1stcscp38gld

Possible API:

The premise of LoRA is to add weights to the model and only train those so that the fine-tuned weights result in some very small portable weights.

Therefore it is important to add a new "LoRA weights loading API" which is currently implemented as follows:

#!/usr/bin/env python3
from diffusers import StableDiffusionPipeline
import torch

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline.unet.load_attn_procs("patrickvonplaten/lora")
pipeline.to("cuda")

prompt = "A photo of sks dog in a bucket"

images = pipeline(prompt, num_images_per_prompt=4).images
    
for i, image in enumerate(images):
    image.save(f"/home/patrick_huggingface_co/images/dog_{i}.png")

The idea is the following. During training only the loRA layers are saved which for the default rank=4 are only around 3MB: https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin

Those weights can then be downloaded easily from the Hub via a novel load_lora loading function as implemented here: https://github.com/huggingface/diffusers/pull/1884/files#r1069869084

Co-authors: Co-authored by: https://github.com/cloneofsimo - the first that came up with the idea of using LoRA for stable diffusion in the popular "lora" repo: https://github.com/cloneofsimo/lora

Thank you for adding LoRa.
You said the default rank is 4, but when I was checking LoRACrossAttnProcessor, I saw rank is not used at all.
Therefore, how can I change the rank?

@asadm
Copy link
Contributor

asadm commented Feb 1, 2023

@FBehrad I sent a fix for this in #2191

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* [Lora] first upload

* add first lora version

* upload

* more

* first training

* up

* correct

* improve

* finish loaders and inference

* up

* up

* fix more

* up

* finish more

* finish more

* up

* up

* change year

* revert year change

* Change lines

* Add cloneofsimo as co-author.

Co-authored-by: Simo Ryu <cloneofsimo@gmail.com>

* finish

* fix docs

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>

* upload

* finish

Co-authored-by: Simo Ryu <cloneofsimo@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet