diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py index 574b89233cc1..a45a9ac6a5e7 100644 --- a/src/diffusers/loaders/textual_inversion.py +++ b/src/diffusers/loaders/textual_inversion.py @@ -457,7 +457,7 @@ def load_textual_inversion( def unload_textual_inversion( self, - tokens: Optional[Union[str, List[str]]] = None, + token: Optional[Union[str, List[str]]] = None, tokenizer: Optional["PreTrainedTokenizer"] = None, text_encoder: Optional["PreTrainedModel"] = None, ): @@ -511,10 +511,10 @@ def unload_textual_inversion( # Unload explicitly from both text encoders abd tokenizers pipeline.unload_textual_inversion( - tokens=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer + token=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer ) pipeline.unload_textual_inversion( - tokens=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2 + token=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2 ) ``` """ @@ -526,28 +526,28 @@ def unload_textual_inversion( token_ids = [] last_special_token_id = None - if tokens: - if isinstance(tokens, str): - tokens = [tokens] + if token: + if isinstance(token, str): + token = [token] for added_token_id, added_token in tokenizer.added_tokens_decoder.items(): if not added_token.special: - if added_token.content in tokens: + if added_token.content in token: token_ids.append(added_token_id) else: last_special_token_id = added_token_id if len(token_ids) == 0: raise ValueError("No tokens to remove found") else: - tokens = [] + token = [] for added_token_id, added_token in tokenizer.added_tokens_decoder.items(): if not added_token.special: token_ids.append(added_token_id) - tokens.append(added_token.content) + token.append(added_token.content) else: last_special_token_id = added_token_id # Delete from tokenizer - for token_id, token_to_remove in zip(token_ids, tokens): + for token_id, token_to_remove in zip(token_ids, token): del tokenizer._added_tokens_decoder[token_id] del tokenizer._added_tokens_encoder[token_to_remove]