Skip to content

Commit 7679f36

Browse files
committed
first draft
1 parent e516858 commit 7679f36

File tree

3 files changed

+272
-4
lines changed

3 files changed

+272
-4
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
title: Load safetensors
3535
- local: using-diffusers/other-formats
3636
title: Load different Stable Diffusion formats
37+
- local: using-diffusers/loading_adapters
38+
title: Load adapters
3739
- local: using-diffusers/push_to_hub
3840
title: Push files to the Hub
3941
title: Loading & Hub
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Load adapters
14+
15+
[[open-in-colab]]
16+
17+
There are several [training](../training/overview) techniques for personalizing diffusion models to generate images of a specific subject or images in certain styles. Each of these training methods produce a different type of adapter. Some of the adapters generate an entirely new model, while other adapters only modify a smaller set of embeddings or weights. This means the loading process for each adapter is also different.
18+
19+
This guide will show you how to load DreamBooth, textual inversion, and LoRA weights.
20+
21+
<Tip>
22+
23+
You can start by looking at [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer), [LoRA the Explorer](multimodalart/LoraTheExplorer), and the [Diffusers Models Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) for checkpoints and embeddings to use.
24+
25+
</Tip>
26+
27+
## DreamBooth
28+
29+
[DreamBooth](https://dreambooth.github.io/) finetunes the *entire diffusion model* on just several images of a subject to generate images of the subject in new styles and settings. This method works by using a special word in the prompt that the model learns to associate with the subject image. Of all the training methods, DreamBooth produces the largest file size (usually a few GBs) because it is a full checkpoint model. But this also means loading a DreamBooth checkpoint is the same as loading any other checkpoint.
30+
31+
For example, the [herge_style](https://huggingface.co/sd-dreambooth-library/herge-style) checkpoint is trained on just 10 images drawn by Hergé and now it can generate images in that style. For it to work, you need to include the special word `herge_style` in your prompt to trigger the checkpoint:
32+
33+
```py
34+
from diffusers import AutoPipelineForText2Image
35+
import torch
36+
37+
pipeline = AutoPipelineForText2Image.from_pretrained("sd-dreambooth-library/herge-style", torch_dtype=torch.float16).to("cuda")
38+
prompt = "A cute herge_style brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
39+
image = pipeline(prompt).images[0]
40+
```
41+
42+
<div class="flex justify-center">
43+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_dreambooth.png" />
44+
</div>
45+
46+
## Textual inversion
47+
48+
[Textual inversion](https://textual-inversion.github.io/) is very similar to DreamBooth, and it can also personalize a diffusion model to generate certain concepts (styles, objects) from just a few images. This method works by training and finding new embeddings that represent the images you provide with a special word in the prompt. As a result, the diffusion model weights stays the same and the training process produces a relatively tiny (a few KBs) file.
49+
50+
Because textual inversion creates embeddings, you need to use textual inversion with another model. For example, load a model:
51+
52+
```py
53+
from diffusers import AutoPipelineForText2Image
54+
import torch
55+
56+
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
57+
```
58+
59+
Then you can load the textual inversion embeddings with the [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] method and generate some images. Let's load the [sd-concepts-library/gta5-artwork](https://huggingface.co/sd-concepts-library/gta5-artwork) embeddings, and you'll need to include the special word `<gta5-artwork>` in your prompt to trigger it:
60+
61+
```py
62+
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
63+
prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration, <gta5-artwork> style"
64+
image = pipeline(prompt).images[0]
65+
```
66+
67+
<div class="flex justify-center">
68+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_txt_embed.png" />
69+
</div>
70+
71+
Textual inversion can also be trained on undesirable things to create *negative embeddings* to discourage a model from generating images with those undesirable things like blurry images or extra fingers on a hand. This can be a easy way to improve your prompt. You'll load the embeddings the same way with the [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] method.
72+
73+
This time, you'll need two more parameters:
74+
75+
- `weight_name`: specifies the weight file to load if the file was saved in the 🤗 Diffusers format but saved with a specific name or if the file is in the A1111 format
76+
- `token`: specifies the special word to use in the prompt to trigger the embeddings
77+
78+
Let's load the [sayakpaul/EasyNegative-test](https://huggingface.co/sayakpaul/EasyNegative-test) embeddings:
79+
80+
```py
81+
pipeline.load_textual_inversion(
82+
"sayakpaul/EasyNegative-test", weight_name="EasyNegative.safetensors", token="EasyNegative"
83+
)
84+
```
85+
86+
Now you can use the `token` to generate an image with the negative embeddings:
87+
88+
```py
89+
prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
90+
negative_prompt = "EasyNegative"
91+
92+
image = pipeline(prompt, negative_prompt=negative_prompt, num_inference_steps=50).images[0]
93+
```
94+
95+
<div class="flex justify-center">
96+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png" />
97+
</div>
98+
99+
## LoRA
100+
101+
[Low-Rank Adaptation (LoRA)](https://huggingface.co/papers/2106.09685) is a popular training technique because it is fast and generates smaller file sizes (a couple hundred MBs). Like the other methods in this guide, LoRA can train a model to learn new styles from just a few images. It works by inserting new weights into the diffusion model and then only the new weights are trained instead of the entire model. This makes LoRAs faster to train and easier to store.
102+
103+
LoRAs also need to be used with another model. For example, load a model:
104+
105+
```py
106+
from diffusers import AutoPipelineForText2Image
107+
import torch
108+
109+
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
110+
```
111+
112+
Then use the [`~loaders.LoraLoaderMixin.load_lora_weights`] method to load the [ostris/super-cereal-sdxl-lora](https://huggingface.co/ostris/super-cereal-sdxl-lora) weights and specify the weights filename from the repository:
113+
114+
```py
115+
pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora", weight_name="cereal_box_sdxl_v1.safetensors")
116+
prompt = "bears, pizza bites"
117+
image = pipeline(prompt).images[0]
118+
```
119+
120+
<div class="flex justify-center">
121+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_lora.png" />
122+
</div>
123+
124+
The [`~loaders.LoraLoaderMixin.load_lora_weights`] method loads LoRA weights into both the UNet and text encoder. It is the preferred way for loading LoRAs because it can handle cases where:
125+
126+
- the LoRA weights don't have separate identifiers for the UNet and text encoder
127+
- the LoRA weights have separate identifiers for the UNet and text encoder
128+
129+
But if you only need to load LoRA into the UNet, then you can use the [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. Let's load the [jbilcke-hf/sdxl-cinematic-1](https://huggingface.co/jbilcke-hf/sdxl-cinematic-1) LoRA:
130+
131+
```py
132+
from diffusers import AutoPipelineForText2Image
133+
import torch
134+
135+
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
136+
pipeline.unet.load_attn_procs("jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors")
137+
138+
# use cnmt in the prompt to trigger the LoRA
139+
prompt = "A cute cnmt eating a slice of pizza, stunning color scheme, masterpiece, illustration"
140+
image = pipeline(prompt).images[0]
141+
```
142+
143+
<div class="flex justify-center">
144+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
145+
</div>
146+
147+
<Tip>
148+
149+
For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
150+
151+
</Tip>
152+
153+
To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
154+
155+
```py
156+
pipeline.unload_lora_weights()
157+
```
158+
159+
### Load multiple LoRAs
160+
161+
It can be fun to use multiple LoRAs together to create something entirely new and unique. The [`~loaders.LoraLoaderMixin.fuse_lora`] method allows you to fuse the LoRA weights with the original weights of the underlying model.
162+
163+
<Tip>
164+
165+
Fusing the weights can lead to a speedup in inference latency because you don't need to separately load the base model and LoRA! You can save your fused pipeline with [`~DiffusionPipeline.save_pretrained`] to avoid loading and fusng the weights every time you want to use the model.
166+
167+
</Tip>
168+
169+
Load an initial model:
170+
171+
```py
172+
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
173+
import torch
174+
175+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
176+
pipeline = StableDiffusionXLPipeline.from_pretrained(
177+
"stabilityai/stable-diffusion-xl-base-1.0",
178+
vae=vae,
179+
torch_dtype=torch.float16,
180+
).to("cuda")
181+
```
182+
183+
Then load the LoRA checkpoint and fuse it with the original weights. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.LoraLoaderMixin.fuse_lora`] method because it won't work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
184+
185+
If you need to reset the original model weights for any reason (use a different `lora_scale`), you should use the [`~loaders.LoraLoaderMixin.unfuse_lora`] method.
186+
187+
```py
188+
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl")
189+
pipeline.fuse_lora(lora_scale=0.7)
190+
191+
# to unfuse the LoRA weights
192+
pipeline.unfuse_lora()
193+
```
194+
195+
Then fuse this pipeline with the next set of LoRA weights:
196+
197+
```py
198+
pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora")
199+
pipeline.fuse_lora(lora_scale=0.7)
200+
```
201+
202+
<Tip warning={true}>
203+
204+
You can't unfuse multiple LoRA checkpoints so if you need to reset the model to its original weights, you'll need to reload it.
205+
206+
</Tip>
207+
208+
Now you can generate an image that uses the weights from both LoRAs:
209+
210+
```py
211+
prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
212+
image = pipeline(prompt).images[0]
213+
```
214+
215+
### Kohya and TheLastBen
216+
217+
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
218+
219+
Let's download the [Blueprintify SD XL 1.0](https://civitai.com/models/150986/blueprintify-sd-xl-10) checkpoint from [Civitai](https://civitai.com/):
220+
221+
```py
222+
!wget https://civitai.com/api/download/models/168776 -O blueprintify-sd-xl-10.safetensors
223+
```
224+
225+
Load the LoRA checkpoint with the [`~loaders.LoraLoaderMixin.load_lora_weights`] method, and specify filename in the `weight_name` parameter:
226+
227+
```py
228+
from diffusers import AutoPipelineForText2Image
229+
import torch
230+
231+
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0").to("cuda")
232+
pipeline.load_lora_weights("path/to/weights", weight_name="blueprintify-sd-xl-10.safetensors")
233+
```
234+
235+
Then you can generate an image as you normally would:
236+
237+
```py
238+
# use bl3uprint in the prompt to trigger the LoRA
239+
prompt = "bl3uprint, a highly detailed blueprint of the eiffel tower, explaining how to build all parts, many txt, blueprint grid backdrop"
240+
image = pipeline(prompt).images[0]
241+
```
242+
243+
<Tip warning={true}>
244+
245+
Some limitations of using Kohya LoRAs with 🤗 Diffusers include:
246+
247+
- Images may not look like those generated by UIs - like ComfyUI - for multiple reasons which are explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
248+
- [LyCORIS checkpoints](https://github.com/KohakuBlueleaf/LyCORIS) aren't fully supported. The [`~loaders.LoraLoaderMixin.load_lora_weights`] method loads LyCORIS checkpoints with LoRA and LoCon modules, but Hada and LoKR are not supported.
249+
250+
</Tip>
251+
252+
Loading a checkpoint from TheLastBen is very similar. For example, to load the [TheLastBen/William_Eggleston_Style_SDXL](https://huggingface.co/TheLastBen/William_Eggleston_Style_SDXL) checkpoint:
253+
254+
```py
255+
from diffusers import AutoPipelineForText2Image
256+
import torch
257+
258+
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
259+
pipeline.load_lora_weights("TheLastBen/William_Eggleston_Style_SDXL", weight_name="wegg.safetensors")
260+
261+
# use by william eggleston in the prompt to trigger the LoRA
262+
prompt = "a house by william eggleston, sunrays, beautiful, sunlight, sunrays, beautiful"
263+
image = pipeline(prompt=prompt).images[0]
264+
```

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,10 +1235,12 @@ def forward(
12351235
deprecate(
12361236
"T2I should not use down_block_additional_residuals",
12371237
"1.3.0",
1238-
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated "
1239-
" and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only"
1240-
" be used for ControlNet. Please make sure use"
1241-
" `down_intrablock_additional_residuals` instead. ",
1238+
(
1239+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated "
1240+
" and will be removed in diffusers 1.3.0. `down_block_additional_residuals`"
1241+
" should only be used for ControlNet. Please make sure use"
1242+
" `down_intrablock_additional_residuals` instead. "
1243+
),
12421244
standard_warn=False,
12431245
)
12441246
down_intrablock_additional_residuals = down_block_additional_residuals

0 commit comments

Comments
 (0)