Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytorch_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from .modeling_bert import (BertConfig, BertPreTrainedModel, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering,
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BertForTokenClassification, BertForQuestionAnswering, BertLayerNorm,
gelu, load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTPreTrainedModel, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
Expand Down
15 changes: 10 additions & 5 deletions pytorch_transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,13 +392,14 @@ def __len__(self):
return self.vocab_size + len(self.added_tokens_encoder)


def add_tokens(self, new_tokens):
def add_tokens(self, new_tokens, ids_start=None):
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.

Args:
new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
ids_start: new token ids will start from ids_start. If None, the new ids start from size of the vocabulary.

Returns:
Number of tokens added to the vocabulary.
Expand All @@ -424,15 +425,18 @@ def add_tokens(self, new_tokens):
to_add_tokens.append(token)
logger.info("Adding %s to the vocabulary", token)

added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
if ids_start is None:
ids_start = len(self)

added_tok_encoder = dict((tok, ids_start + i) for i, tok in enumerate(to_add_tokens))
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)

return len(to_add_tokens)


def add_special_tokens(self, special_tokens_dict):
def add_special_tokens(self, special_tokens_dict, ids_start=None):
"""
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
to class attributes. If special tokens are NOT in the vocabulary, they are added
Expand All @@ -444,6 +448,7 @@ def add_special_tokens(self, special_tokens_dict):
``additional_special_tokens``].

Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
ids_start: new token ids will start from ids_start. If None, the new ids start from size of the vocabulary.

Returns:
Number of tokens added to the vocabulary.
Expand All @@ -470,10 +475,10 @@ def add_special_tokens(self, special_tokens_dict):
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
if key == 'additional_special_tokens':
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
added_tokens += self.add_tokens(value)
added_tokens += self.add_tokens(value, ids_start=ids_start)
else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
added_tokens += self.add_tokens([value])
added_tokens += self.add_tokens([value], ids_start=ids_start)
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
setattr(self, key, value)

Expand Down