diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index dda0965b44..d9c9360c5f 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -83,6 +83,7 @@ ) from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer +from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer from keras_nlp.models.whisper.whisper_backbone import WhisperBackbone from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone from keras_nlp.models.xlm_roberta.xlm_roberta_classifier import ( diff --git a/keras_nlp/models/t5/t5_tokenizer.py b/keras_nlp/models/t5/t5_tokenizer.py new file mode 100644 index 0000000000..d6c058c577 --- /dev/null +++ b/keras_nlp/models/t5/t5_tokenizer.py @@ -0,0 +1,76 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5 tokenizer.""" + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer + + +@keras_nlp_export("keras_nlp.models.T5Tokenizer") +class T5Tokenizer(SentencePieceTokenizer): + """T5 tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + T5 models and provides a `from_preset()` method to automatically + download a matching vocabulary for a T5 preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + + ```python + tokenizer = keras_nlp.models.T5Tokenizer(proto="model.spm") + + # Batched inputs. + tokenizer(["the quick brown fox", "the earth is round"]) + + # Unbatched inputs. + tokenizer("the quick brown fox") + + # Detokenization. + tokenizer.detokenize(tf.constant([[2, 14, 2231, 886, 2385, 3]])) + ``` + """ + + def __init__(self, proto, **kwargs): + super().__init__(proto=proto, **kwargs) + + # Check for necessary special tokens. + end_token = "" + pad_token = "" + for token in [pad_token]: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your " + "`vocabulary` or use a pretrained `vocabulary` name." + ) + + self.pad_token_id = self.token_to_id(pad_token) + self.end_token_id = self.token_to_id(end_token) + # T5 uses the same start token as end token, i.e., "<\s>". + self.start_token_id = self.end_token_id diff --git a/keras_nlp/models/t5/t5_tokenizer_test.py b/keras_nlp/models/t5/t5_tokenizer_test.py new file mode 100644 index 0000000000..dfaa700ddc --- /dev/null +++ b/keras_nlp/models/t5/t5_tokenizer_test.py @@ -0,0 +1,102 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for T5 tokenizer.""" + +import io +import os + +import sentencepiece +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer + + +class T5TokenizerTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + bytes_io = io.BytesIO() + vocab_data = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round"] + ) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=11, + model_type="WORD", + bos_id=-1, + pad_id=0, + eos_id=1, + unk_id=2, + pad_piece="", + eos_piece="", + unk_piece="", + user_defined_symbols="[MASK]", + ) + self.proto = bytes_io.getvalue() + + self.tokenizer = T5Tokenizer(proto=self.proto) + + def test_tokenize(self): + input_data = "the quick brown fox" + output = self.tokenizer(input_data) + self.assertAllEqual(output, [4, 9, 5, 7]) + + def test_tokenize_batch(self): + input_data = tf.constant(["the quick brown fox", "the earth is round"]) + output = self.tokenizer(input_data) + self.assertAllEqual(output, [[4, 9, 5, 7], [4, 6, 8, 10]]) + + def test_detokenize(self): + input_data = tf.constant([[4, 9, 5, 7]]) + output = self.tokenizer.detokenize(input_data) + self.assertEqual(output, tf.constant(["the quick brown fox"])) + + def test_vocabulary_size(self): + tokenizer = T5Tokenizer(proto=self.proto) + self.assertEqual(tokenizer.vocabulary_size(), 11) + + def test_errors_missing_special_tokens(self): + bytes_io = io.BytesIO() + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=iter(["abc"]), + model_writer=bytes_io, + vocab_size=5, + pad_id=-1, + eos_id=-1, + bos_id=-1, + ) + with self.assertRaises(ValueError): + T5Tokenizer(proto=bytes_io.getvalue()) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + input_data = tf.constant(["the quick brown fox"]) + + inputs = keras.Input(dtype="string", shape=()) + outputs = self.tokenizer(inputs) + model = keras.Model(inputs, outputs) + + path = os.path.join(self.get_temp_dir(), filename) + model.save(path, save_format=save_format) + + restored_model = keras.models.load_model(path) + self.assertAllEqual( + model(input_data), + restored_model(input_data), + )