Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions keras_nlp/tokenizers/word_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class WordPieceTokenizer(tokenizer.Tokenizer):

If a more custom pre-tokenization step is desired, the layer can be
configured to apply only the strict WordPiece algorithm by passing
`lowercase=False`, `strip_accents=False` and `split_pattern=None`. In
`lowercase=False`, `strip_accents=False` and `split=False`. In
this case, inputs should be pre-split string tensors or ragged tensors.

By default, the layer will output a `tf.RaggedTensor` where the last
Expand All @@ -101,10 +101,11 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
tokenization.
strip_accents: If true, all accent marks will be removed from text
before tokenization.
split_pattern: A regex pattern to match delimiters to split, or None
indicating that the input is pre-split and no splitting should be
performed. By default, all whitespace and punctuation marks will
be split on.
split: If true, input will be split according to `split_pattern`
and `keep_pattern`. If false, input should be split before calling
the layer.
split_pattern: A regex pattern to match delimiters to split. By default,
all whitespace and punctuation marks will be split on.
keep_pattern: A regex pattern of delimiters contained in the
`split_pattern` of delimeters that should be kept as independent
tokens. By default, all punctuation marks will be kept as tokens.
Expand Down Expand Up @@ -167,8 +168,9 @@ def __init__(
sequence_length: int = None,
lowercase: bool = True,
strip_accents: bool = True,
split_pattern: str = WHITESPACE_AND_PUNCTUATION_REGEX,
keep_pattern: str = PUNCTUATION_REGEX,
split: bool = True,
split_pattern: str = None,
keep_pattern: str = None,
suffix_indicator: str = "##",
oov_token: str = "[UNK]",
**kwargs,
Expand Down Expand Up @@ -201,9 +203,16 @@ def __init__(
if oov_token is None:
raise ValueError("`oov_token` cannot be None.")

if split_pattern is None:
split_pattern = WHITESPACE_AND_PUNCTUATION_REGEX

if keep_pattern is None:
keep_pattern = PUNCTUATION_REGEX

self.sequence_length = sequence_length
self.lowercase = lowercase
self.strip_accents = strip_accents
self.split = split
self.split_pattern = split_pattern
self.keep_pattern = keep_pattern
self.suffix_indicator = suffix_indicator
Expand Down Expand Up @@ -257,6 +266,7 @@ def get_config(self) -> Dict[str, Any]:
"sequence_length": self.sequence_length,
"lowercase": self.lowercase,
"strip_accents": self.strip_accents,
"split": self.split,
"split_pattern": self.split_pattern,
"keep_pattern": self.keep_pattern,
"suffix_indicator": self.suffix_indicator,
Expand All @@ -280,7 +290,7 @@ def tokenize(self, inputs):
inputs = tf_text.normalize_utf8(inputs, "NFD")
# Remove the accent marks.
inputs = tf.strings.regex_replace(inputs, r"\p{Mn}", "")
if self.split_pattern:
if self.split:
inputs = tf_text.regex_split(
inputs,
delim_regex_pattern=self.split_pattern,
Expand Down
6 changes: 2 additions & 4 deletions keras_nlp/tokenizers/word_piece_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,7 @@ def test_custom_spliting(self):
def test_no_spliting(self):
input_data = ["t o k e n", "m i s s i n g", "t o k e n"]
vocab_data = ["[UNK]", "t o k e n"]
tokenizer = WordPieceTokenizer(
vocabulary=vocab_data, split_pattern=None
)
tokenizer = WordPieceTokenizer(vocabulary=vocab_data, split=False)
call_output = tokenizer(input_data)
self.assertAllEqual(call_output, [1, 0, 1])

Expand All @@ -148,7 +146,7 @@ def test_word_piece_only(self):
vocabulary=vocab_data,
lowercase=False,
strip_accents=False,
split_pattern=None,
split=False,
)
call_output = tokenizer(input_data)
self.assertAllEqual(call_output, [1, 2, 3, 4, 5, 6])
Expand Down