Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
)
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer
from keras_nlp.models.whisper.whisper_backbone import WhisperBackbone
from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone
from keras_nlp.models.xlm_roberta.xlm_roberta_classifier import (
Expand Down
76 changes: 76 additions & 0 deletions keras_nlp/models/t5/t5_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""T5 tokenizer."""

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer


@keras_nlp_export("keras_nlp.models.T5Tokenizer")
class T5Tokenizer(SentencePieceTokenizer):
"""T5 tokenizer layer based on SentencePiece.

This tokenizer class will tokenize raw strings into integer sequences and
is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the
underlying tokenizer, it will check for all special tokens needed by
T5 models and provides a `from_preset()` method to automatically
download a matching vocabulary for a T5 preset.

If input is a batch of strings (rank > 0), the layer will output a
`tf.RaggedTensor` where the last dimension of the output is ragged.

If input is a scalar string (rank == 0), the layer will output a dense
`tf.Tensor` with static shape `[None]`.

Args:
proto: Either a `string` path to a SentencePiece proto file, or a
`bytes` object with a serialized SentencePiece proto. See the
[SentencePiece repository](https://github.com/google/sentencepiece)
for more details on the format.

Examples:

```python
tokenizer = keras_nlp.models.T5Tokenizer(proto="model.spm")

# Batched inputs.
tokenizer(["the quick brown fox", "the earth is round"])

# Unbatched inputs.
tokenizer("the quick brown fox")

# Detokenization.
tokenizer.detokenize(tf.constant([[2, 14, 2231, 886, 2385, 3]]))
```
"""

def __init__(self, proto, **kwargs):
super().__init__(proto=proto, **kwargs)

# Check for necessary special tokens.
end_token = "</s>"
pad_token = "<pad>"
for token in [pad_token]:
if token not in self.get_vocabulary():
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self.pad_token_id = self.token_to_id(pad_token)
self.end_token_id = self.token_to_id(end_token)
# T5 uses the same start token as end token, i.e., "<\s>".
self.start_token_id = self.end_token_id
102 changes: 102 additions & 0 deletions keras_nlp/models/t5/t5_tokenizer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for T5 tokenizer."""

import io
import os

import sentencepiece
import tensorflow as tf
from absl.testing import parameterized
from tensorflow import keras

from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer


class T5TokenizerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
bytes_io = io.BytesIO()
vocab_data = tf.data.Dataset.from_tensor_slices(
["the quick brown fox", "the earth is round"]
)
sentencepiece.SentencePieceTrainer.train(
sentence_iterator=vocab_data.as_numpy_iterator(),
model_writer=bytes_io,
vocab_size=11,
model_type="WORD",
bos_id=-1,
pad_id=0,
eos_id=1,
unk_id=2,
pad_piece="<pad>",
eos_piece="</s>",
unk_piece="<unk>",
user_defined_symbols="[MASK]",
)
self.proto = bytes_io.getvalue()

self.tokenizer = T5Tokenizer(proto=self.proto)

def test_tokenize(self):
input_data = "the quick brown fox"
output = self.tokenizer(input_data)
self.assertAllEqual(output, [4, 9, 5, 7])

def test_tokenize_batch(self):
input_data = tf.constant(["the quick brown fox", "the earth is round"])
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[4, 9, 5, 7], [4, 6, 8, 10]])

def test_detokenize(self):
input_data = tf.constant([[4, 9, 5, 7]])
output = self.tokenizer.detokenize(input_data)
self.assertEqual(output, tf.constant(["the quick brown fox"]))

def test_vocabulary_size(self):
tokenizer = T5Tokenizer(proto=self.proto)
self.assertEqual(tokenizer.vocabulary_size(), 11)

def test_errors_missing_special_tokens(self):
bytes_io = io.BytesIO()
sentencepiece.SentencePieceTrainer.train(
sentence_iterator=iter(["abc"]),
model_writer=bytes_io,
vocab_size=5,
pad_id=-1,
eos_id=-1,
bos_id=-1,
)
with self.assertRaises(ValueError):
T5Tokenizer(proto=bytes_io.getvalue())

@parameterized.named_parameters(
("tf_format", "tf", "model"),
("keras_format", "keras_v3", "model.keras"),
)
def test_saved_model(self, save_format, filename):
input_data = tf.constant(["the quick brown fox"])

inputs = keras.Input(dtype="string", shape=())
outputs = self.tokenizer(inputs)
model = keras.Model(inputs, outputs)

path = os.path.join(self.get_temp_dir(), filename)
model.save(path, save_format=save_format)

restored_model = keras.models.load_model(path)
self.assertAllEqual(
model(input_data),
restored_model(input_data),
)