Skip to content

Commit

Permalink
MbartTokenizer: do not hardcode vocab size (#5998)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer committed Jul 23, 2020
1 parent 6e16195 commit 9827d66
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 30 deletions.
66 changes: 37 additions & 29 deletions src/transformers/tokenization_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,34 @@ class BartTokenizerFast(RobertaTokenizerFast):
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"

FAIRSEQ_LANGUAGE_CODES = [
"ar_AR",
"cs_CZ",
"de_DE",
"en_XX",
"es_XX",
"et_EE",
"fi_FI",
"fr_XX",
"gu_IN",
"hi_IN",
"it_IT",
"ja_XX",
"kk_KZ",
"ko_KR",
"lt_LT",
"lv_LV",
"my_MM",
"ne_NP",
"nl_XX",
"ro_RO",
"ru_RU",
"si_LK",
"tr_TR",
"vi_VN",
"zh_CN",
]


class MBartTokenizer(XLMRobertaTokenizer):
"""
Expand All @@ -81,40 +109,20 @@ class MBartTokenizer(XLMRobertaTokenizer):
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
lang_code_to_id = { # NOTE(SS): resize embeddings will break this
"ar_AR": 250001,
"cs_CZ": 250002,
"de_DE": 250003,
"en_XX": 250004,
"es_XX": 250005,
"et_EE": 250006,
"fi_FI": 250007,
"fr_XX": 250008,
"gu_IN": 250009,
"hi_IN": 250010,
"it_IT": 250011,
"ja_XX": 250012,
"kk_KZ": 250013,
"ko_KR": 250014,
"lt_LT": 250015,
"lv_LV": 250016,
"my_MM": 250017,
"ne_NP": 250018,
"nl_XX": 250019,
"ro_RO": 250020,
"ru_RU": 250021,
"si_LK": 250022,
"tr_TR": 250023,
"vi_VN": 250024,
"zh_CN": 250025,
}
id_to_lang_code = {v: k for k, v in lang_code_to_id.items()}
cur_lang_code = lang_code_to_id["en_XX"]

prefix_tokens: List[int] = []
suffix_tokens: List[int] = []

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.sp_model_size = len(self.sp_model)
self.lang_code_to_id = {
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
}
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
self.cur_lang_code = self.lang_code_to_id["en_XX"]

self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
self._additional_special_tokens = list(self.lang_code_to_id.keys())
Expand Down
7 changes: 6 additions & 1 deletion tests/test_tokenization_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,15 @@ class MBartEnroIntegrationTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
cls.tokenizer: MBartTokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
cls.pad_token_id = 1
return cls

def check_language_codes(self):
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ar_AR"], 250001)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)

def test_enro_tokenizer_prepare_translation_batch(self):
batch = self.tokenizer.prepare_translation_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
Expand Down

0 comments on commit 9827d66

Please sign in to comment.