diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 62e3b8c47b80..f33ad229258c 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -15,8 +15,8 @@ from .modeling_bert import (BertConfig, BertPreTrainedModel, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, BertForMultipleChoice, - BertForTokenClassification, BertForQuestionAnswering, - load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + BertForTokenClassification, BertForQuestionAnswering, BertLayerNorm, + gelu, load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP) from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTPreTrainedModel, OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 4fef0e34fb07..3203cf85ed8a 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -392,13 +392,14 @@ def __len__(self): return self.vocab_size + len(self.added_tokens_encoder) - def add_tokens(self, new_tokens): + def add_tokens(self, new_tokens, ids_start=None): """ Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to it with indices starting from length of the current vocabulary. Args: new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). + ids_start: new token ids will start from ids_start. If None, the new ids start from size of the vocabulary. Returns: Number of tokens added to the vocabulary. @@ -424,7 +425,10 @@ def add_tokens(self, new_tokens): to_add_tokens.append(token) logger.info("Adding %s to the vocabulary", token) - added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens)) + if ids_start is None: + ids_start = len(self) + + added_tok_encoder = dict((tok, ids_start + i) for i, tok in enumerate(to_add_tokens)) added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} self.added_tokens_encoder.update(added_tok_encoder) self.added_tokens_decoder.update(added_tok_decoder) @@ -432,7 +436,7 @@ def add_tokens(self, new_tokens): return len(to_add_tokens) - def add_special_tokens(self, special_tokens_dict): + def add_special_tokens(self, special_tokens_dict, ids_start=None): """ Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them to class attributes. If special tokens are NOT in the vocabulary, they are added @@ -444,6 +448,7 @@ def add_special_tokens(self, special_tokens_dict): ``additional_special_tokens``]. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). + ids_start: new token ids will start from ids_start. If None, the new ids start from size of the vocabulary. Returns: Number of tokens added to the vocabulary. @@ -470,10 +475,10 @@ def add_special_tokens(self, special_tokens_dict): assert key in self.SPECIAL_TOKENS_ATTRIBUTES if key == 'additional_special_tokens': assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value) - added_tokens += self.add_tokens(value) + added_tokens += self.add_tokens(value, ids_start=ids_start) else: assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) - added_tokens += self.add_tokens([value]) + added_tokens += self.add_tokens([value], ids_start=ids_start) logger.info("Assigning %s to the %s key of the tokenizer", value, key) setattr(self, key, value)