@@ -457,6 +457,8 @@ def load_textual_inversion(
457
457
def unload_textual_inversion (
458
458
self ,
459
459
tokens : Optional [Union [str , List [str ]]] = None ,
460
+ tokenizer : Optional ["PreTrainedTokenizer" ] = None ,
461
+ text_encoder : Optional ["PreTrainedModel" ] = None ,
460
462
):
461
463
r"""
462
464
Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]
@@ -481,11 +483,28 @@ def unload_textual_inversion(
481
483
482
484
# Remove just one token
483
485
pipeline.unload_textual_inversion("<moe-bius>")
486
+
487
+ # Example 3: unload from SDXL
488
+ pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
489
+ embedding_path = hf_hub_download(repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model")
490
+
491
+ # load embeddings to the text encoders
492
+ state_dict = load_file(embedding_path)
493
+
494
+ # load embeddings of text_encoder 1 (CLIP ViT-L/14)
495
+ pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
496
+ # load embeddings of text_encoder 2 (CLIP ViT-G/14)
497
+ pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
498
+
499
+ # Unload explicitly from both text encoders abd tokenizers
500
+ pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
501
+ pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
502
+
484
503
```
485
504
"""
486
505
487
- tokenizer = getattr (self , "tokenizer" , None )
488
- text_encoder = getattr (self , "text_encoder" , None )
506
+ tokenizer = tokenizer or getattr (self , "tokenizer" , None )
507
+ text_encoder = text_encoder or getattr (self , "text_encoder" , None )
489
508
490
509
# Get textual inversion tokens and ids
491
510
token_ids = []
0 commit comments