Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨🚨🚨 [NLLB Tokenizer] Fix the prefix tokens 🚨🚨🚨 #22313

Merged
merged 10 commits into from Apr 4, 2023
8 changes: 4 additions & 4 deletions src/transformers/models/nllb/tokenization_nllb.py
Expand Up @@ -390,11 +390,11 @@ def _switch_to_target_mode(self):
def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
8 changes: 4 additions & 4 deletions src/transformers/models/nllb/tokenization_nllb_fast.py
Expand Up @@ -289,8 +289,8 @@ def _switch_to_target_mode(self):
def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
Expand All @@ -304,8 +304,8 @@ def set_src_lang_special_tokens(self, src_lang) -> None:
def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
self.cur_lang_code = self.convert_tokens_to_ids(lang)
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
Expand Down
14 changes: 7 additions & 7 deletions tests/models/nllb/test_tokenization_nllb.py
Expand Up @@ -305,6 +305,7 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
" face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
]
expected_src_tokens = [
256047,
16297,
134408,
8165,
Expand All @@ -319,7 +320,6 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
108,
49486,
2,
256047,
]

@classmethod
Expand Down Expand Up @@ -355,8 +355,8 @@ def test_enro_tokenizer_truncation(self):
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(ids[-1], 2)
self.assertEqual(ids[0], EN_CODE)
self.assertEqual(len(ids), desired_max_length)

def test_mask_token(self):
Expand Down Expand Up @@ -389,10 +389,10 @@ def test_enro_tokenizer_prepare_batch(self):
self.assertEqual((2, 15), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
self.assertEqual(RO_CODE, batch.decoder_input_ids[0, 0]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])

def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
Expand All @@ -419,7 +419,7 @@ def test_tokenizer_translation(self):
nested_simplify(inputs),
{
# A, test, EOS, en_XX
"input_ids": [[70, 7356, 2, 256047]],
"input_ids": [[256047, 70, 7356, 2]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 256057,
Expand Down