From d9ad1ac6d8834be1ca2d98dd549c7719c95af28d Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Mon, 20 Mar 2023 14:01:46 +0530 Subject: [PATCH 1/2] modified tests for whisper --- .../models/whisper/whisper_backbone_test.py | 95 ++++++++----------- 1 file changed, 40 insertions(+), 55 deletions(-) diff --git a/keras_nlp/models/whisper/whisper_backbone_test.py b/keras_nlp/models/whisper/whisper_backbone_test.py index df4f25742d..0dc80ba1b1 100644 --- a/keras_nlp/models/whisper/whisper_backbone_test.py +++ b/keras_nlp/models/whisper/whisper_backbone_test.py @@ -26,33 +26,19 @@ class WhisperBackboneTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): - self.model = WhisperBackbone( - vocabulary_size=1000, + self.backbone = WhisperBackbone( + vocabulary_size=10, num_layers=2, num_heads=2, - hidden_dim=64, - intermediate_dim=128, - max_encoder_sequence_length=128, - max_decoder_sequence_length=96, + hidden_dim=2, + intermediate_dim=4, + max_encoder_sequence_length=6, + max_decoder_sequence_length=6, ) - self.batch_size = 8 self.input_batch = { - "encoder_features": tf.ones( - ( - self.batch_size, - self.model.max_encoder_sequence_length, - NUM_MELS, - ), - dtype="int32", - ), - "decoder_token_ids": tf.ones( - (self.batch_size, self.model.max_decoder_sequence_length), - dtype="int32", - ), - "decoder_padding_mask": tf.ones( - (self.batch_size, self.model.max_decoder_sequence_length), - dtype="int32", - ), + "encoder_features": tf.ones((2, 5, NUM_MELS), dtype="int32"), + "decoder_token_ids": tf.ones((2, 5), dtype="int32"), + "decoder_padding_mask": tf.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( @@ -60,55 +46,53 @@ def setUp(self): ).batch(2) def test_valid_call_whisper(self): - self.model(self.input_batch) + self.backbone(self.input_batch) + + def test_token_embedding(self): + output = self.backbone.token_embedding( + self.input_batch["decoder_token_ids"] + ) + self.assertEqual(output.shape, (2, 5, 2)) + def test_name(self): # Check default name passed through - self.assertRegexpMatches(self.model.name, "whisper_backbone") + self.assertRegexpMatches(self.backbone.name, "whisper_backbone") def test_variable_sequence_length_call_whisper(self): - for seq_length in (25, 50, 75): + for seq_length in (2, 3, 4): input_data = { "encoder_features": tf.ones( - (self.batch_size, seq_length, NUM_MELS), - dtype="int32", - ), - "decoder_token_ids": tf.ones( - (self.batch_size, seq_length), dtype="int32" - ), - "decoder_padding_mask": tf.ones( - (self.batch_size, seq_length), dtype="int32" + (2, seq_length, NUM_MELS), dtype="int32" ), + "decoder_token_ids": tf.ones((2, seq_length), dtype="int32"), + "decoder_padding_mask": tf.ones((2, seq_length), dtype="int32"), } - self.model(input_data) + self.backbone(input_data) - @parameterized.named_parameters( - ("jit_compile_false", False), ("jit_compile_true", True) - ) - def test_compile(self, jit_compile): - self.model.compile(jit_compile=jit_compile) - self.model.predict(self.input_batch) + def test_predict(self): + self.backbone.predict(self.input_batch) + self.backbone.predict(self.input_dataset) - @parameterized.named_parameters( - ("jit_compile_false", False), ("jit_compile_true", True) - ) - def test_compile_batched_ds(self, jit_compile): - self.model.compile(jit_compile=jit_compile) - self.model.predict(self.input_dataset) + def test_serialization(self): + new_backbone = keras.utils.deserialize_keras_object( + keras.utils.serialize_keras_object(self.backbone) + ) + self.assertEqual(new_backbone.get_config(), self.backbone.get_config()) def test_key_projection_bias_absence(self): # Check only for the first encoder layer and first decoder layer. self.assertIsNone( - self.model.get_layer( + self.backbone.get_layer( "transformer_encoder_layer_0" )._self_attention_layer._key_dense.bias ) self.assertIsNone( - self.model.get_layer( + self.backbone.get_layer( "transformer_decoder_layer_0" )._self_attention_layer._key_dense.bias ) self.assertIsNone( - self.model.get_layer( + self.backbone.get_layer( "transformer_decoder_layer_0" )._cross_attention_layer._key_dense.bias ) @@ -117,10 +101,11 @@ def test_key_projection_bias_absence(self): ("tf_format", "tf", "model"), ("keras_format", "keras_v3", "model.keras"), ) + @pytest.mark.large # Saving is slow, so mark these large. def test_saved_model(self, save_format, filename): - model_output = self.model(self.input_batch) + model_output = self.backbone(self.input_batch) save_path = os.path.join(self.get_temp_dir(), filename) - self.model.save(save_path, save_format=save_format) + self.backbone.save(save_path, save_format=save_format) restored_model = keras.models.load_model(save_path) # Check we got the real object back. @@ -143,7 +128,7 @@ def test_saved_model(self, save_format, filename): class WhisperBackboneTPUTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): with self.tpu_strategy.scope(): - self.model = WhisperBackbone( + self.backbone = WhisperBackbone( vocabulary_size=1000, num_layers=2, num_heads=2, @@ -175,5 +160,5 @@ def setUp(self): ).batch(2) def test_predict(self): - self.model.compile() - self.model.predict(self.input_dataset) + self.backbone.compile() + self.backbone.predict(self.input_dataset) From 381d6f859367962982e10bfabba1a86105a0942b Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Tue, 21 Mar 2023 19:11:49 +0530 Subject: [PATCH 2/2] comments --- keras_nlp/models/whisper/whisper_backbone_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/keras_nlp/models/whisper/whisper_backbone_test.py b/keras_nlp/models/whisper/whisper_backbone_test.py index 0dc80ba1b1..b71bbfdaf0 100644 --- a/keras_nlp/models/whisper/whisper_backbone_test.py +++ b/keras_nlp/models/whisper/whisper_backbone_test.py @@ -129,13 +129,13 @@ class WhisperBackboneTPUTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): with self.tpu_strategy.scope(): self.backbone = WhisperBackbone( - vocabulary_size=1000, + vocabulary_size=10, num_layers=2, num_heads=2, - hidden_dim=64, - intermediate_dim=128, - max_encoder_sequence_length=128, - max_decoder_sequence_length=64, + hidden_dim=2, + intermediate_dim=4, + max_encoder_sequence_length=6, + max_decoder_sequence_length=6, ) self.input_batch = {