diff --git a/keras_nlp/models/t5/t5_backbone_test.py b/keras_nlp/models/t5/t5_backbone_test.py index de1d1bd741..53c283959e 100644 --- a/keras_nlp/models/t5/t5_backbone_test.py +++ b/keras_nlp/models/t5/t5_backbone_test.py @@ -25,7 +25,7 @@ class T5Test(tf.test.TestCase, parameterized.TestCase): def setUp(self): - self.model = T5Backbone( + self.backbone = T5Backbone( vocabulary_size=4, num_layers=2, num_heads=2, @@ -53,12 +53,20 @@ def setUp(self): ).batch(2) def test_valid_call_t5(self): - self.model(self.input_batch) + self.backbone(self.input_batch) + + def test_token_embedding(self): + output = self.backbone.token_embedding( + self.input_batch["encoder_token_ids"] + ) + self.assertEqual(output.shape, (2, 3, 4)) + + def test_name(self): # Check default name passed through - self.assertRegexpMatches(self.model.name, "t5_backbone") + self.assertRegexpMatches(self.backbone.name, "t5_backbone") def test_variable_sequence_length_call_t5(self): - for seq_length in (4, 5, 6): + for seq_length in (2, 3, 4): input_data = { "encoder_token_ids": tf.ones( (self.batch_size, seq_length), dtype="int32" @@ -73,36 +81,29 @@ def test_variable_sequence_length_call_t5(self): (self.batch_size, seq_length), dtype="int32" ), } - outputs = self.model(input_data) + outputs = self.backbone(input_data) self.assertIn("encoder_sequence_output", outputs) self.assertIn("decoder_sequence_output", outputs) - @parameterized.named_parameters( - ("jit_compile_false", False), ("jit_compile_true", True) - ) - def test_t5_compile(self, jit_compile): - self.model.compile(jit_compile=jit_compile) - outputs = self.model.predict(self.input_batch) - self.assertIn("encoder_sequence_output", outputs) - self.assertIn("decoder_sequence_output", outputs) + def test_predict(self): + self.backbone.predict(self.input_batch) + self.backbone.predict(self.input_dataset) - @parameterized.named_parameters( - ("jit_compile_false", False), ("jit_compile_true", True) - ) - def test_t5_compile_batched_ds(self, jit_compile): - self.model.compile(jit_compile=jit_compile) - outputs = self.model.predict(self.input_dataset) - self.assertIn("encoder_sequence_output", outputs) - self.assertIn("decoder_sequence_output", outputs) + def test_serialization(self): + new_backbone = keras.utils.deserialize_keras_object( + keras.utils.serialize_keras_object(self.backbone) + ) + self.assertEqual(new_backbone.get_config(), self.backbone.get_config()) @parameterized.named_parameters( ("tf_format", "tf", "model"), ("keras_format", "keras_v3", "model.keras"), ) + @pytest.mark.large # Saving is slow, so mark these large. def test_saved_model(self, save_format, filename): - outputs = self.model(self.input_batch) + outputs = self.backbone(self.input_batch) save_path = os.path.join(self.get_temp_dir(), filename) - self.model.save(save_path, save_format=save_format) + self.backbone.save(save_path, save_format=save_format) restored_model = keras.models.load_model(save_path) # Check we got the real object back. @@ -119,7 +120,7 @@ def test_saved_model(self, save_format, filename): class T5BackboneTPUTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): with self.tpu_strategy.scope(): - self.model = T5Backbone( + self.backbone = T5Backbone( vocabulary_size=4, num_layers=2, num_heads=2, @@ -136,6 +137,6 @@ def setUp(self): def test_predict(self): self.model.compile() - outputs = self.model.predict(self.input_dataset) + outputs = self.backbone.predict(self.input_dataset) self.assertIn("encoder_sequence_output", outputs) self.assertIn("decoder_sequence_output", outputs) diff --git a/keras_nlp/models/t5/t5_tokenizer_test.py b/keras_nlp/models/t5/t5_tokenizer_test.py index dfaa700ddc..a4b62b75ac 100644 --- a/keras_nlp/models/t5/t5_tokenizer_test.py +++ b/keras_nlp/models/t5/t5_tokenizer_test.py @@ -17,6 +17,7 @@ import io import os +import pytest import sentencepiece import tensorflow as tf from absl.testing import parameterized @@ -81,10 +82,19 @@ def test_errors_missing_special_tokens(self): with self.assertRaises(ValueError): T5Tokenizer(proto=bytes_io.getvalue()) + def test_serialization(self): + config = keras.utils.serialize_keras_object(self.tokenizer) + new_tokenizer = keras.utils.deserialize_keras_object(config) + self.assertEqual( + new_tokenizer.get_config(), + self.tokenizer.get_config(), + ) + @parameterized.named_parameters( ("tf_format", "tf", "model"), ("keras_format", "keras_v3", "model.keras"), ) + @pytest.mark.large # Saving is slow, so mark these large. def test_saved_model(self, save_format, filename): input_data = tf.constant(["the quick brown fox"])