Skip to content

Commit

Permalink
Skip weight initialization when resizing text encoder token embedding…
Browse files Browse the repository at this point in the history
…s to accomodate new TI embeddings. This saves time.
  • Loading branch information
RyanJDick authored and hipsterusername committed Jan 5, 2024
1 parent 8e17e29 commit f7f6978
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions invokeai/backend/model_management/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformers import CLIPTextModel, CLIPTokenizer

from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init

from .models.lora import LoRAModel

Expand Down Expand Up @@ -211,8 +212,12 @@ def _get_ti_embedding(model_embeddings, ti):
for i in range(ti_embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))

# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
# Modify text_encoder.
# resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of
# this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some
# time.
with skip_torch_weight_init():
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
model_embeddings = text_encoder.get_input_embeddings()

for ti_name, ti in ti_list:
Expand Down

0 comments on commit f7f6978

Please sign in to comment.