Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions src/diffusers/loaders/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,91 @@ def load_textual_inversion(
self.enable_sequential_cpu_offload()

# / Unsafe Code >

def unload_textual_inversion(
self,
tokens: Optional[Union[str, List[str]]] = None,
):
r"""
Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]

Example:
```py
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")

# Example 1
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
pipeline.load_textual_inversion("sd-concepts-library/moeb-style")

# Remove all token embeddings
pipeline.unload_textual_inversion()

# Example 2
pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")

# Remove just one token
pipeline.unload_textual_inversion("<moe-bius>")
```
"""

tokenizer = getattr(self, "tokenizer", None)
text_encoder = getattr(self, "text_encoder", None)

# Get textual inversion tokens and ids
token_ids = []
last_special_token_id = None

if tokens:
if isinstance(tokens, str):
tokens = [tokens]
for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
if not added_token.special:
if added_token.content in tokens:
token_ids.append(added_token_id)
else:
last_special_token_id = added_token_id
if len(token_ids) == 0:
raise ValueError("No tokens to remove found")
else:
tokens = []
for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
if not added_token.special:
token_ids.append(added_token_id)
tokens.append(added_token.content)
else:
last_special_token_id = added_token_id

# Delete from tokenizer
for token_id, token_to_remove in zip(token_ids, tokens):
del tokenizer._added_tokens_decoder[token_id]
del tokenizer._added_tokens_encoder[token_to_remove]

# Make all token ids sequential in tokenizer
key_id = 1
for token_id in tokenizer.added_tokens_decoder:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain why do we need this block?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this block to make all token ids sequential after one of the added tokens is removed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for explaining this! @fabiorigano

I'm not very familiar with the use case
cc @apolinario ad @linoytsaban here can you take a look to see if we need to reorder the remaining added tokens after we remove some?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @fabiorigano! I looked a bit but I'm actually not quite sure why it's necessary to reorder 🤔

Copy link
Contributor Author

@fabiorigano fabiorigano Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @linoytsaban, the reordering block is useful to have the same indeces as the text embeddings in the encoder, so multiple unload_textual_inversion(<token>) calls will remove the correct text embeddings. If the reordering is not done, when re-executing unload_textual_inversion(<another-token>) the last for loop may fail, because it may remove a different text embedding than expected.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha! that makes total sense, thanks for explaining! 🤗

if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
token = tokenizer._added_tokens_decoder[token_id]
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
del tokenizer._added_tokens_decoder[token_id]
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
key_id += 1
tokenizer._update_trie()

# Delete from text encoder
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
temp_text_embedding_weights = text_encoder.get_input_embeddings().weight
text_embedding_weights = temp_text_embedding_weights[: last_special_token_id + 1]
to_append = []
for i in range(last_special_token_id + 1, temp_text_embedding_weights.shape[0]):
if i not in token_ids:
to_append.append(temp_text_embedding_weights[i].unsqueeze(0))
if len(to_append) > 0:
to_append = torch.cat(to_append, dim=0)
text_embedding_weights = torch.cat([text_embedding_weights, to_append], dim=0)
text_embeddings_filtered = nn.Embedding(text_embedding_weights.shape[0], text_embedding_dim)
text_embeddings_filtered.weight.data = text_embedding_weights
text_encoder.set_input_embeddings(text_embeddings_filtered)