Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for WMT21 tokenizer in M2M100Tokenizer #14376

Merged
merged 1 commit into from
Nov 13, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/transformers/models/m2m_100/tokenization_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
}

# fmt: off
FAIRSEQ_LANGUAGE_CODES = ["af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb", "cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig", "ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln", "lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no", "ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk", "ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"]
FAIRSEQ_LANGUAGE_CODES = {
"m2m100": ["af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb", "cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig", "ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln", "lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no", "ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk", "ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"],
"wmt21": ['en', 'ha', 'is', 'ja', 'cs', 'ru', 'zh', 'de']
}
# fmt: on


Expand Down Expand Up @@ -86,6 +89,8 @@ class M2M100Tokenizer(PreTrainedTokenizer):
token instead.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
language_codes (:obj:`str`, `optional`, defaults to :obj:`"m2m100"`):
What language codes to use. Should be one of :obj:`"m2m100"` or :obj:`"wmt21"`.
sp_model_kwargs (:obj:`dict`, `optional`):
Will be passed to the ``SentencePieceProcessor.__init__()`` method. The `Python wrapper for SentencePiece
<https://github.com/google/sentencepiece/tree/master/python>`__ can be used, among other things, to set:
Expand Down Expand Up @@ -132,17 +137,21 @@ def __init__(
sep_token="</s>",
pad_token="<pad>",
unk_token="<unk>",
language_codes="m2m100",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
num_madeup_words=8,
**kwargs,
) -> None:
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

self.lang_code_to_token = {lang_code: f"__{lang_code}__" for lang_code in FAIRSEQ_LANGUAGE_CODES}
self.language_codes = language_codes
fairseq_language_code = FAIRSEQ_LANGUAGE_CODES[language_codes]
self.lang_code_to_token = {lang_code: f"__{lang_code}__" for lang_code in fairseq_language_code}

kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", [])
kwargs["additional_special_tokens"] += [
self.get_lang_token(lang_code)
for lang_code in FAIRSEQ_LANGUAGE_CODES
for lang_code in fairseq_language_code
if self.get_lang_token(lang_code) not in kwargs["additional_special_tokens"]
]

Expand All @@ -154,7 +163,9 @@ def __init__(
sep_token=sep_token,
unk_token=unk_token,
pad_token=pad_token,
language_codes=language_codes,
sp_model_kwargs=self.sp_model_kwargs,
num_madeup_words=num_madeup_words,
**kwargs,
)

Expand All @@ -167,17 +178,17 @@ def __init__(
self.encoder_size = len(self.encoder)

self.lang_token_to_id = {
self.get_lang_token(lang_code): self.encoder_size + i for i, lang_code in enumerate(FAIRSEQ_LANGUAGE_CODES)
self.get_lang_token(lang_code): self.encoder_size + i for i, lang_code in enumerate(fairseq_language_code)
}
self.lang_code_to_id = {lang_code: self.encoder_size + i for i, lang_code in enumerate(FAIRSEQ_LANGUAGE_CODES)}
self.lang_code_to_id = {lang_code: self.encoder_size + i for i, lang_code in enumerate(fairseq_language_code)}
self.id_to_lang_token = {v: k for k, v in self.lang_token_to_id.items()}

self._src_lang = src_lang if src_lang is not None else "en"
self.tgt_lang = tgt_lang
self.cur_lang_id = self.get_lang_id(self._src_lang)
self.set_src_lang_special_tokens(self._src_lang)

self.num_madeup_words = 8
self.num_madeup_words = num_madeup_words

@property
def vocab_size(self) -> int:
Expand Down