Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨🚨🚨 [NLLB Tokenizer] Fix the prefix tokens 🚨🚨🚨 #22313

Merged
merged 10 commits into from Apr 4, 2023
41 changes: 39 additions & 2 deletions docs/source/en/model_doc/nllb.mdx
Expand Up @@ -12,8 +12,45 @@ specific language governing permissions and limitations under the License.

# NLLB

**DISCLAIMER:** If you see something strange, file a [Github Issue](https://github.com/huggingface/transformers/issues/new?assignees=&labels=bug&template=bug-report.yml) and assign
@LysandreJik
**DISCLAIMER:** The default behaviour for the tokenizer has recently been fixed (and thus changed)!

The previous version adds `[self.eos_token_id, self.cur_lang_code]` at the end of the token sequence for both target and source tokenization. This is wrong as the NLLB paper mentions (page 48, 6.1.1. Model Architecture) :

*Note that we prefix the source sequence with the source language, as opposed to the target
language as previously done in several works (Arivazhagan et al., 2019; Johnson et al.,
2017). This is primarily because we prioritize optimizing zero-shot performance of our
model on any pair of 200 languages at a minor cost to supervised performance.*

Previous behaviour:

```python
>>> from transformers import NllbTokenizer

>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
>>> tokenizer("How was your day?").input_ids
[13374, 1398, 4260, 4039, 248130, 2, 256047]

>>> # 2: '</s>'
>>> # 256047 : 'eng_Latn'
```
New behaviour

```python
>>> from transformers import NllbTokenizer

>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
>>> tokenizer("How was your day?").input_ids
[256047, 13374, 1398, 4260, 4039, 248130, 2]
```

Enabling the old behaviour can be done as follows:
```python
>>> from transformers import NllbTokenizer

>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", legacy_behaviour=True)
```

For more details, feel free to check the linked [PR](https://github.com/huggingface/transformers/pull/22313) and [Issue](https://github.com/huggingface/transformers/issues/19943).

## Overview of NLLB

Expand Down
30 changes: 23 additions & 7 deletions src/transformers/models/nllb/tokenization_nllb.py
Expand Up @@ -140,12 +140,14 @@ def __init__(
tgt_lang=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
additional_special_tokens=None,
legacy_behaviour=False,
**kwargs,
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.legacy_behaviour = legacy_behaviour

super().__init__(
bos_token=bos_token,
Expand All @@ -160,13 +162,13 @@ def __init__(
tgt_lang=tgt_lang,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=self.sp_model_kwargs,
legacy_behaviour=legacy_behaviour,
**kwargs,
)

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file

# Original fairseq vocab and spm vocab must be "aligned":
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
# -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ----
Expand Down Expand Up @@ -388,13 +390,27 @@ def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
"""Reset the special tokens to the source lang setting.
- In legacy mode: No prefix and suffix=[eos, src_lang_code].
- In default mode: Prefix=[src_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
"""Reset the special tokens to the target lang setting.
- In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
- In default mode: Prefix=[tgt_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
31 changes: 24 additions & 7 deletions src/transformers/models/nllb/tokenization_nllb_fast.py
Expand Up @@ -151,11 +151,12 @@ def __init__(
src_lang=None,
tgt_lang=None,
additional_special_tokens=None,
legacy_behaviour=False,
**kwargs,
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

self.legacy_behaviour = legacy_behaviour
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
Expand All @@ -169,6 +170,7 @@ def __init__(
src_lang=src_lang,
tgt_lang=tgt_lang,
additional_special_tokens=additional_special_tokens,
legacy_behaviour=legacy_behaviour,
**kwargs,
)

Expand Down Expand Up @@ -287,10 +289,18 @@ def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
"""Reset the special tokens to the source lang setting.
- In legacy mode: No prefix and suffix=[eos, src_lang_code].
- In default mode: Prefix=[src_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]

if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
Expand All @@ -302,10 +312,17 @@ def set_src_lang_special_tokens(self, src_lang) -> None:
)

def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
"""Reset the special tokens to the target lang setting.
- In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
- In default mode: Prefix=[tgt_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.convert_tokens_to_ids(lang)
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
Expand Down
32 changes: 25 additions & 7 deletions tests/models/nllb/test_tokenization_nllb.py
Expand Up @@ -305,6 +305,7 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
" face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
]
expected_src_tokens = [
256047,
16297,
134408,
8165,
Expand All @@ -319,7 +320,6 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
108,
49486,
2,
256047,
]

@classmethod
Expand Down Expand Up @@ -355,8 +355,8 @@ def test_enro_tokenizer_truncation(self):
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(ids[-1], 2)
self.assertEqual(ids[0], EN_CODE)
self.assertEqual(len(ids), desired_max_length)

def test_mask_token(self):
Expand Down Expand Up @@ -389,10 +389,10 @@ def test_enro_tokenizer_prepare_batch(self):
self.assertEqual((2, 15), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
self.assertEqual(RO_CODE, batch.decoder_input_ids[0, 0]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])

def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
Expand All @@ -419,9 +419,27 @@ def test_tokenizer_translation(self):
nested_simplify(inputs),
{
# A, test, EOS, en_XX
"input_ids": [[70, 7356, 2, 256047]],
"input_ids": [[256047, 70, 7356, 2]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 256057,
},
)

@require_torch
def test_legacy_behaviour(self):
self.tokenizer.legacy_behaviour = True
inputs = self.tokenizer(
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
)
self.assertEqual(
inputs.input_ids, [16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2, 256047]
)

self.tokenizer.legacy_behaviour = False
inputs = self.tokenizer(
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
)
self.assertEqual(
inputs.input_ids, [256047, 16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2]
)