Skip to content

Commit f0216b7

Browse files
H3zisayakpaulyiyixuxufabioriganolinoytsaban
authored
allow explicit tokenizer & text_encoder in unload_textual_inversion (#6977)
* allow passing tokenizer & text_encoder to unload_textual_inversion --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: Fabio Rigano <fabio2rigano@gmail.com> Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
1 parent d5f444d commit f0216b7

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

src/diffusers/loaders/textual_inversion.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,8 @@ def load_textual_inversion(
457457
def unload_textual_inversion(
458458
self,
459459
tokens: Optional[Union[str, List[str]]] = None,
460+
tokenizer: Optional["PreTrainedTokenizer"] = None,
461+
text_encoder: Optional["PreTrainedModel"] = None,
460462
):
461463
r"""
462464
Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]
@@ -481,11 +483,28 @@ def unload_textual_inversion(
481483
482484
# Remove just one token
483485
pipeline.unload_textual_inversion("<moe-bius>")
486+
487+
# Example 3: unload from SDXL
488+
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
489+
embedding_path = hf_hub_download(repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model")
490+
491+
# load embeddings to the text encoders
492+
state_dict = load_file(embedding_path)
493+
494+
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
495+
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
496+
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
497+
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
498+
499+
# Unload explicitly from both text encoders abd tokenizers
500+
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
501+
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
502+
484503
```
485504
"""
486505

487-
tokenizer = getattr(self, "tokenizer", None)
488-
text_encoder = getattr(self, "text_encoder", None)
506+
tokenizer = tokenizer or getattr(self, "tokenizer", None)
507+
text_encoder = text_encoder or getattr(self, "text_encoder", None)
489508

490509
# Get textual inversion tokens and ids
491510
token_ids = []

0 commit comments

Comments
 (0)