diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py index 936773ce05..931410a5be 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py @@ -93,3 +93,30 @@ def test_saved_model(self, save_format, filename): # Check that output matches. restored_output = restored_model(self.input_batch) self.assertAllClose(model_output, restored_output) + + +@pytest.mark.tpu +@pytest.mark.usefixtures("tpu_test_class") +class DebertaV3BackboneTPUTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + with self.tpu_strategy.scope(): + self.backbone = DebertaV3Backbone( + vocabulary_size=10, + num_layers=2, + num_heads=2, + hidden_dim=2, + intermediate_dim=4, + max_sequence_length=5, + bucket_size=2, + ) + self.input_batch = { + "token_ids": tf.ones((2, 5), dtype="int32"), + "padding_mask": tf.ones((2, 5), dtype="int32"), + } + self.input_dataset = tf.data.Dataset.from_tensor_slices( + self.input_batch + ).batch(2) + + def test_predict(self): + self.backbone.compile() + self.backbone.predict(self.input_dataset)