Skip to content

Commit

Permalink
Merge pull request #11 from helpmefindaname/fix_output_vocab_size
Browse files Browse the repository at this point in the history
set the vocab size correctly when recreating the full embedding
  • Loading branch information
helpmefindaname authored Dec 4, 2023
2 parents 00f49cf + b9a193e commit 90a2d49
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/test_contextual_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_saving_while_reduction_can_be_loaded_afterwards():
"Home sweet home",
"ay ay ay",
]
initial_vocab_size = model.config.vocab_size
with tempfile.TemporaryDirectory() as tdir:
with reduce_train_vocab(model=model, tokenizer=tokenizer, texts=texts):
model.save_pretrained(tdir)
Expand All @@ -77,3 +78,5 @@ def test_saving_while_reduction_can_be_loaded_afterwards():
new_tokenizer = AutoTokenizer.from_pretrained(tdir)
assert new_model.config.vocab_size == 13
assert len(new_tokenizer) == 13
assert model.config.vocab_size == initial_vocab_size
assert len(tokenizer) == initial_vocab_size
1 change: 1 addition & 0 deletions transformer_smaller_training_vocab/modify_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,6 @@ def recreate_embedding(
for reduced_id, full_id in enumerate(keep_token_ids):
saved_embeddings[full_id] = embedding_weights[reduced_id]
new_input_embedding = nn.Embedding(saved_embeddings.size(0), saved_embeddings.size(1), _weight=saved_embeddings)
model.config.vocab_size = saved_embeddings.size(0)
model.set_input_embeddings(new_input_embedding)
model.get_input_embeddings().to(model_device)

0 comments on commit 90a2d49

Please sign in to comment.