Skip to content

Commit

Permalink
馃毃馃毃馃毃 [NLLB Tokenizer] Fix the prefix tokens 馃毃馃毃馃毃 (#22313)
Browse files Browse the repository at this point in the history
* fix the prefix tokens

* update fast and test values

* add legacy behaviour

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

* update disclaimer, linkissue PR and behaviral changes

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <hi@lysand.re>

* styling

* make a quote

* quote this time

---------

Co-authored-by: sgugger <sylvain.gugger@gmail.com>
Co-authored-by: Lysandre Debut <hi@lysand.re>
  • Loading branch information
3 people committed Apr 4, 2023
1 parent ad5e9b6 commit 00b5887
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 23 deletions.
41 changes: 39 additions & 2 deletions docs/source/en/model_doc/nllb.mdx
Expand Up @@ -12,8 +12,45 @@ specific language governing permissions and limitations under the License.

# NLLB

**DISCLAIMER:** If you see something strange, file a [Github Issue](https://github.com/huggingface/transformers/issues/new?assignees=&labels=bug&template=bug-report.yml) and assign
@LysandreJik
**DISCLAIMER:** The default behaviour for the tokenizer has recently been fixed (and thus changed)!

The previous version adds `[self.eos_token_id, self.cur_lang_code]` at the end of the token sequence for both target and source tokenization. This is wrong as the NLLB paper mentions (page 48, 6.1.1. Model Architecture) :

*Note that we prefix the source sequence with the source language, as opposed to the target
language as previously done in several works (Arivazhagan et al., 2019; Johnson et al.,
2017). This is primarily because we prioritize optimizing zero-shot performance of our
model on any pair of 200 languages at a minor cost to supervised performance.*

Previous behaviour:

```python
>>> from transformers import NllbTokenizer

>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
>>> tokenizer("How was your day?").input_ids
[13374, 1398, 4260, 4039, 248130, 2, 256047]

>>> # 2: '</s>'
>>> # 256047 : 'eng_Latn'
```
New behaviour
```python
>>> from transformers import NllbTokenizer
>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
>>> tokenizer("How was your day?").input_ids
[256047, 13374, 1398, 4260, 4039, 248130, 2]
```

Enabling the old behaviour can be done as follows:
```python
>>> from transformers import NllbTokenizer

>>> tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", legacy_behaviour=True)
```

For more details, feel free to check the linked [PR](https://github.com/huggingface/transformers/pull/22313) and [Issue](https://github.com/huggingface/transformers/issues/19943).

## Overview of NLLB

Expand Down
30 changes: 23 additions & 7 deletions src/transformers/models/nllb/tokenization_nllb.py
Expand Up @@ -140,12 +140,14 @@ def __init__(
tgt_lang=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
additional_special_tokens=None,
legacy_behaviour=False,
**kwargs,
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.legacy_behaviour = legacy_behaviour

super().__init__(
bos_token=bos_token,
Expand All @@ -160,13 +162,13 @@ def __init__(
tgt_lang=tgt_lang,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=self.sp_model_kwargs,
legacy_behaviour=legacy_behaviour,
**kwargs,
)

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file

# Original fairseq vocab and spm vocab must be "aligned":
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
# -------- | ------- | ------- | ------ | ------- | ---- | ---- | ---- | ---- | ---- | ----
Expand Down Expand Up @@ -388,13 +390,27 @@ def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
"""Reset the special tokens to the source lang setting.
- In legacy mode: No prefix and suffix=[eos, src_lang_code].
- In default mode: Prefix=[src_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
"""Reset the special tokens to the target lang setting.
- In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
- In default mode: Prefix=[tgt_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
31 changes: 24 additions & 7 deletions src/transformers/models/nllb/tokenization_nllb_fast.py
Expand Up @@ -151,11 +151,12 @@ def __init__(
src_lang=None,
tgt_lang=None,
additional_special_tokens=None,
legacy_behaviour=False,
**kwargs,
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

self.legacy_behaviour = legacy_behaviour
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
Expand All @@ -169,6 +170,7 @@ def __init__(
src_lang=src_lang,
tgt_lang=tgt_lang,
additional_special_tokens=additional_special_tokens,
legacy_behaviour=legacy_behaviour,
**kwargs,
)

Expand Down Expand Up @@ -287,10 +289,18 @@ def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
"""Reset the special tokens to the source lang setting.
- In legacy mode: No prefix and suffix=[eos, src_lang_code].
- In default mode: Prefix=[src_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.convert_tokens_to_ids(src_lang)
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]

if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
Expand All @@ -302,10 +312,17 @@ def set_src_lang_special_tokens(self, src_lang) -> None:
)

def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
"""Reset the special tokens to the target lang setting.
- In legacy mode: No prefix and suffix=[eos, tgt_lang_code].
- In default mode: Prefix=[tgt_lang_code], suffix = [eos]
"""
self.cur_lang_code = self.convert_tokens_to_ids(lang)
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
if self.legacy_behaviour:
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
else:
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

prefix_tokens_str = self.convert_ids_to_tokens(self.prefix_tokens)
suffix_tokens_str = self.convert_ids_to_tokens(self.suffix_tokens)
Expand Down
32 changes: 25 additions & 7 deletions tests/models/nllb/test_tokenization_nllb.py
Expand Up @@ -305,6 +305,7 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
" face dec芒t s膬 卯nr膬ut膬牛easc膬 violen牛ele 艧i mizeria pentru milioane de oameni.",
]
expected_src_tokens = [
256047,
16297,
134408,
8165,
Expand All @@ -319,7 +320,6 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
108,
49486,
2,
256047,
]

@classmethod
Expand Down Expand Up @@ -355,8 +355,8 @@ def test_enro_tokenizer_truncation(self):
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(ids[-1], 2)
self.assertEqual(ids[0], EN_CODE)
self.assertEqual(len(ids), desired_max_length)

def test_mask_token(self):
Expand Down Expand Up @@ -389,10 +389,10 @@ def test_enro_tokenizer_prepare_batch(self):
self.assertEqual((2, 15), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
self.assertEqual(RO_CODE, batch.decoder_input_ids[0, 0]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])

def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
Expand All @@ -419,9 +419,27 @@ def test_tokenizer_translation(self):
nested_simplify(inputs),
{
# A, test, EOS, en_XX
"input_ids": [[70, 7356, 2, 256047]],
"input_ids": [[256047, 70, 7356, 2]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 256057,
},
)

@require_torch
def test_legacy_behaviour(self):
self.tokenizer.legacy_behaviour = True
inputs = self.tokenizer(
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
)
self.assertEqual(
inputs.input_ids, [16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2, 256047]
)

self.tokenizer.legacy_behaviour = False
inputs = self.tokenizer(
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
)
self.assertEqual(
inputs.input_ids, [256047, 16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2]
)

0 comments on commit 00b5887

Please sign in to comment.