diff --git a/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py index 14a86f8968..19c1240596 100644 --- a/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py @@ -67,6 +67,9 @@ class SentencePieceTokenizer(tokenizer.Tokenizer): for more details on the format. sequence_length: If set, the output will be converted to a dense tensor and padded/trimmed so all outputs are of `sequence_length`. + add_bos: Add beginning of sentence token to the result. + add_eos: Add end of sentence token to the result. Token is always + truncated if output is longer than specified `sequence_length`. References: - [Kudo and Richardson, 2018](https://arxiv.org/abs/1808.06226) @@ -116,6 +119,8 @@ def __init__( proto=None, sequence_length=None, dtype="int32", + add_bos=False, + add_eos=False, **kwargs, ) -> None: if not is_int_dtype(dtype) and not is_string_dtype(dtype): @@ -128,6 +133,8 @@ def __init__( self.proto = None self.sequence_length = sequence_length + self.add_bos = add_bos + self.add_eos = add_eos self.set_proto(proto) self.file_assets = [VOCAB_FILENAME] @@ -172,6 +179,8 @@ def set_proto(self, proto): self._sentence_piece = tf_text.SentencepieceTokenizer( model=proto_bytes, out_type=self.compute_dtype, + add_bos=self.add_bos, + add_eos=self.add_eos, ) # Keras cannot serialize a bytestring, so we base64 encode the model # byte array as a string for saving. @@ -212,6 +221,8 @@ def get_config(self): { "proto": None, # Save vocabulary via an asset! "sequence_length": self.sequence_length, + "add_bos": self.add_bos, + "add_eos": self.add_eos, } ) return config diff --git a/keras_nlp/src/tokenizers/sentence_piece_tokenizer_test.py b/keras_nlp/src/tokenizers/sentence_piece_tokenizer_test.py index e4d09f44b6..16a76f810c 100644 --- a/keras_nlp/src/tokenizers/sentence_piece_tokenizer_test.py +++ b/keras_nlp/src/tokenizers/sentence_piece_tokenizer_test.py @@ -70,6 +70,29 @@ def test_string_tokenize(self): [["▁the", "▁quick", "▁brown", "▁fox."]], ) + def test_scalar_bos_eos(self): + input_data = "the quick brown fox." + tokenizer = SentencePieceTokenizer( + proto=self.proto, + add_bos=True, + add_eos=True, + ) + output_data = tokenizer(input_data) + self.assertAllEqual(output_data, [1, 6, 5, 3, 4, 2]) + + def test_string_bos_eos(self): + input_data = ["the quick brown fox."] + tokenizer = SentencePieceTokenizer( + proto=self.proto, + dtype="string", + add_bos=True, + add_eos=True, + ) + output_data = tokenizer(input_data) + self.assertAllEqual( + output_data, [["", "▁the", "▁quick", "▁brown", "▁fox.", ""]] + ) + def test_detokenize(self): tokenizer = SentencePieceTokenizer(proto=self.proto) outputs = tokenizer.detokenize([6, 5, 3, 4])