Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tokenization of special tokens for word_piece_tokenizer #1397

Merged
merged 19 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions keras_nlp/models/bert/bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class BertTokenizer(WordPieceTokenizer):
plain text file containing a single word piece token per line.
lowercase: If `True`, the input text will be first lowered before
tokenization.
special_tokens_in_strings: bool. A bool to indicate if the tokenizer
should expect special tokens in input strings that should be
tokenized and mapped correctly to their ids. Defaults to False.

Examples:
```python
Expand Down Expand Up @@ -76,6 +79,7 @@ def __init__(
self,
vocabulary=None,
lowercase=False,
special_tokens_in_strings=False,
**kwargs,
):
self.cls_token = "[CLS]"
Expand All @@ -85,22 +89,20 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
lowercase=lowercase,
special_tokens=[
self.cls_token,
self.sep_token,
self.pad_token,
self.mask_token,
],
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

def set_vocabulary(self, vocabulary):
super().set_vocabulary(vocabulary)

if vocabulary is not None:
# Check for necessary special tokens.
for token in [self.cls_token, self.pad_token, self.sep_token]:
if token not in self.vocabulary:
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self.cls_token_id = self.token_to_id(self.cls_token)
self.sep_token_id = self.token_to_id(self.sep_token)
self.pad_token_id = self.token_to_id(self.pad_token)
Expand All @@ -114,3 +116,8 @@ def set_vocabulary(self, vocabulary):
@classproperty
def presets(cls):
return copy.deepcopy({**backbone_presets, **classifier_presets})

def get_config(self):
config = super().get_config()
del config["special_tokens"] # Not configurable; set in __init__.
return config
10 changes: 10 additions & 0 deletions keras_nlp/models/bert/bert_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ def test_lowercase(self):
output = tokenizer(self.input_data)
self.assertAllEqual(output, [[9, 10, 11, 12], [9, 12]])

def test_tokenizer_special_tokens(self):
input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"]
tokenizer = BertTokenizer(
**self.init_kwargs, special_tokens_in_strings=True
)
output_data = tokenizer(input_data)
expected_output = [[2, 5, 4, 8, 3, 0]]

self.assertAllEqual(output_data, expected_output)

def test_errors_missing_special_tokens(self):
with self.assertRaises(ValueError):
BertTokenizer(vocabulary=["a", "b", "c"])
Expand Down
25 changes: 16 additions & 9 deletions keras_nlp/models/distil_bert/distil_bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class DistilBertTokenizer(WordPieceTokenizer):
plain text file containing a single word piece token per line.
lowercase: If `True`, the input text will be first lowered before
tokenization.
special_tokens_in_strings: bool. A bool to indicate if the tokenizer
should expect special tokens in input strings that should be
tokenized and mapped correctly to their ids. Defaults to False.

Examples:

Expand Down Expand Up @@ -74,6 +77,7 @@ def __init__(
self,
vocabulary,
lowercase=False,
special_tokens_in_strings=False,
**kwargs,
):
self.cls_token = "[CLS]"
Expand All @@ -83,22 +87,20 @@ def __init__(
super().__init__(
vocabulary=vocabulary,
lowercase=lowercase,
special_tokens=[
self.cls_token,
self.sep_token,
self.pad_token,
self.mask_token,
],
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

def set_vocabulary(self, vocabulary):
super().set_vocabulary(vocabulary)

if vocabulary is not None:
# Check for necessary special tokens.
for token in [self.cls_token, self.pad_token, self.sep_token]:
if token not in self.vocabulary:
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self.cls_token_id = self.token_to_id(self.cls_token)
self.sep_token_id = self.token_to_id(self.sep_token)
self.pad_token_id = self.token_to_id(self.pad_token)
Expand All @@ -112,3 +114,8 @@ def set_vocabulary(self, vocabulary):
@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

def get_config(self):
config = super().get_config()
del config["special_tokens"] # Not configurable; set in __init__.
return config
10 changes: 10 additions & 0 deletions keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ def test_lowercase(self):
output = tokenizer(self.input_data)
self.assertAllEqual(output, [[9, 10, 11, 12], [9, 12]])

def test_tokenizer_special_tokens(self):
input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"]
tokenizer = DistilBertTokenizer(
**self.init_kwargs, special_tokens_in_strings=True
)
output_data = tokenizer(input_data)
expected_output = [[2, 5, 4, 8, 3, 0]]

self.assertAllEqual(output_data, expected_output)

def test_errors_missing_special_tokens(self):
with self.assertRaises(ValueError):
DistilBertTokenizer(vocabulary=["a", "b", "c"])
Expand Down
38 changes: 27 additions & 11 deletions keras_nlp/models/electra/electra_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class ElectraTokenizer(WordPieceTokenizer):
plain text file containing a single word piece token per line.
lowercase: If `True`, the input text will be first lowered before
tokenization.
special_tokens_in_strings: bool. A bool to indicate if the tokenizer
should expect special tokens in input strings that should be
tokenized and mapped correctly to their ids. Defaults to False.

Examples:
```python
Expand All @@ -57,26 +60,34 @@ class ElectraTokenizer(WordPieceTokenizer):
```
"""

def __init__(self, vocabulary, lowercase=False, **kwargs):
def __init__(
self,
vocabulary,
lowercase=False,
special_tokens_in_strings=False,
**kwargs,
):
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
self.pad_token = "[PAD]"
self.mask_token = "[MASK]"
super().__init__(vocabulary=vocabulary, lowercase=lowercase, **kwargs)
super().__init__(
vocabulary=vocabulary,
lowercase=lowercase,
special_tokens=[
self.cls_token,
self.sep_token,
self.pad_token,
self.mask_token,
],
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

def set_vocabulary(self, vocabulary):
super().set_vocabulary(vocabulary)

if vocabulary is not None:
# Check for necessary special tokens.
for token in [self.cls_token, self.pad_token, self.sep_token]:
if token not in self.vocabulary:
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self.cls_token_id = self.token_to_id(self.cls_token)
self.sep_token_id = self.token_to_id(self.sep_token)
self.pad_token_id = self.token_to_id(self.pad_token)
Expand All @@ -86,3 +97,8 @@ def set_vocabulary(self, vocabulary):
self.sep_token_id = None
self.pad_token_id = None
self.mask_token_id = None

def get_config(self):
config = super().get_config()
del config["special_tokens"] # Not configurable; set in __init__.
return config
10 changes: 10 additions & 0 deletions keras_nlp/models/electra/electra_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ def test_lowercase(self):
output = tokenizer(self.input_data)
self.assertAllEqual(output, [[9, 10, 11, 12], [9, 12]])

def test_tokenizer_special_tokens(self):
input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"]
tokenizer = ElectraTokenizer(
**self.init_kwargs, special_tokens_in_strings=True
)
output_data = tokenizer(input_data)
expected_output = [[2, 5, 4, 8, 3, 0]]

self.assertAllEqual(output_data, expected_output)

def test_errors_missing_special_tokens(self):
with self.assertRaises(ValueError):
ElectraTokenizer(vocabulary=["a", "b", "c"])
70 changes: 69 additions & 1 deletion keras_nlp/tokenizers/word_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import re
from typing import Iterable
from typing import List

Expand Down Expand Up @@ -101,12 +102,19 @@
)


def get_special_tokens_pattern(special_tokens):
if special_tokens is None or len(special_tokens) == 0:
return None
return r"|".join([re.escape(token) for token in special_tokens])


def pretokenize(
text,
lowercase=False,
strip_accents=True,
split=True,
split_on_cjk=True,
special_tokens_pattern=None,
):
"""Helper function that takes in a dataset element and pretokenizes it.

Expand All @@ -124,7 +132,14 @@ def pretokenize(
split_on_cjk: bool. If `True`, input will be split
on CJK characters, i.e., Chinese, Japanese, Korean and Vietnamese
characters (https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)).
Note that this is applicable only when `split` is `True`. Defaults to `True`.
Note that this is applicable only when `split` is `True`. Defaults
to `True`.
special_tokens_pattern: str. A regex pattern that contain the
special tokens that will never be split during the word-level
splitting applied before the word-peice encoding. This can be used
to ensure special tokens map to unique indices in the vocabulary,
even if these special tokens contain splittable characters such as
punctuation.

Returns:
A tensor containing the pre-processed and pre-tokenized `text`.
Expand Down Expand Up @@ -154,6 +169,23 @@ def pretokenize(
else:
split_pattern = WHITESPACE_AND_PUNCTUATION_REGEX
keep_split_pattern = PUNCTUATION_REGEX
if special_tokens_pattern is not None:
# the idea here is to pass the special tokens regex to the split
# function as delimiter regex pattern, so the input will be splitted
# by them, but also the function will treat each on of them as one
# entity that shouldn't be splitted even if they have other
# delimiter regex pattern inside them. then pass the special tokens
# regex also as keep delimiter regex pattern, so they will
# not be removed.
split_pattern = r"|".join(
[
special_tokens_pattern,
split_pattern,
]
)
keep_split_pattern = r"|".join(
[special_tokens_pattern, keep_split_pattern]
)
text = tf_text.regex_split(
text,
delim_regex_pattern=split_pattern,
Expand Down Expand Up @@ -225,6 +257,15 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
oov_token: str. The string value to substitute for
an unknown token. It must be included in the vocab.
Defaults to `"[UNK]"`.
special_tokens: list. A list of special tokens. when
`special_tokens_in_strings` is set to `True`, the tokenizer will map
every special token in the input strings to its id, even if these
special tokens contain characters that should be splitted before
tokenization such as punctuation. `special_tokens` must be included
in `vocabulary`.
special_tokens_in_strings: bool. A bool to indicate if the tokenizer
should expect special tokens in input strings that should be
tokenized and mapped correctly to their ids. Defaults to False.

References:
- [Schuster and Nakajima, 2012](https://research.google/pubs/pub37842/)
Expand Down Expand Up @@ -303,6 +344,8 @@ def __init__(
split_on_cjk: bool = True,
suffix_indicator: str = "##",
oov_token: str = "[UNK]",
special_tokens: List[str] = None,
special_tokens_in_strings: bool = False,
dtype="int32",
**kwargs,
) -> None:
Expand All @@ -325,6 +368,19 @@ def __init__(
self.split_on_cjk = split_on_cjk
self.suffix_indicator = suffix_indicator
self.oov_token = oov_token
self.special_tokens = special_tokens
self._special_tokens_pattern = None
if self.split and special_tokens_in_strings:
# the idea here is to pass the special tokens regex to the
# split function as delimiter regex pattern, so the input will
# be splitted by them, but also the function will treat each on
# of them as one entity that shouldn't be splitted even if they
# have other delimiter regex pattern inside them. then pass the
# special tokens regex also as keep delimiter regex
# pattern, so they will not be removed.
self._special_tokens_pattern = get_special_tokens_pattern(
self.special_tokens
)
self.set_vocabulary(vocabulary)

def save_assets(self, dir_path):
Expand Down Expand Up @@ -365,6 +421,16 @@ def set_vocabulary(self, vocabulary):
"the `oov_token` argument when creating the tokenizer."
)

# Check for special tokens in the vocabulary
if self.special_tokens is not None:
for token in self.special_tokens:
if token not in self.vocabulary:
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self._fast_word_piece = tf_text.FastWordpieceTokenizer(
vocab=self.vocabulary,
token_out_type=self.compute_dtype,
Expand Down Expand Up @@ -413,6 +479,7 @@ def get_config(self):
"split": self.split,
"suffix_indicator": self.suffix_indicator,
"oov_token": self.oov_token,
"special_tokens": self.special_tokens,
}
)
return config
Expand All @@ -436,6 +503,7 @@ def tokenize(self, inputs):
self.strip_accents,
self.split,
self.split_on_cjk,
self._special_tokens_pattern,
)

# Apply WordPiece and coerce shape for outputs.
Expand Down
Loading
Loading