Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for CJK Char Splitting for WordPiece Tokenizer #318

Merged
merged 10 commits into from
Aug 29, 2022
Merged
60 changes: 55 additions & 5 deletions keras_nlp/tokenizers/word_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@
r"[{-~]",
# Unicode punctuation class.
r"[\p{P}]",
# More unicode ranges.
]
)

# Matches CJK characters. Obtained from
# https://github.com/google-research/bert/blob/master/tokenization.py#L251.
CJK_REGEX = r"|".join(
[
r"[\x{4E00}-\x{9FFF}]",
r"[\x{3400}-\x{4DBF}]",
r"[\x{20000}-\x{2A6DF}]",
Expand All @@ -63,8 +69,30 @@
]
)

# Matches punctuation and CJK characters.
PUNCTUATION_AND_CJK_REGEX = r"|".join(
[
PUNCTUATION_REGEX,
CJK_REGEX,
]
)

# Matches whitespace, punctuation, and CJK characters.
WHITESPACE_PUNCTUATION_AND_CJK_REGEX = r"|".join(
[
WHITESPACE_AND_PUNCTUATION_REGEX,
CJK_REGEX,
]
)


def pretokenize(text, lowercase, strip_accents, split):
def pretokenize(
text,
lowercase=True,
strip_accents=True,
split=True,
split_on_cjk=True,
):
"""Helper function that takes in a dataset element and pretokenizes it.

Args:
Expand All @@ -78,6 +106,10 @@ def pretokenize(text, lowercase, strip_accents, split):
kept as tokens. If false, input should be split ("pre-tokenized")
before calling the tokenizer, and passed as a dense or ragged tensor
of whole words.
split_on_cjk: bool, defaults to `True`. If true, input will be split
on CJK characters, i.e., Chinese, Japanese, Korean and Vietnamese
characters (https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)).
Note that this is applicable only when `split` is true.

Returns:
A tensor containing the pre-processed and pre-tokenized `text`.
Expand All @@ -91,6 +123,8 @@ def pretokenize(text, lowercase, strip_accents, split):
# Preprocess, lowercase, strip and split input data.
if text.shape.rank == 0:
text = tf.expand_dims(text, 0)
if split_on_cjk and split:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow interesting, it looks like we were already splitting on these. I think in that case we could actually avoid this regex_replace call and just keep two different split regexes.

WHITESPACE_AND_PUNCTUATION_REGEX and WHITESPACE_PUNCTUATION_AND_CJK_REGEX

if split:
    if split_on_cjk:
         split_pattern = WHITESPACE_PUNCTUATION_AND_CJK_REGEX
    else:
         split_pattern = WHITESPACE_AND_PUNCTUATION_REGEX
    text = tf_text.regex_split(...)

Would this work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, yeah. But we'll have to create a third REGEX, namely, PUNCTUATION_AND_CJK_REGEX to pass to keep_delim_regex_pattern. Isn't it easier to just keep tf.regex_replace(), then? What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushing changes for now, let me know if it's better to revert back to the original.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine to me. Original is simpler, but we probably should care about efficiency here (and cutting an op should help that).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

text = tf.strings.regex_replace(text, CJK_REGEX, r" \0 ")
if lowercase:
text = tf_text.case_fold_utf8(text)
if strip_accents:
Expand All @@ -99,10 +133,16 @@ def pretokenize(text, lowercase, strip_accents, split):
# Remove the accent marks.
text = tf.strings.regex_replace(text, r"\p{Mn}", "")
if split:
if split_on_cjk:
split_pattern = WHITESPACE_PUNCTUATION_AND_CJK_REGEX
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delim_regex_pattern I guess? since you are agreeing with the other arg name

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hehe, I just changed it to keep_split_pattern :P

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that works too!

keep_split_pattern = PUNCTUATION_AND_CJK_REGEX
else:
split_pattern = WHITESPACE_AND_PUNCTUATION_REGEX
keep_split_pattern = PUNCTUATION_REGEX
text = tf_text.regex_split(
text,
delim_regex_pattern=WHITESPACE_AND_PUNCTUATION_REGEX,
keep_delim_regex_pattern=PUNCTUATION_REGEX,
delim_regex_pattern=split_pattern,
keep_delim_regex_pattern=keep_split_pattern,
)
return text

Expand Down Expand Up @@ -159,6 +199,10 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
kept as tokens. If false, input should be split ("pre-tokenized")
before calling the tokenizer, and passed as a dense or ragged tensor
of whole words.
split_on_cjk: bool, defaults to `True`. If true, input will be split
on CJK characters, i.e., Chinese, Japanese, Korean and Vietnamese
characters (https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)).
Note that this is applicable only when `split` is true.
suffix_indicator: str, defaults to "##". The characters prepended to a
WordPiece to indicate that it is a suffix to another subword.
E.g. "##ing".
Expand Down Expand Up @@ -235,6 +279,7 @@ def __init__(
lowercase: bool = False,
strip_accents: bool = False,
split: bool = True,
split_on_cjk: bool = True,
suffix_indicator: str = "##",
oov_token: str = "[UNK]",
**kwargs,
Expand Down Expand Up @@ -271,6 +316,7 @@ def __init__(
self.lowercase = lowercase
self.strip_accents = strip_accents
self.split = split
self.split_on_cjk = split_on_cjk
self.suffix_indicator = suffix_indicator
self.oov_token = oov_token

Expand Down Expand Up @@ -335,7 +381,11 @@ def tokenize(self, inputs):

scalar_input = inputs.shape.rank == 0
inputs = pretokenize(
inputs, self.lowercase, self.strip_accents, self.split
inputs,
self.lowercase,
self.strip_accents,
self.split,
self.split_on_cjk,
)

# Apply WordPiece and coerce shape for outputs.
Expand Down
12 changes: 11 additions & 1 deletion keras_nlp/tokenizers/word_piece_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ def test_special_tokens(self):
tf.ragged.constant([["qu", "@@ick", "br", "@@own", "@UNK@"]]),
)

def test_cjk_tokens(self):
input_data = ["ah半推zz"]
vocab_data = ["[UNK]", "推", "敐", "乐", "半", "偷", "匕", "ah", "zz"]
tokenizer = WordPieceTokenizer(vocabulary=vocab_data, dtype="string")
call_output = tokenizer(input_data)
self.assertAllEqual(
call_output,
tf.ragged.constant([["ah", "半", "推", "zz"]]),
)

def test_lowercase(self):
input_data = ["the QUicK brOWN FOX"]
vocab_data = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox"]
Expand Down Expand Up @@ -119,7 +129,7 @@ def test_skip_strip_accents(self):
call_output = tokenizer(input_data)
self.assertAllEqual(call_output, [[1, 2, 3, 4, 5]])

def test_no_spliting(self):
def test_no_splitting(self):
input_data = ["t o k e n", "m i s s i n g", "t o k e n"]
vocab_data = ["[UNK]", "t o k e n"]
tokenizer = WordPieceTokenizer(vocabulary=vocab_data, split=False)
Expand Down
9 changes: 8 additions & 1 deletion keras_nlp/tokenizers/word_piece_tokenizer_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def compute_word_piece_vocabulary(
lowercase=False,
strip_accents=False,
split=True,
split_on_cjk=True,
suffix_indicator="##",
reserved_tokens=["[PAD]", "[CLS]", "[SEP]", "[UNK]", "[MASK]"],
):
Expand Down Expand Up @@ -55,6 +56,10 @@ def compute_word_piece_vocabulary(
before calling the tokenizer, and passed as a dense or ragged tensor
of whole words. `split` is required to be `True` when `data` is a
list of filenames.
split_on_cjk: bool, defaults to `True`. If true, input will be split
on CJK characters, i.e., Chinese, Japanese, Korean and Vietnamese
characters (https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)).
Note that this is applicable only when `split` is true.
suffix_indicator: str, defaults to "##". The characters prepended to a
WordPiece to indicate that it is a suffix to another subword.
E.g. "##ing".
Expand Down Expand Up @@ -138,7 +143,9 @@ def normalize_and_split(x):
)

words_data = data.map(
lambda text: pretokenize(text, lowercase, strip_accents, split),
lambda text: pretokenize(
text, lowercase, strip_accents, split, split_on_cjk
),
num_parallel_calls=tf.data.AUTOTUNE,
)
word_counts = learner.count_words(words_data)
Expand Down
23 changes: 23 additions & 0 deletions keras_nlp/tokenizers/word_piece_tokenizer_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,29 @@ def test_split(self):
)
self.assertAllEqual(output_vocab_1, output_vocab_2)

def test_split_on_cjk(self):
test_text = tf.data.Dataset.from_tensor_slices(["ah半推zz"])
test_text_split = tf.data.Dataset.from_tensor_slices(
["ah", "半", "推", "zz"]
)
output_vocab_1 = compute_word_piece_vocabulary(
test_text,
4,
split=True,
split_on_cjk=True,
lowercase=False,
strip_accents=False,
)
output_vocab_2 = compute_word_piece_vocabulary(
test_text_split,
4,
split=False,
split_on_cjk=False,
lowercase=False,
strip_accents=False,
)
self.assertAllEqual(output_vocab_1, output_vocab_2)

def test_skip_split(self):
test_text = tf.data.Dataset.from_tensor_slices(
[
Expand Down