Skip to content

Commit

Permalink
minor training optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
dvruette committed Mar 5, 2023
1 parent 90bc0c9 commit fd00393
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
5 changes: 4 additions & 1 deletion model/model_training/trainer_sft.py
Expand Up @@ -212,7 +212,10 @@ def argument_parsing(notebook=False, notebook_args=None):

train, evals = get_dataset(training_conf)
train_collate_fn = DialogueDataCollator(
tokenizer, max_length=training_conf.max_length, samples_mixing=training_conf.samples_mixing
tokenizer,
max_length=training_conf.max_length,
samples_mixing=training_conf.samples_mixing,
pad_to_multiple_of=32,
)
eval_collate_fn = DialogueDataCollator(tokenizer, max_length=training_conf.max_length, samples_mixing=False)

Expand Down
7 changes: 5 additions & 2 deletions model/model_training/utils.py
@@ -1,4 +1,5 @@
import copy
import math
import random
from distutils.util import strtobool
from pathlib import Path
Expand Down Expand Up @@ -248,10 +249,12 @@ def get_model(conf, tokenizer):
conf.model_name, cache_dir=conf.cache_dir, quantization=conf.quantization, seq2seqmodel=conf.seq2seqmodel
)

if len(tokenizer) != model.get_input_embeddings().num_embeddings:
n_embs = model.get_input_embeddings().num_embeddings
if len(tokenizer) != n_embs:
assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen."

model.resize_token_embeddings(len(tokenizer))
if (len(tokenizer) != n_embs or n_embs % 8 == 0) and not conf.freeze_layer:
model.resize_token_embeddings(math.ceil(len(tokenizer) / 8) * 8)

if conf.freeze_layer:
model = freeze_top_n_layers(model, conf.freeze_layer)
Expand Down

0 comments on commit fd00393

Please sign in to comment.