diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index 900692d877..cc1766ecb1 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -80,7 +80,7 @@ class WordPieceTokenizer(tokenizer.Tokenizer): If a more custom pre-tokenization step is desired, the layer can be configured to apply only the strict WordPiece algorithm by passing - `lowercase=False`, `strip_accents=False` and `split_pattern=None`. In + `lowercase=False`, `strip_accents=False` and `split=False`. In this case, inputs should be pre-split string tensors or ragged tensors. By default, the layer will output a `tf.RaggedTensor` where the last @@ -101,10 +101,11 @@ class WordPieceTokenizer(tokenizer.Tokenizer): tokenization. strip_accents: If true, all accent marks will be removed from text before tokenization. - split_pattern: A regex pattern to match delimiters to split, or None - indicating that the input is pre-split and no splitting should be - performed. By default, all whitespace and punctuation marks will - be split on. + split: If true, input will be split according to `split_pattern` + and `keep_pattern`. If false, input should be split before calling + the layer. + split_pattern: A regex pattern to match delimiters to split. By default, + all whitespace and punctuation marks will be split on. keep_pattern: A regex pattern of delimiters contained in the `split_pattern` of delimeters that should be kept as independent tokens. By default, all punctuation marks will be kept as tokens. @@ -167,8 +168,9 @@ def __init__( sequence_length: int = None, lowercase: bool = True, strip_accents: bool = True, - split_pattern: str = WHITESPACE_AND_PUNCTUATION_REGEX, - keep_pattern: str = PUNCTUATION_REGEX, + split: bool = True, + split_pattern: str = None, + keep_pattern: str = None, suffix_indicator: str = "##", oov_token: str = "[UNK]", **kwargs, @@ -201,9 +203,16 @@ def __init__( if oov_token is None: raise ValueError("`oov_token` cannot be None.") + if split_pattern is None: + split_pattern = WHITESPACE_AND_PUNCTUATION_REGEX + + if keep_pattern is None: + keep_pattern = PUNCTUATION_REGEX + self.sequence_length = sequence_length self.lowercase = lowercase self.strip_accents = strip_accents + self.split = split self.split_pattern = split_pattern self.keep_pattern = keep_pattern self.suffix_indicator = suffix_indicator @@ -257,6 +266,7 @@ def get_config(self) -> Dict[str, Any]: "sequence_length": self.sequence_length, "lowercase": self.lowercase, "strip_accents": self.strip_accents, + "split": self.split, "split_pattern": self.split_pattern, "keep_pattern": self.keep_pattern, "suffix_indicator": self.suffix_indicator, @@ -280,7 +290,7 @@ def tokenize(self, inputs): inputs = tf_text.normalize_utf8(inputs, "NFD") # Remove the accent marks. inputs = tf.strings.regex_replace(inputs, r"\p{Mn}", "") - if self.split_pattern: + if self.split: inputs = tf_text.regex_split( inputs, delim_regex_pattern=self.split_pattern, diff --git a/keras_nlp/tokenizers/word_piece_tokenizer_test.py b/keras_nlp/tokenizers/word_piece_tokenizer_test.py index c7430ae380..fc3e6237e4 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer_test.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer_test.py @@ -135,9 +135,7 @@ def test_custom_spliting(self): def test_no_spliting(self): input_data = ["t o k e n", "m i s s i n g", "t o k e n"] vocab_data = ["[UNK]", "t o k e n"] - tokenizer = WordPieceTokenizer( - vocabulary=vocab_data, split_pattern=None - ) + tokenizer = WordPieceTokenizer(vocabulary=vocab_data, split=False) call_output = tokenizer(input_data) self.assertAllEqual(call_output, [1, 0, 1]) @@ -148,7 +146,7 @@ def test_word_piece_only(self): vocabulary=vocab_data, lowercase=False, strip_accents=False, - split_pattern=None, + split=False, ) call_output = tokenizer(input_data) self.assertAllEqual(call_output, [1, 2, 3, 4, 5, 6])