diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 28ee6ef64b..d2b9e6d446 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -102,7 +102,7 @@ jobs: env: KERAS_BACKEND: ${{ matrix.backend }} run: | - pytest --run_large keras_nlp/layers/modeling keras_nlp/samplers keras_nlp/tokenizers keras_nlp/metrics + pytest keras_nlp/ format: name: Check the code format runs-on: ubuntu-latest diff --git a/keras_nlp/conftest.py b/keras_nlp/conftest.py index 5f74c45c53..3a29b038c5 100644 --- a/keras_nlp/conftest.py +++ b/keras_nlp/conftest.py @@ -122,7 +122,12 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_tf_only) +# Disable traceback filtering for quicker debugging of tests failures. +tf.debugging.disable_traceback_filtering() if backend_config.multi_backend(): keras.config.disable_traceback_filtering() -tf.debugging.disable_traceback_filtering() +# One off setup for dtensor tests. +if not backend_config.multi_backend(): + keras.backend.experimental.enable_tf_random_generator() + keras.utils.set_random_seed(1337) diff --git a/keras_nlp/metrics/rouge_l.py b/keras_nlp/metrics/rouge_l.py index 201c31a707..ffa62bbbef 100644 --- a/keras_nlp/metrics/rouge_l.py +++ b/keras_nlp/metrics/rouge_l.py @@ -102,14 +102,13 @@ class RougeL(RougeBase): 3. Pass the metric to `model.compile()`. >>> inputs = keras.Input(shape=(), dtype='string') - >>> outputs = tf.strings.lower(inputs) + >>> outputs = keras.layers.Identity()(inputs) >>> model = keras.Model(inputs, outputs) >>> model.compile(metrics=[keras_nlp.metrics.RougeL()]) - >>> x = tf.constant(["HELLO THIS IS FUN"]) + >>> y_pred = x = tf.constant(["hello this is fun"]) >>> y = tf.constant(["hello this is awesome"]) - >>> metric_dict = model.evaluate(x, y, return_dict=True) - >>> metric_dict["f1_score"] - 0.75 + >>> model.compute_metrics(x, y, y_pred, sample_weight=None)["f1_score"] + 0.75 """ def __init__( diff --git a/keras_nlp/metrics/rouge_n.py b/keras_nlp/metrics/rouge_n.py index 46a5f528c9..d90122fa8c 100644 --- a/keras_nlp/metrics/rouge_n.py +++ b/keras_nlp/metrics/rouge_n.py @@ -121,13 +121,12 @@ class RougeN(RougeBase): 3. Pass the metric to `model.compile()`. >>> inputs = keras.Input(shape=(), dtype='string') - >>> outputs = tf.strings.lower(inputs) + >>> outputs = keras.layers.Identity()(inputs) >>> model = keras.Model(inputs, outputs) >>> model.compile(metrics=[keras_nlp.metrics.RougeN()]) - >>> x = tf.constant(["HELLO THIS IS FUN"]) + >>> y_pred = x = tf.constant(["hello this is fun"]) >>> y = tf.constant(["hello this is awesome"]) - >>> metric_dict = model.evaluate(x, y, return_dict=True) - >>> metric_dict["f1_score"] + >>> model.compute_metrics(x, y, y_pred, sample_weight=None)["f1_score"] 0.6666666865348816 """ diff --git a/keras_nlp/models/albert/albert_backbone_test.py b/keras_nlp/models/albert/albert_backbone_test.py index 94d0f26297..a6bfdc18e1 100644 --- a/keras_nlp/models/albert/albert_backbone_test.py +++ b/keras_nlp/models/albert/albert_backbone_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.albert.albert_backbone import AlbertBackbone from keras_nlp.tests.test_case import TestCase @@ -38,9 +39,9 @@ def setUp(self): ) self.batch_size = 8 self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "segment_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "segment_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( @@ -57,9 +58,9 @@ def test_name(self): def test_variable_sequence_length_call_albert(self): for seq_length in (2, 3, 4): input_data = { - "token_ids": tf.ones((2, seq_length), dtype="int32"), - "segment_ids": tf.ones((2, seq_length), dtype="int32"), - "padding_mask": tf.ones((2, seq_length), dtype="int32"), + "token_ids": ops.ones((2, seq_length), dtype="int32"), + "segment_ids": ops.ones((2, seq_length), dtype="int32"), + "padding_mask": ops.ones((2, seq_length), dtype="int32"), } self.backbone(input_data) @@ -121,9 +122,9 @@ def setUp(self): ) self.input_batch = { - "token_ids": tf.ones((8, 128), dtype="int32"), - "segment_ids": tf.ones((8, 128), dtype="int32"), - "padding_mask": tf.ones((8, 128), dtype="int32"), + "token_ids": ops.ones((8, 128), dtype="int32"), + "segment_ids": ops.ones((8, 128), dtype="int32"), + "padding_mask": ops.ones((8, 128), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch diff --git a/keras_nlp/models/albert/albert_classifier.py b/keras_nlp/models/albert/albert_classifier.py index 53a51b0ac5..fe2bd7c7ff 100644 --- a/keras_nlp/models/albert/albert_classifier.py +++ b/keras_nlp/models/albert/albert_classifier.py @@ -22,7 +22,6 @@ from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -192,7 +191,7 @@ def __init__( ), optimizer=keras.optimizers.Adam(5e-5), metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) def get_config(self): diff --git a/keras_nlp/models/albert/albert_classifier_test.py b/keras_nlp/models/albert/albert_classifier_test.py index f22c921a50..0dd09d32d9 100644 --- a/keras_nlp/models/albert/albert_classifier_test.py +++ b/keras_nlp/models/albert/albert_classifier_test.py @@ -21,6 +21,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.albert.albert_backbone import AlbertBackbone from keras_nlp.models.albert.albert_classifier import AlbertClassifier from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor @@ -77,15 +78,13 @@ def setUp(self): activation=keras.activations.softmax, ) - self.raw_batch = tf.constant( - [ - "the quick brown fox.", - "the slow brown fox.", - ] - ) + self.raw_batch = [ + "the quick brown fox.", + "the slow brown fox.", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch) self.raw_dataset = tf.data.Dataset.from_tensor_slices( - (self.raw_batch, tf.ones((2,))) + (self.raw_batch, ops.ones((2,))) ).batch(2) self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) @@ -99,7 +98,7 @@ def test_classifier_predict(self): # Assert predictions match. self.assertAllClose(preds1, preds2) # Assert valid softmax output. - self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0]) + self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0]) def test_classifier_fit(self): self.classifier.fit(self.raw_dataset) diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py index 71bdd7721d..9843282c5a 100644 --- a/keras_nlp/models/albert/albert_masked_lm.py +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -26,7 +26,6 @@ ) from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -135,7 +134,7 @@ def __init__(self, backbone, preprocessor=None, **kwargs): loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) @classproperty diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py index a06cb9889b..522e7fcdda 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py @@ -152,6 +152,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py index 6ba68245dd..0e8869613d 100644 --- a/keras_nlp/models/albert/albert_masked_lm_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_test.py @@ -85,14 +85,12 @@ def setUp(self): preprocessor=None, ) - self.raw_batch = tf.constant( - [ - "quick brown fox", - "eagle flew over fox", - "the eagle flew quick", - "a brown eagle", - ] - ) + self.raw_batch = [ + "quick brown fox", + "eagle flew over fox", + "the eagle flew quick", + "a brown eagle", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] self.raw_dataset = tf.data.Dataset.from_tensor_slices( self.raw_batch diff --git a/keras_nlp/models/albert/albert_preprocessor_test.py b/keras_nlp/models/albert/albert_preprocessor_test.py index 5234fb0047..14f6581b03 100644 --- a/keras_nlp/models/albert/albert_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_preprocessor_test.py @@ -166,6 +166,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) inputs = keras.Input(dtype="string", shape=()) diff --git a/keras_nlp/models/albert/albert_presets_test.py b/keras_nlp/models/albert/albert_presets_test.py index 535091f0c6..09904246a2 100644 --- a/keras_nlp/models/albert/albert_presets_test.py +++ b/keras_nlp/models/albert/albert_presets_test.py @@ -14,9 +14,9 @@ """Tests for loading pretrained model presets.""" import pytest -import tensorflow as tf from absl.testing import parameterized +from keras_nlp.backend import ops from keras_nlp.models.albert.albert_backbone import AlbertBackbone from keras_nlp.models.albert.albert_classifier import AlbertClassifier from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor @@ -53,7 +53,7 @@ def test_preprocessor_output(self): ("load_weights", True), ("no_load_weights", False) ) def test_classifier_output(self, load_weights): - input_data = tf.constant(["The quick brown fox."]) + input_data = ["The quick brown fox."] model = AlbertClassifier.from_preset( "albert_base_en_uncased", num_classes=2, @@ -67,9 +67,9 @@ def test_classifier_output(self, load_weights): ) def test_classifier_output_without_preprocessing(self, load_weights): input_data = { - "token_ids": tf.constant([[101, 1996, 4248, 102]]), - "segment_ids": tf.constant([[0, 0, 0, 0]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[101, 1996, 4248, 102]]), + "segment_ids": ops.array([[0, 0, 0, 0]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = AlbertClassifier.from_preset( "albert_base_en_uncased", @@ -85,9 +85,9 @@ def test_classifier_output_without_preprocessing(self, load_weights): ) def test_backbone_output(self, load_weights): input_data = { - "token_ids": tf.constant([[2, 13, 1, 3]]), - "segment_ids": tf.constant([[0, 0, 0, 0]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[2, 13, 1, 3]]), + "segment_ids": ops.array([[0, 0, 0, 0]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = AlbertBackbone.from_preset( "albert_base_en_uncased", load_weights=load_weights @@ -139,13 +139,11 @@ def test_load_albert(self, load_weights): preset, load_weights=load_weights ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=model.vocabulary_size ), - "segment_ids": tf.constant( - [0] * 200 + [1] * 312, shape=(1, 512) - ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } model(input_data) @@ -159,7 +157,7 @@ def test_load_albert_classifier(self, load_weights): num_classes=2, load_weights=load_weights, ) - input_data = tf.constant(["This quick brown fox"]) + input_data = ["This quick brown fox."] classifier.predict(input_data) @parameterized.named_parameters( @@ -174,15 +172,13 @@ def test_load_albert_classifier_without_preprocessing(self, load_weights): load_weights=load_weights, ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=classifier.backbone.vocabulary_size, ), - "segment_ids": tf.constant( - [0] * 200 + [1] * 312, shape=(1, 512) - ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } classifier.predict(input_data) diff --git a/keras_nlp/models/albert/albert_tokenizer_test.py b/keras_nlp/models/albert/albert_tokenizer_test.py index 6e9decbc0d..010b1a46b9 100644 --- a/keras_nlp/models/albert/albert_tokenizer_test.py +++ b/keras_nlp/models/albert/albert_tokenizer_test.py @@ -56,14 +56,14 @@ def test_tokenize(self): self.assertAllEqual(output, [5, 10, 6, 8]) def test_tokenize_batch(self): - input_data = tf.constant(["the quick brown fox", "the earth is round"]) + input_data = ["the quick brown fox", "the earth is round"] output = self.tokenizer(input_data) self.assertAllEqual(output, [[5, 10, 6, 8], [5, 7, 9, 11]]) def test_detokenize(self): - input_data = tf.constant([[5, 10, 6, 8]]) + input_data = [[5, 10, 6, 8]] output = self.tokenizer.detokenize(input_data) - self.assertEqual(output, tf.constant(["the quick brown fox"])) + self.assertEqual(output, ["the quick brown fox"]) def test_vocabulary_size(self): tokenizer = AlbertTokenizer(proto=self.proto) @@ -91,6 +91,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) diff --git a/keras_nlp/models/bart/bart_backbone_test.py b/keras_nlp/models/bart/bart_backbone_test.py index 206facf9f6..e0782d4d65 100644 --- a/keras_nlp/models/bart/bart_backbone_test.py +++ b/keras_nlp/models/bart/bart_backbone_test.py @@ -17,86 +17,70 @@ import pytest import tensorflow as tf -from absl.testing import parameterized from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.bart.bart_backbone import BartBackbone from keras_nlp.tests.test_case import TestCase class BartBackboneTest(TestCase): def setUp(self): - self.model = BartBackbone( - vocabulary_size=1000, + self.backbone = BartBackbone( + vocabulary_size=10, num_layers=2, num_heads=2, - hidden_dim=64, - intermediate_dim=128, - max_sequence_length=128, + hidden_dim=3, + intermediate_dim=4, + max_sequence_length=5, ) - self.batch_size = 8 self.input_batch = { - "encoder_token_ids": tf.ones( - (self.batch_size, self.model.max_sequence_length), dtype="int32" - ), - "encoder_padding_mask": tf.ones( - (self.batch_size, self.model.max_sequence_length), dtype="int32" - ), - "decoder_token_ids": tf.ones( - (self.batch_size, self.model.max_sequence_length), dtype="int32" - ), - "decoder_padding_mask": tf.ones( - (self.batch_size, self.model.max_sequence_length), dtype="int32" - ), + "encoder_token_ids": ops.ones((2, 5), dtype="int32"), + "encoder_padding_mask": ops.ones((2, 5), dtype="int32"), + "decoder_token_ids": ops.ones((2, 5), dtype="int32"), + "decoder_padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch ).batch(2) - def test_valid_call_bart(self): - self.model(self.input_batch) + def test_valid_call(self): + self.backbone(self.input_batch) + def test_name(self): # Check default name passed through - self.assertRegexpMatches(self.model.name, "bart_backbone") + self.assertRegexpMatches(self.backbone.name, "bart_backbone") - def test_variable_sequence_length_call_bart(self): - for seq_length in (25, 50, 75): + def test_variable_sequence_length_call(self): + for seq_length in (2, 3, 4): input_data = { - "encoder_token_ids": tf.ones( - (self.batch_size, seq_length), dtype="int32" + "encoder_token_ids": ops.ones((2, seq_length), dtype="int32"), + "encoder_padding_mask": ops.ones( + (2, seq_length), dtype="int32" ), - "encoder_padding_mask": tf.ones( - (self.batch_size, seq_length), dtype="int32" - ), - "decoder_token_ids": tf.ones( - (self.batch_size, seq_length), dtype="int32" - ), - "decoder_padding_mask": tf.ones( - (self.batch_size, seq_length), dtype="int32" + "decoder_token_ids": ops.ones((2, seq_length), dtype="int32"), + "decoder_padding_mask": ops.ones( + (2, seq_length), dtype="int32" ), } - self.model(input_data) - - @parameterized.named_parameters( - ("jit_compile_false", False), ("jit_compile_true", True) - ) - def test_compile(self, jit_compile): - self.model.compile(jit_compile=jit_compile) - self.model.predict(self.input_batch) - - @parameterized.named_parameters( - ("jit_compile_false", False), ("jit_compile_true", True) - ) - def test_compile_batched_ds(self, jit_compile): - self.model.compile(jit_compile=jit_compile) - self.model.predict(self.input_dataset) + self.backbone(input_data) + + def test_predict(self): + self.backbone.predict(self.input_batch) + self.backbone.predict(self.input_dataset) + + def test_serialization(self): + new_backbone = keras.saving.deserialize_keras_object( + keras.saving.serialize_keras_object(self.backbone) + ) + self.assertEqual(new_backbone.get_config(), self.backbone.get_config()) @pytest.mark.large def test_saved_model(self): model_output = self.backbone(self.input_batch) path = os.path.join(self.get_temp_dir(), "model.keras") - self.model.save(path, save_format="keras_v3") + self.backbone.save(path, save_format="keras_v3") restored_model = keras.models.load_model(path) # Check we got the real object back. @@ -119,7 +103,7 @@ def test_saved_model(self): class BartBackboneTPUTest(TestCase): def setUp(self): with self.tpu_strategy.scope(): - self.model = BartBackbone( + self.backbone = BartBackbone( vocabulary_size=1000, num_layers=2, num_heads=2, @@ -128,15 +112,15 @@ def setUp(self): max_sequence_length=128, ) self.input_batch = { - "encoder_token_ids": tf.ones((8, 128), dtype="int32"), - "encoder_padding_mask": tf.ones((8, 128), dtype="int32"), - "decoder_token_ids": tf.ones((8, 128), dtype="int32"), - "decoder_padding_mask": tf.ones((8, 128), dtype="int32"), + "encoder_token_ids": ops.ones((8, 128), dtype="int32"), + "encoder_padding_mask": ops.ones((8, 128), dtype="int32"), + "decoder_token_ids": ops.ones((8, 128), dtype="int32"), + "decoder_padding_mask": ops.ones((8, 128), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch ).batch(2) def test_predict(self): - self.model.compile() - self.model.predict(self.input_dataset) + self.backbone.compile() + self.backbone.predict(self.input_dataset) diff --git a/keras_nlp/models/bart/bart_preprocessor_test.py b/keras_nlp/models/bart/bart_preprocessor_test.py index 3849792284..866b2eb537 100644 --- a/keras_nlp/models/bart/bart_preprocessor_test.py +++ b/keras_nlp/models/bart/bart_preprocessor_test.py @@ -168,6 +168,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = { "encoder_text": tf.constant([" airplane at airport"]), diff --git a/keras_nlp/models/bart/bart_presets_test.py b/keras_nlp/models/bart/bart_presets_test.py index 8c8b837289..324e840178 100644 --- a/keras_nlp/models/bart/bart_presets_test.py +++ b/keras_nlp/models/bart/bart_presets_test.py @@ -13,6 +13,7 @@ # limitations under the License. # Copyright 2023 The KerasNLP Authors # +from keras_nlp.backend import ops from keras_nlp.tests.test_case import TestCase # Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,7 +30,6 @@ """Tests for loading pretrained model presets.""" import pytest -import tensorflow as tf from absl.testing import parameterized from keras_nlp.models.bart.bart_backbone import BartBackbone @@ -58,10 +58,10 @@ def test_tokenizer_output(self): ) def test_backbone_output(self, load_weights): input_data = { - "encoder_token_ids": tf.constant([[0, 133, 2119, 2]]), - "encoder_padding_mask": tf.constant([[1, 1, 1, 1]]), - "decoder_token_ids": tf.constant([[0, 7199, 14, 2119, 2]]), - "decoder_padding_mask": tf.constant([[1, 1, 1, 1, 1]]), + "encoder_token_ids": ops.array([[0, 133, 2119, 2]]), + "encoder_padding_mask": ops.array([[1, 1, 1, 1]]), + "decoder_token_ids": ops.array([[0, 7199, 14, 2119, 2]]), + "decoder_padding_mask": ops.array([[1, 1, 1, 1, 1]]), } model = BartBackbone.from_preset( "bart_base_en", load_weights=load_weights @@ -116,20 +116,20 @@ def test_load_bart(self, load_weights): for preset in BartBackbone.presets: model = BartBackbone.from_preset(preset, load_weights=load_weights) input_data = { - "encoder_token_ids": tf.random.uniform( + "encoder_token_ids": ops.random.uniform( shape=(1, 1024), dtype="int64", maxval=model.vocabulary_size, ), - "encoder_padding_mask": tf.constant( + "encoder_padding_mask": ops.array( [1] * 768 + [0] * 256, shape=(1, 1024) ), - "decoder_token_ids": tf.random.uniform( + "decoder_token_ids": ops.random.uniform( shape=(1, 1024), dtype="int64", maxval=model.vocabulary_size, ), - "decoder_padding_mask": tf.constant( + "decoder_padding_mask": ops.array( [1] * 489 + [0] * 535, shape=(1, 1024) ), } diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm.py b/keras_nlp/models/bart/bart_seq_2_seq_lm.py index 20119e67fe..50868c4810 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm.py @@ -15,20 +15,32 @@ import copy -import tensorflow as tf - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.bart.bart_backbone import BartBackbone from keras_nlp.models.bart.bart_presets import backbone_presets from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import ( BartSeq2SeqLMPreprocessor, ) from keras_nlp.models.generative_task import GenerativeTask -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty +# TODO: Extend and factor this out into keras_nlp.layers. +class ReverseEmbedding(keras.layers.Layer): + def __init__(self, embedding, **kwargs): + super().__init__(**kwargs) + self.embedding = embedding + + def call(self, inputs): + kernel = ops.transpose(ops.convert_to_tensor(self.embedding.embeddings)) + return ops.matmul(inputs, kernel) + + def compute_output_shape(self, input_shape): + return (input_shape[0],) + (self.embedding.embeddings.shape[0],) + + @keras_nlp_export("keras_nlp.models.BartSeq2SeqLM") class BartSeq2SeqLM(GenerativeTask): """An end-to-end BART model for seq2seq language modeling. @@ -192,11 +204,10 @@ def __init__( x = backbone(inputs)["decoder_sequence_output"] # Use token embedding weights to project from the token representation # to vocabulary logits. - outputs = tf.matmul( - x, - backbone.token_embedding.embeddings, - transpose_b=True, - ) + outputs = ReverseEmbedding( + backbone.token_embedding, + name="reverse_embedding", + )(x) # Instantiate using Functional API Model constructor. super().__init__( @@ -216,7 +227,7 @@ def __init__( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(2e-5), metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) @classproperty @@ -305,11 +316,11 @@ def call_decoder_with_cache( # Every decoder layer has a separate cache for the self-attention layer # and the cross-attention layer. We update all of them separately. - self_attention_caches = tf.unstack(self_attention_cache, axis=1) - cross_attention_caches = tf.unstack(cross_attention_cache, axis=1) + self_attention_caches = [] + cross_attention_caches = [] for i in range(self.backbone.num_layers): - current_self_attention_cache = self_attention_caches[i] - current_cross_attention_cache = cross_attention_caches[i] + current_self_attention_cache = self_attention_cache[:, i, ...] + current_cross_attention_cache = cross_attention_cache[:, i, ...] ( x, @@ -326,22 +337,17 @@ def call_decoder_with_cache( ) if self_attention_cache_update_index is not None: - self_attention_caches[i] = next_self_attention_cache + self_attention_caches.append(next_self_attention_cache) if cross_attention_cache_update_index is not None: - cross_attention_caches[i] = next_cross_attention_cache + cross_attention_caches.append(next_cross_attention_cache) if self_attention_cache_update_index is not None: - self_attention_cache = tf.stack(self_attention_caches, axis=1) + self_attention_cache = ops.stack(self_attention_caches, axis=1) if cross_attention_cache_update_index is not None: - cross_attention_cache = tf.stack(cross_attention_caches, axis=1) + cross_attention_cache = ops.stack(cross_attention_caches, axis=1) hidden_states = x - - logits = tf.matmul( - hidden_states, - self.backbone.get_layer("token_embedding").embeddings, - transpose_b=True, - ) + logits = self.get_layer("reverse_embedding")(x) return ( logits, hidden_states, @@ -375,9 +381,9 @@ def call_encoder(self, token_ids, padding_mask): def _initialize_cache(self, encoder_token_ids, decoder_token_ids): """Initializes empty self-attention cache and cross-attention cache.""" - batch_size = tf.shape(encoder_token_ids)[0] - encoder_max_length = tf.shape(encoder_token_ids)[1] - decoder_max_length = tf.shape(decoder_token_ids)[1] + batch_size = ops.shape(encoder_token_ids)[0] + encoder_max_length = ops.shape(encoder_token_ids)[1] + decoder_max_length = ops.shape(decoder_token_ids)[1] num_layers = self.backbone.num_layers num_heads = self.backbone.num_heads @@ -391,10 +397,10 @@ def _initialize_cache(self, encoder_token_ids, decoder_token_ids): num_heads, head_dim, ] - self_attention_cache = tf.zeros(shape, dtype=self.compute_dtype) + self_attention_cache = ops.zeros(shape, dtype=self.compute_dtype) shape[3] = encoder_max_length - cross_attention_cache = tf.zeros(shape, dtype=self.compute_dtype) + cross_attention_cache = ops.zeros(shape, dtype=self.compute_dtype) return (self_attention_cache, cross_attention_cache) @@ -464,7 +470,7 @@ def generate_step( inputs["decoder_padding_mask"], ) - batch_size = tf.shape(encoder_token_ids)[0] + batch_size = ops.shape(encoder_token_ids)[0] # Create and seed cache with a single forward pass. ( @@ -476,24 +482,21 @@ def generate_step( encoder_token_ids, encoder_padding_mask, decoder_token_ids ) # Compute the lengths of all user inputted tokens ids. - row_lengths = tf.math.reduce_sum( - tf.cast(decoder_padding_mask, "int32"), axis=-1 - ) + row_lengths = ops.sum(ops.cast(decoder_padding_mask, "int32"), axis=-1) # Start at the first index that has no user inputted id. - index = tf.math.reduce_min(row_lengths) + index = ops.min(row_lengths) def next(prompt, cache, index): # The cache index is the index of our previous token. cache_index = index - 1 - prompt = tf.slice(prompt, [0, cache_index], [-1, 1]) - - num_samples = tf.shape(prompt)[0] + num_samples = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_index], [num_samples, 1]) def repeat_tensor(x): """Repeats tensors along batch axis to match dim for beam search.""" - if tf.shape(x)[0] == num_samples: + if ops.shape(x)[0] == num_samples: return x - return tf.repeat(x, repeats=num_samples // batch_size, axis=0) + return ops.repeat(x, repeats=num_samples // batch_size, axis=0) logits, hidden_states, cache, _ = self.call_decoder_with_cache( encoder_hidden_states=repeat_tensor(encoder_hidden_states), @@ -505,8 +508,8 @@ def repeat_tensor(x): cross_attention_cache_update_index=None, ) return ( - tf.squeeze(logits, axis=1), - tf.squeeze(hidden_states, axis=1), + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), cache, ) @@ -527,14 +530,17 @@ def repeat_tensor(x): end_locations = (decoder_token_ids == end_token_id) & ( ~decoder_padding_mask ) - end_locations = tf.cast(end_locations, "int32") + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after `end_locations`. - overflow = tf.math.cumsum(end_locations, exclusive=True, axis=-1) + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations # Our padding mask is the inverse of these overflow locations. - decoder_padding_mask = ~tf.cast(overflow, "bool") + decoder_padding_mask = ops.logical_not(ops.cast(overflow, "bool")) else: # Without early stopping, all locations will have been updated. - decoder_padding_mask = tf.ones_like(decoder_token_ids, dtype="bool") + decoder_padding_mask = ops.ones_like( + decoder_token_ids, dtype="bool" + ) return { "decoder_token_ids": decoder_token_ids, diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py index 034a10cd9f..d54a83984c 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py @@ -156,6 +156,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = { "encoder_text": tf.constant([" airplane at airport"]), @@ -163,8 +164,12 @@ def test_saved_model(self): } inputs = { - "encoder_text": keras.Input(dtype="string", shape=()), - "decoder_text": keras.Input(dtype="string", shape=()), + "encoder_text": keras.Input( + dtype="string", name="encoder_text", shape=() + ), + "decoder_text": keras.Input( + dtype="string", name="decoder_text", shape=() + ), } outputs, y, sw = self.preprocessor(inputs) model = keras.Model(inputs=inputs, outputs=outputs) diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py index 87f448c2e4..a3b9db4bd0 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py @@ -16,11 +16,11 @@ import os from unittest.mock import patch -import numpy as np import pytest import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.bart.bart_backbone import BartBackbone from keras_nlp.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import ( @@ -32,10 +32,6 @@ class BartSeq2SeqLMTest(TestCase): def setUp(self): - # For DTensor. - keras.backend.experimental.enable_tf_random_generator() - keras.utils.set_random_seed(1337) - self.vocab = { "": 0, "": 1, @@ -75,12 +71,8 @@ def setUp(self): ) self.raw_batch = { - "encoder_text": tf.constant( - [" airplane at airport", " airplane at airport"] - ), - "decoder_text": tf.constant( - [" kohli is the best", " kohli is the best"] - ), + "encoder_text": [" airplane at airport", " airplane at airport"], + "decoder_text": [" kohli is the best", " kohli is the best"], } self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] @@ -171,8 +163,10 @@ def wrapper(*args, **kwargs): self_attention_cache, cross_attention_cache, ) = call_decoder_with_cache(*args, **kwargs) - logits = np.zeros(logits.shape.as_list()) - logits[:, :, self.preprocessor.tokenizer.end_token_id] = 1.0e9 + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) return ( logits, hidden_states, @@ -194,10 +188,9 @@ def wrapper(*args, **kwargs): # We should immediately abort and output the prompt. self.assertAllEqual(inputs["decoder_text"], output) - self.assertEqual( - self.seq_2_seq_lm.call_decoder_with_cache.call_count, 2 - ) + # TODO: fix beam search. + @pytest.mark.tf_only def test_beam_search(self): seq_2_seq_lm = BartSeq2SeqLM( backbone=self.backbone, diff --git a/keras_nlp/models/bart/bart_tokenizer_test.py b/keras_nlp/models/bart/bart_tokenizer_test.py index f93849b898..3f750dd084 100644 --- a/keras_nlp/models/bart/bart_tokenizer_test.py +++ b/keras_nlp/models/bart/bart_tokenizer_test.py @@ -58,7 +58,7 @@ def test_tokenize_special_tokens(self): self.assertAllEqual(output, [0, 3, 4, 5, 3, 6, 0, 1]) def test_tokenize_batch(self): - input_data = tf.constant([" airplane at airport", " kohli is the best"]) + input_data = [" airplane at airport", " kohli is the best"] output = self.tokenizer(input_data) self.assertAllEqual(output, [[3, 4, 5, 3, 6], [7, 8, 9, 10, 11]]) @@ -74,7 +74,16 @@ def test_errors_missing_special_tokens(self): with self.assertRaises(ValueError): BartTokenizer(vocabulary=["a", "b", "c"], merges=[]) + def test_serialization(self): + config = keras.saving.serialize_keras_object(self.tokenizer) + new_tokenizer = keras.saving.deserialize_keras_object(config) + self.assertEqual( + new_tokenizer.get_config(), + self.tokenizer.get_config(), + ) + @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" airplane at airport"]) diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index 746343e657..056086fe28 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -168,12 +168,13 @@ def __init__( # Construct the two BERT outputs. The pooled output is a dense layer on # top of the [CLS] token. sequence_output = x - pooled_output = keras.layers.Dense( + x = keras.layers.Dense( hidden_dim, kernel_initializer=bert_kernel_initializer(), activation="tanh", name="pooled_dense", - )(x[:, cls_token_index, :]) + )(x) + pooled_output = x[:, cls_token_index, :] # Instantiate using Functional API Model constructor super().__init__( diff --git a/keras_nlp/models/bert/bert_backbone_test.py b/keras_nlp/models/bert/bert_backbone_test.py index 275220af6b..5c26fdc122 100644 --- a/keras_nlp/models/bert/bert_backbone_test.py +++ b/keras_nlp/models/bert/bert_backbone_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.bert.bert_backbone import BertBackbone from keras_nlp.tests.test_case import TestCase @@ -34,9 +35,9 @@ def setUp(self): max_sequence_length=5, ) self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "segment_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "segment_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch @@ -56,9 +57,9 @@ def test_name(self): def test_variable_sequence_length_call_bert(self): for seq_length in (2, 3, 4): input_data = { - "token_ids": tf.ones((2, seq_length), dtype="int32"), - "segment_ids": tf.ones((2, seq_length), dtype="int32"), - "padding_mask": tf.ones((2, seq_length), dtype="int32"), + "token_ids": ops.ones((2, seq_length), dtype="int32"), + "segment_ids": ops.ones((2, seq_length), dtype="int32"), + "padding_mask": ops.ones((2, seq_length), dtype="int32"), } self.backbone(input_data) @@ -84,9 +85,7 @@ def test_saved_model(self): # Check that output matches. restored_output = restored_model(self.input_batch) - self.assertAllClose( - model_output["pooled_output"], restored_output["pooled_output"] - ) + self.assertAllClose(model_output, restored_output) @pytest.mark.tpu @@ -103,9 +102,9 @@ def setUp(self): max_sequence_length=128, ) self.input_batch = { - "token_ids": tf.ones((8, 128), dtype="int32"), - "segment_ids": tf.ones((8, 128), dtype="int32"), - "padding_mask": tf.ones((8, 128), dtype="int32"), + "token_ids": ops.ones((8, 128), dtype="int32"), + "segment_ids": ops.ones((8, 128), dtype="int32"), + "padding_mask": ops.ones((8, 128), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch diff --git a/keras_nlp/models/bert/bert_classifier.py b/keras_nlp/models/bert/bert_classifier.py index 04b6f8c8d1..3dff53ed8d 100644 --- a/keras_nlp/models/bert/bert_classifier.py +++ b/keras_nlp/models/bert/bert_classifier.py @@ -23,7 +23,6 @@ from keras_nlp.models.bert.bert_presets import backbone_presets from keras_nlp.models.bert.bert_presets import classifier_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -177,7 +176,7 @@ def __init__( ), optimizer=keras.optimizers.Adam(5e-5), metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) def get_config(self): diff --git a/keras_nlp/models/bert/bert_classifier_test.py b/keras_nlp/models/bert/bert_classifier_test.py index 2ea2836e72..1ad831c3dc 100644 --- a/keras_nlp/models/bert/bert_classifier_test.py +++ b/keras_nlp/models/bert/bert_classifier_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.bert.bert_backbone import BertBackbone from keras_nlp.models.bert.bert_classifier import BertClassifier from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor @@ -52,15 +53,13 @@ def setUp(self): ) # Setup data. - self.raw_batch = tf.constant( - [ - "the quick brown fox.", - "the slow brown fox.", - ] - ) + self.raw_batch = [ + "the quick brown fox.", + "the slow brown fox.", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch) self.raw_dataset = tf.data.Dataset.from_tensor_slices( - (self.raw_batch, tf.ones((2,))) + (self.raw_batch, ops.ones((2,))) ).batch(2) self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) @@ -74,7 +73,7 @@ def test_classifier_predict(self): # Assert predictions match. self.assertAllClose(preds1, preds2) # Assert valid softmax output. - self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0]) + self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0]) def test_classifier_fit(self): self.classifier.fit(self.raw_dataset) @@ -84,6 +83,7 @@ def test_classifier_fit(self): def test_classifier_fit_no_xla(self): self.classifier.preprocessor = None self.classifier.compile( + optimizer="adam", loss="sparse_categorical_crossentropy", jit_compile=False, ) diff --git a/keras_nlp/models/bert/bert_masked_lm.py b/keras_nlp/models/bert/bert_masked_lm.py index 56f6e8499a..1b1cd4853f 100644 --- a/keras_nlp/models/bert/bert_masked_lm.py +++ b/keras_nlp/models/bert/bert_masked_lm.py @@ -25,7 +25,6 @@ ) from keras_nlp.models.bert.bert_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -136,7 +135,7 @@ def __init__( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) @classproperty diff --git a/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py b/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py index c993253387..004c30e495 100644 --- a/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py @@ -136,6 +136,7 @@ def test_serialization(self): ) @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) diff --git a/keras_nlp/models/bert/bert_masked_lm_test.py b/keras_nlp/models/bert/bert_masked_lm_test.py index 5f1db45408..46c777b0c2 100644 --- a/keras_nlp/models/bert/bert_masked_lm_test.py +++ b/keras_nlp/models/bert/bert_masked_lm_test.py @@ -56,34 +56,33 @@ def setUp(self): ) # Setup data. - self.raw_batch = tf.constant( - [ - "the quick brown fox.", - "the slow brown fox.", - ] - ) + self.raw_batch = [ + "the quick brown fox.", + "the slow brown fox.", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch) self.raw_dataset = tf.data.Dataset.from_tensor_slices( self.raw_batch ).batch(2) self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) - def test_valid_call_classifier(self): + def test_valid_call(self): self.masked_lm(self.preprocessed_batch[0]) - def test_classifier_predict(self): + def test_predict(self): self.masked_lm.predict(self.raw_batch) self.masked_lm.preprocessor = None self.masked_lm.predict(self.preprocessed_batch[0]) - def test_classifier_fit(self): + def test_fit(self): self.masked_lm.fit(self.raw_dataset) self.masked_lm.preprocessor = None self.masked_lm.fit(self.preprocessed_dataset) - def test_classifier_fit_no_xla(self): + def test_fit_no_xla(self): self.masked_lm.preprocessor = None self.masked_lm.compile( + optimizer="adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), jit_compile=False, ) diff --git a/keras_nlp/models/bert/bert_preprocessor_test.py b/keras_nlp/models/bert/bert_preprocessor_test.py index 10c1c22017..28f12a84dd 100644 --- a/keras_nlp/models/bert/bert_preprocessor_test.py +++ b/keras_nlp/models/bert/bert_preprocessor_test.py @@ -121,6 +121,7 @@ def test_serialization(self): ) @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["THE QUICK BROWN FOX."]) inputs = keras.Input(dtype="string", shape=()) diff --git a/keras_nlp/models/bert/bert_presets_test.py b/keras_nlp/models/bert/bert_presets_test.py index e9e7f36cac..6728a7ded5 100644 --- a/keras_nlp/models/bert/bert_presets_test.py +++ b/keras_nlp/models/bert/bert_presets_test.py @@ -14,9 +14,9 @@ """Tests for loading pretrained model presets.""" import pytest -import tensorflow as tf from absl.testing import parameterized +from keras_nlp.backend import ops from keras_nlp.models.bert.bert_backbone import BertBackbone from keras_nlp.models.bert.bert_classifier import BertClassifier from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor @@ -55,9 +55,9 @@ def test_preprocessor_output(self): ) def test_backbone_output(self, load_weights): input_data = { - "token_ids": tf.constant([[101, 1996, 4248, 102]]), - "segment_ids": tf.constant([[0, 0, 0, 0]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[101, 1996, 4248, 102]]), + "segment_ids": ops.array([[0, 0, 0, 0]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = BertBackbone.from_preset( "bert_tiny_en_uncased", load_weights=load_weights @@ -78,7 +78,7 @@ def test_backbone_output(self, load_weights): ("load_weights", True), ("no_load_weights", False) ) def test_classifier_output(self, load_weights): - input_data = tf.constant(["The quick brown fox."]) + input_data = ["The quick brown fox."] model = BertClassifier.from_preset( "bert_tiny_en_uncased", num_classes=2, @@ -92,9 +92,9 @@ def test_classifier_output(self, load_weights): ) def test_classifier_output_without_preprocessing(self, load_weights): input_data = { - "token_ids": tf.constant([[101, 1996, 4248, 102]]), - "segment_ids": tf.constant([[0, 0, 0, 0]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[101, 1996, 4248, 102]]), + "segment_ids": ops.array([[0, 0, 0, 0]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = BertClassifier.from_preset( "bert_tiny_en_uncased", @@ -187,13 +187,11 @@ def test_load_bert(self, load_weights): for preset in BertBackbone.presets: model = BertBackbone.from_preset(preset, load_weights=load_weights) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=model.vocabulary_size ), - "segment_ids": tf.constant( - [0] * 200 + [1] * 312, shape=(1, 512) - ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } model(input_data) @@ -207,7 +205,7 @@ def test_load_bert_classifier(self, load_weights): num_classes=2, load_weights=load_weights, ) - input_data = tf.constant(["This quick brown fox"]) + input_data = ["This quick brown fox."] classifier.predict(input_data) @parameterized.named_parameters( @@ -222,15 +220,13 @@ def test_load_bert_classifier_without_preprocessing(self, load_weights): load_weights=load_weights, ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=classifier.backbone.vocabulary_size, ), - "segment_ids": tf.constant( - [0] * 200 + [1] * 312, shape=(1, 512) - ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } classifier.predict(input_data) diff --git a/keras_nlp/models/bert/bert_tokenizer_test.py b/keras_nlp/models/bert/bert_tokenizer_test.py index 1023626a2e..b56a9f5c19 100644 --- a/keras_nlp/models/bert/bert_tokenizer_test.py +++ b/keras_nlp/models/bert/bert_tokenizer_test.py @@ -36,7 +36,7 @@ def test_tokenize(self): self.assertAllEqual(output, [5, 6, 7, 8, 1]) def test_tokenize_batch(self): - input_data = tf.constant(["THE QUICK BROWN FOX.", "THE FOX."]) + input_data = ["THE QUICK BROWN FOX.", "THE FOX."] output = self.tokenizer(input_data) self.assertAllEqual(output, [[5, 6, 7, 8, 1], [5, 8, 1]]) @@ -67,6 +67,7 @@ def test_serialization(self): ) @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["THE QUICK BROWN FOX."]) tokenizer = BertTokenizer(vocabulary=self.vocab) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py index 616b1b6131..1c889fc36c 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py @@ -25,6 +25,7 @@ ) from keras_nlp.models.deberta_v3.relative_embedding import RelativeEmbedding from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import assert_tf_backend def deberta_kernel_initializer(stddev=0.02): @@ -110,6 +111,8 @@ def __init__( bucket_size=256, **kwargs, ): + assert_tf_backend(self.__class__.__name__) + # Inputs token_id_input = keras.Input( shape=(None,), dtype="int32", name="token_ids" diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py index 9ced19a2c2..2cd49914df 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py @@ -19,10 +19,12 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone from keras_nlp.tests.test_case import TestCase +@pytest.mark.tf_only class DebertaV3BackboneTest(TestCase): def setUp(self): self.backbone = DebertaV3Backbone( @@ -36,8 +38,8 @@ def setUp(self): ) self.batch_size = 8 self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( @@ -57,8 +59,8 @@ def test_token_embedding(self): def test_variable_sequence_length_call_deberta(self): for seq_length in (2, 3, 4): input_data = { - "token_ids": tf.ones((2, seq_length), dtype="int32"), - "padding_mask": tf.ones((2, seq_length), dtype="int32"), + "token_ids": ops.ones((2, seq_length), dtype="int32"), + "padding_mask": ops.ones((2, seq_length), dtype="int32"), } output = self.backbone(input_data) self.assertAllEqual( @@ -93,6 +95,7 @@ def test_saved_model(self): @pytest.mark.tpu @pytest.mark.usefixtures("tpu_test_class") +@pytest.mark.tf_only class DebertaV3BackboneTPUTest(TestCase): def setUp(self): with self.tpu_strategy.scope(): @@ -106,8 +109,8 @@ def setUp(self): bucket_size=2, ) self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch diff --git a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py index f91e159bd6..2d15a95ab9 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py @@ -26,7 +26,6 @@ ) from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -208,7 +207,7 @@ def __init__( ), optimizer=keras.optimizers.Adam(5e-5), metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) def get_config(self): diff --git a/keras_nlp/models/deberta_v3/deberta_v3_classifier_test.py b/keras_nlp/models/deberta_v3/deberta_v3_classifier_test.py index 4342a6fd9a..da69036472 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_classifier_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_classifier_test.py @@ -21,6 +21,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone from keras_nlp.models.deberta_v3.deberta_v3_classifier import ( DebertaV3Classifier, @@ -32,6 +33,7 @@ from keras_nlp.tests.test_case import TestCase +@pytest.mark.tf_only class DebertaV3ClassifierTest(TestCase): def setUp(self): bytes_io = io.BytesIO() @@ -75,15 +77,13 @@ def setUp(self): hidden_dim=4, ) - self.raw_batch = tf.constant( - [ - "the quick brown fox.", - "the slow brown fox.", - ] - ) + self.raw_batch = [ + "the quick brown fox.", + "the slow brown fox.", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch) self.raw_dataset = tf.data.Dataset.from_tensor_slices( - (self.raw_batch, tf.ones((2,))) + (self.raw_batch, ops.ones((2,))) ).batch(2) self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) @@ -97,7 +97,7 @@ def test_classifier_predict(self): # Assert predictions match. self.assertAllClose(preds1, preds2) # Assert valid softmax output. - self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0]) + self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0]) def test_classifier_fit(self): self.classifier.fit(self.raw_dataset) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py index 4d97b77e82..8fe6e0c919 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py @@ -27,7 +27,6 @@ ) from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -142,7 +141,7 @@ def __init__( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) @classproperty diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py index b54da09fcb..65fc464027 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py @@ -148,6 +148,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_test.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_test.py index 0b62f27548..768fc25661 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_test.py @@ -30,6 +30,7 @@ from keras_nlp.tests.test_case import TestCase +@pytest.mark.tf_only class DebertaV3MaskedLMTest(TestCase): def setUp(self): bytes_io = io.BytesIO() @@ -70,12 +71,10 @@ def setUp(self): preprocessor=self.preprocessor, ) - self.raw_batch = tf.constant( - [ - "the quick brown fox.", - "the eagle flew over fox.", - ] - ) + self.raw_batch = [ + "the quick brown fox.", + "the eagle flew over fox.", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch) self.raw_dataset = tf.data.Dataset.from_tensor_slices( self.raw_batch diff --git a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py index fbf6a7e09f..0526b4b548 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py @@ -145,6 +145,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) inputs = keras.Input(dtype="string", shape=()) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_presets_test.py b/keras_nlp/models/deberta_v3/deberta_v3_presets_test.py index 0df222fd0b..334ea8a263 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_presets_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_presets_test.py @@ -14,9 +14,9 @@ """Tests for loading pretrained model presets.""" import pytest -import tensorflow as tf from absl.testing import parameterized +from keras_nlp.backend import ops from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone from keras_nlp.models.deberta_v3.deberta_v3_classifier import ( DebertaV3Classifier, @@ -29,6 +29,7 @@ @pytest.mark.large +@pytest.mark.tf_only class DebertaV3PresetSmokeTest(TestCase): """ A smoke test for DeBERTa presets we run continuously. @@ -67,8 +68,8 @@ def test_preprocessor_mask_token(self): ) def test_backbone_output(self, load_weights): input_data = { - "token_ids": tf.constant([[0, 581, 63773, 2]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[0, 581, 63773, 2]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = DebertaV3Backbone.from_preset( "deberta_v3_extra_small_en", load_weights=load_weights @@ -83,7 +84,7 @@ def test_backbone_output(self, load_weights): ("preset_weights", True), ("random_weights", False) ) def test_classifier_output(self, load_weights): - input_data = tf.constant(["The quick brown fox."]) + input_data = ["The quick brown fox."] model = DebertaV3Classifier.from_preset( "deberta_v3_extra_small_en", num_classes=2, @@ -97,8 +98,8 @@ def test_classifier_output(self, load_weights): ) def test_classifier_output_without_preprocessing(self, load_weights): input_data = { - "token_ids": tf.constant([[0, 581, 63773, 2]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[0, 581, 63773, 2]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = DebertaV3Classifier.from_preset( "deberta_v3_extra_small_en", @@ -133,6 +134,7 @@ def test_unknown_preset_error(self, cls, kwargs): @pytest.mark.extra_large +@pytest.mark.tf_only class DebertaV3PresetFullTest(TestCase): """ Test the full enumeration of our preset. @@ -151,10 +153,10 @@ def test_load_deberta(self, load_weights): preset, load_weights=load_weights ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=model.vocabulary_size ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } model(input_data) @@ -168,7 +170,7 @@ def test_load_deberta_classifier(self, load_weights): num_classes=4, load_weights=load_weights, ) - input_data = tf.constant(["This quick brown fox"]) + input_data = ["The quick brown fox."] classifier.predict(input_data) @parameterized.named_parameters( @@ -183,12 +185,12 @@ def test_load_deberta_classifier_without_preprocessing(self, load_weights): preprocessor=None, ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=classifier.backbone.vocabulary_size, ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } classifier.predict(input_data) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer_test.py b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer_test.py index 0019a28841..88f3936ec2 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer_test.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer_test.py @@ -55,19 +55,19 @@ def test_tokenize(self): self.assertAllEqual(output, [4, 9, 5, 7]) def test_tokenize_batch(self): - input_data = tf.constant(["the quick brown fox", "the earth is round"]) + input_data = ["the quick brown fox", "the earth is round"] output = self.tokenizer(input_data) self.assertAllEqual(output, [[4, 9, 5, 7], [4, 6, 8, 3]]) def test_detokenize(self): - input_data = tf.constant([[4, 9, 5, 7]]) + input_data = [[4, 9, 5, 7]] output = self.tokenizer.detokenize(input_data) - self.assertEqual(output, tf.constant(["the quick brown fox"])) + self.assertEqual(output, ["the quick brown fox"]) def test_detokenize_mask_token(self): - input_data = tf.constant([[4, 9, 5, 7, self.tokenizer.mask_token_id]]) + input_data = [[4, 9, 5, 7, self.tokenizer.mask_token_id]] output = self.tokenizer.detokenize(input_data) - self.assertEqual(output, tf.constant(["the quick brown fox"])) + self.assertEqual(output, ["the quick brown fox"]) def test_vocabulary_size(self): self.assertEqual(self.tokenizer.vocabulary_size(), 11) @@ -103,6 +103,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) diff --git a/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py b/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py index 6c3a872131..594e7188ca 100644 --- a/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py +++ b/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py @@ -75,11 +75,7 @@ def __init__( bias_initializer="zeros", **kwargs ): - # Work around for model saving - self._input_shape = kwargs.pop("build_input_shape", None) - super().__init__(**kwargs) - self.intermediate_dim = intermediate_dim self.num_heads = num_heads self.max_position_embeddings = max_position_embeddings @@ -92,15 +88,9 @@ def __init__( self._built = False self.supports_masking = True - if self._input_shape is not None: - self._build(self._input_shape) - - def _build(self, input_shape): - # Create layers based on input shape. - self._built = True - self._input_shape = input_shape + def build(self, inputs_shape): # Infer the dimension of our hidden feature size from the build shape. - hidden_dim = input_shape[-1] + hidden_dim = inputs_shape[-1] # Self attention layers. self._self_attention_layer = DisentangledSelfAttention( @@ -112,9 +102,11 @@ def _build(self, input_shape): kernel_initializer=clone_initializer(self.kernel_initializer), bias_initializer=clone_initializer(self.bias_initializer), ) + self._self_attention_layer.build(inputs_shape) self._self_attention_layernorm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, ) + self._self_attention_layernorm.build(inputs_shape) self._self_attention_dropout = keras.layers.Dropout( rate=self.dropout, ) @@ -123,20 +115,26 @@ def _build(self, input_shape): self._feedforward_layernorm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, ) + self._feedforward_layernorm.build(inputs_shape) self._feedforward_intermediate_dense = keras.layers.Dense( self.intermediate_dim, activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), bias_initializer=clone_initializer(self.bias_initializer), ) + self._feedforward_intermediate_dense.build(inputs_shape) self._feedforward_output_dense = keras.layers.Dense( hidden_dim, kernel_initializer=clone_initializer(self.kernel_initializer), bias_initializer=clone_initializer(self.bias_initializer), ) + intermediate_shape = list(inputs_shape) + intermediate_shape[-1] = self.intermediate_dim + self._feedforward_output_dense.build(tuple(intermediate_shape)) self._feedforward_dropout = keras.layers.Dropout( rate=self.dropout, ) + self.built = True def call( self, @@ -163,10 +161,6 @@ def call( Returns: A Tensor of the same shape as the `inputs`. """ - - if not self._built: - self._build(inputs.shape) - x = inputs # Compute self attention mask. @@ -177,7 +171,7 @@ def call( # Self attention block. residual = x x = self._self_attention_layer( - hidden_states=x, + x, rel_embeddings=rel_embeddings, attention_mask=self_attention_mask, ) @@ -212,7 +206,9 @@ def get_config(self): "bias_initializer": keras.initializers.serialize( self.bias_initializer ), - "build_input_shape": self._input_shape, } ) return config + + def compute_output_shape(self, inputs_shape): + return inputs_shape diff --git a/keras_nlp/models/deberta_v3/disentangled_self_attention.py b/keras_nlp/models/deberta_v3/disentangled_self_attention.py index f1b8d51721..5f1ab8a473 100644 --- a/keras_nlp/models/deberta_v3/disentangled_self_attention.py +++ b/keras_nlp/models/deberta_v3/disentangled_self_attention.py @@ -82,8 +82,7 @@ def __init__( float(num_type_attn * self.attn_head_size) ) - # Layers. - + def build(self, inputs_shape, rel_embeddings_shape=None): # Q, K, V linear layers. self._query_dense = keras.layers.EinsumDense( equation="abc,cde->abde", @@ -92,6 +91,7 @@ def __init__( **self._get_common_kwargs_for_sublayer(use_bias=True), name="query", ) + self._query_dense.build(inputs_shape) self._key_dense = keras.layers.EinsumDense( equation="abc,cde->abde", output_shape=(None, self.num_heads, self.attn_head_size), @@ -99,6 +99,7 @@ def __init__( **self._get_common_kwargs_for_sublayer(use_bias=True), name="key", ) + self._key_dense.build(inputs_shape) self._value_dense = keras.layers.EinsumDense( equation="abc,cde->abde", output_shape=(None, self.num_heads, self.attn_head_size), @@ -106,6 +107,7 @@ def __init__( **self._get_common_kwargs_for_sublayer(use_bias=True), name="value", ) + self._value_dense.build(inputs_shape) # Relative attention. self._position_dropout_layer = keras.layers.Dropout(self.dropout) @@ -123,6 +125,8 @@ def __init__( **self._get_common_kwargs_for_sublayer(use_bias=True), name="attention_output", ) + self._output_dense.build(inputs_shape) + self.built = True def _get_common_kwargs_for_sublayer(self, use_bias=True): common_kwargs = {} @@ -322,7 +326,7 @@ def _compute_disentangled_attention( def call( self, - hidden_states, + inputs, rel_embeddings, attention_mask=None, return_attention_scores=False, @@ -330,9 +334,9 @@ def call( ): # `query`, `key`, `value` shape: # `(batch_size, sequence_length, num_heads, attn_head_size)`. - query = self._query_dense(hidden_states) - key = self._key_dense(hidden_states) - value = self._value_dense(hidden_states) + query = self._query_dense(inputs) + key = self._key_dense(inputs) + value = self._value_dense(inputs) attention_output, attention_scores = self._compute_attention( query=query, diff --git a/keras_nlp/models/deberta_v3/relative_embedding.py b/keras_nlp/models/deberta_v3/relative_embedding.py index 06bc1f66ad..dddab8f9c8 100644 --- a/keras_nlp/models/deberta_v3/relative_embedding.py +++ b/keras_nlp/models/deberta_v3/relative_embedding.py @@ -88,3 +88,6 @@ def get_config(self): } ) return config + + def compute_output_shape(self, input_shape): + return (input_shape[0],) + (self.bucket_size * 2, self.hidden_dim) diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone_test.py b/keras_nlp/models/distil_bert/distil_bert_backbone_test.py index 3d3e2dcd78..f16c058f71 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone from keras_nlp.tests.test_case import TestCase @@ -36,8 +37,8 @@ def setUp(self): ) self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( @@ -54,9 +55,9 @@ def test_token_embedding(self): def test_variable_sequence_length_call_distilbert(self): for seq_length in (2, 3, 4): input_data = { - "token_ids": tf.ones((2, seq_length), dtype="int32"), - "mask_positions": tf.ones((2, seq_length), dtype="int32"), - "padding_mask": tf.ones((2, seq_length), dtype="int32"), + "token_ids": ops.ones((2, seq_length), dtype="int32"), + "mask_positions": ops.ones((2, seq_length), dtype="int32"), + "padding_mask": ops.ones((2, seq_length), dtype="int32"), } self.backbone(input_data) @@ -99,8 +100,8 @@ def setUp(self): max_sequence_length=128, ) self.input_batch = { - "token_ids": tf.ones((8, 128), dtype="int32"), - "padding_mask": tf.ones((8, 128), dtype="int32"), + "token_ids": ops.ones((8, 128), dtype="int32"), + "padding_mask": ops.ones((8, 128), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch diff --git a/keras_nlp/models/distil_bert/distil_bert_classifier.py b/keras_nlp/models/distil_bert/distil_bert_classifier.py index d13f3def43..1ae94cb2fa 100644 --- a/keras_nlp/models/distil_bert/distil_bert_classifier.py +++ b/keras_nlp/models/distil_bert/distil_bert_classifier.py @@ -26,7 +26,6 @@ ) from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -193,7 +192,7 @@ def __init__( ), optimizer=keras.optimizers.Adam(5e-5), metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) def get_config(self): diff --git a/keras_nlp/models/distil_bert/distil_bert_classifier_test.py b/keras_nlp/models/distil_bert/distil_bert_classifier_test.py index 518842ffc0..1d9b2fef0c 100644 --- a/keras_nlp/models/distil_bert/distil_bert_classifier_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_classifier_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone from keras_nlp.models.distil_bert.distil_bert_classifier import ( DistilBertClassifier, @@ -59,15 +60,13 @@ def setUp(self): hidden_dim=4, ) - self.raw_batch = tf.constant( - [ - "the quick brown fox.", - "the slow brown fox.", - ] - ) + self.raw_batch = [ + "the quick brown fox.", + "the slow brown fox.", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch) self.raw_dataset = tf.data.Dataset.from_tensor_slices( - (self.raw_batch, tf.ones((2,))) + (self.raw_batch, ops.ones((2,))) ).batch(2) self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) @@ -81,7 +80,7 @@ def test_classifier_predict(self): # Assert predictions match. self.assertAllClose(preds1, preds2) # Assert valid softmax output. - self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0]) + self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0]) def test_classifier_fit(self): self.classifier.fit(self.raw_dataset) diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py index 4e7d29258b..1bc7f0dbf1 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py @@ -27,7 +27,6 @@ ) from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -140,7 +139,7 @@ def __init__( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) @classproperty diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py index 16870de81a..192cd6b2b1 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py @@ -113,6 +113,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" THE QUICK BROWN FOX."]) diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py index 154970f74a..e949a01ab1 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py @@ -55,12 +55,10 @@ def setUp(self): preprocessor=self.preprocessor, ) - self.raw_batch = tf.constant( - [ - "the quick brown fox.", - "the slow brown fox.", - ] - ) + self.raw_batch = [ + "the quick brown fox.", + "the slow brown fox.", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch) self.raw_dataset = tf.data.Dataset.from_tensor_slices( self.raw_batch diff --git a/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py b/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py index 9ca119ddee..22127a90fe 100644 --- a/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py @@ -111,6 +111,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["THE QUICK BROWN FOX."]) inputs = keras.Input(dtype="string", shape=()) diff --git a/keras_nlp/models/distil_bert/distil_bert_presets_test.py b/keras_nlp/models/distil_bert/distil_bert_presets_test.py index 71e56911c0..fcef57e0c0 100644 --- a/keras_nlp/models/distil_bert/distil_bert_presets_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_presets_test.py @@ -14,9 +14,9 @@ """Tests for loading pretrained model presets.""" import pytest -import tensorflow as tf from absl.testing import parameterized +from keras_nlp.backend import ops from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone from keras_nlp.models.distil_bert.distil_bert_classifier import ( DistilBertClassifier, @@ -61,8 +61,8 @@ def test_preprocessor_output(self): ) def test_backbone_output(self, load_weights): input_data = { - "token_ids": tf.constant([[101, 1996, 4248, 102]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[101, 1996, 4248, 102]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = DistilBertBackbone.from_preset( "distil_bert_base_en_uncased", load_weights=load_weights @@ -76,7 +76,7 @@ def test_backbone_output(self, load_weights): ("preset_weights", True), ("random_weights", False) ) def test_classifier_output(self, load_weights): - input_data = tf.constant(["The quick brown fox."]) + input_data = ["The quick brown fox."] model = DistilBertClassifier.from_preset( "distil_bert_base_en_uncased", num_classes=2, @@ -89,8 +89,8 @@ def test_classifier_output(self, load_weights): ) def test_classifier_output_without_preprocessing(self, load_weights): input_data = { - "token_ids": tf.constant([[101, 1996, 4248, 102]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[101, 1996, 4248, 102]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = DistilBertClassifier.from_preset( "distil_bert_base_en_uncased", @@ -142,10 +142,10 @@ def test_load_distilbert(self, load_weights): preset, load_weights=load_weights ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=model.vocabulary_size ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } model(input_data) @@ -159,7 +159,7 @@ def test_load_distilbert_classifier(self, load_weights): num_classes=2, load_weights=load_weights, ) - input_data = tf.constant(["This quick brown fox"]) + input_data = ["This quick brown fox."] classifier.predict(input_data) @parameterized.named_parameters( @@ -174,12 +174,12 @@ def test_load_distilbert_classifier_no_preprocessing(self, load_weights): preprocessor=None, ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=classifier.backbone.vocabulary_size, ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } classifier.predict(input_data) diff --git a/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py b/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py index 11007ecc21..7e3481f6f3 100644 --- a/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py @@ -38,7 +38,7 @@ def test_tokenize(self): self.assertAllEqual(output, [5, 6, 7, 8, 1]) def test_tokenize_batch(self): - input_data = tf.constant(["THE QUICK BROWN FOX.", "THE FOX."]) + input_data = ["THE QUICK BROWN FOX.", "THE FOX."] output = self.tokenizer(input_data) self.assertAllEqual(output, [[5, 6, 7, 8, 1], [5, 8, 1]]) @@ -69,6 +69,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["THE QUICK BROWN FOX."]) tokenizer = DistilBertTokenizer(vocabulary=self.vocab) diff --git a/keras_nlp/models/f_net/f_net_backbone_test.py b/keras_nlp/models/f_net/f_net_backbone_test.py index 57930fe2c6..c65b845755 100644 --- a/keras_nlp/models/f_net/f_net_backbone_test.py +++ b/keras_nlp/models/f_net/f_net_backbone_test.py @@ -19,10 +19,12 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.tests.test_case import TestCase +@pytest.mark.tf_only class FNetBackboneTest(TestCase): def setUp(self): self.backbone = FNetBackbone( @@ -34,8 +36,8 @@ def setUp(self): num_segments=4, ) self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "segment_ids": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "segment_ids": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( @@ -51,8 +53,8 @@ def test_valid_call_f_net(self): def test_variable_sequence_length_call_f_net(self): for seq_length in (2, 3, 4): input_data = { - "token_ids": tf.ones((2, seq_length), dtype="int32"), - "segment_ids": tf.ones((2, seq_length), dtype="int32"), + "token_ids": ops.ones((2, seq_length), dtype="int32"), + "segment_ids": ops.ones((2, seq_length), dtype="int32"), } self.backbone(input_data) @@ -97,8 +99,8 @@ def setUp(self): num_segments=4, ) self.input_batch = { - "token_ids": tf.ones((8, 128), dtype="int32"), - "segment_ids": tf.ones((8, 128), dtype="int32"), + "token_ids": ops.ones((8, 128), dtype="int32"), + "segment_ids": ops.ones((8, 128), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch diff --git a/keras_nlp/models/f_net/f_net_classifier.py b/keras_nlp/models/f_net/f_net_classifier.py index 9dab025296..ca15a52392 100644 --- a/keras_nlp/models/f_net/f_net_classifier.py +++ b/keras_nlp/models/f_net/f_net_classifier.py @@ -23,7 +23,6 @@ from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor from keras_nlp.models.f_net.f_net_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -144,7 +143,7 @@ def __init__( ), optimizer=keras.optimizers.Adam(5e-5), metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) def get_config(self): diff --git a/keras_nlp/models/f_net/f_net_classifier_test.py b/keras_nlp/models/f_net/f_net_classifier_test.py index 7131067481..9c8978b6f1 100644 --- a/keras_nlp/models/f_net/f_net_classifier_test.py +++ b/keras_nlp/models/f_net/f_net_classifier_test.py @@ -21,6 +21,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.models.f_net.f_net_classifier import FNetClassifier from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor @@ -28,6 +29,7 @@ from keras_nlp.tests.test_case import TestCase +@pytest.mark.tf_only class FNetClassifierTest(TestCase): def setUp(self): # Setup Model @@ -74,15 +76,13 @@ def setUp(self): ) # Setup data. - self.raw_batch = tf.constant( - [ - "the quick brown fox.", - "the slow brown fox.", - ] - ) + self.raw_batch = [ + "the quick brown fox.", + "the slow brown fox.", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch) self.raw_dataset = tf.data.Dataset.from_tensor_slices( - (self.raw_batch, tf.ones((2,))) + (self.raw_batch, ops.ones((2,))) ).batch(2) self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) @@ -96,7 +96,7 @@ def test_classifier_predict(self): # Assert predictions match. self.assertAllClose(preds1, preds2) # Assert valid softmax output. - self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0]) + self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0]) def test_fnet_classifier_fit(self): self.classifier.fit(self.raw_dataset) diff --git a/keras_nlp/models/f_net/f_net_masked_lm.py b/keras_nlp/models/f_net/f_net_masked_lm.py index 8c6c713b04..f6ecd549d8 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm.py +++ b/keras_nlp/models/f_net/f_net_masked_lm.py @@ -23,7 +23,6 @@ ) from keras_nlp.models.f_net.f_net_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -134,7 +133,7 @@ def __init__( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) @classproperty diff --git a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py index 2b2e0ca952..81ee99a4a4 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py @@ -130,6 +130,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) diff --git a/keras_nlp/models/f_net/f_net_masked_lm_test.py b/keras_nlp/models/f_net/f_net_masked_lm_test.py index 23d67ddde7..4b102cd59a 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm_test.py +++ b/keras_nlp/models/f_net/f_net_masked_lm_test.py @@ -28,6 +28,7 @@ from keras_nlp.tests.test_case import TestCase +@pytest.mark.tf_only class FNetMaskedLMTest(TestCase): def setUp(self): # Setup Model. @@ -68,12 +69,10 @@ def setUp(self): preprocessor=self.preprocessor, ) - self.raw_batch = tf.constant( - [ - "the quick brown fox", - "the slow brown fox", - ] - ) + self.raw_batch = [ + "the quick brown fox", + "the slow brown fox", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] self.raw_dataset = tf.data.Dataset.from_tensor_slices( self.raw_batch diff --git a/keras_nlp/models/f_net/f_net_preprocessor_test.py b/keras_nlp/models/f_net/f_net_preprocessor_test.py index 3230998298..8034c09e6c 100644 --- a/keras_nlp/models/f_net/f_net_preprocessor_test.py +++ b/keras_nlp/models/f_net/f_net_preprocessor_test.py @@ -148,6 +148,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) inputs = keras.Input(dtype="string", shape=()) diff --git a/keras_nlp/models/f_net/f_net_presets_test.py b/keras_nlp/models/f_net/f_net_presets_test.py index 4c30e0b0b2..37d0f21a47 100644 --- a/keras_nlp/models/f_net/f_net_presets_test.py +++ b/keras_nlp/models/f_net/f_net_presets_test.py @@ -14,9 +14,9 @@ """Tests for loading pretrained model presets.""" import pytest -import tensorflow as tf from absl.testing import parameterized +from keras_nlp.backend import ops from keras_nlp.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.models.f_net.f_net_classifier import FNetClassifier from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor @@ -25,6 +25,7 @@ @pytest.mark.large +@pytest.mark.tf_only class FNetPresetSmokeTest(TestCase): """ A smoke test for FNet presets we run continuously. @@ -55,9 +56,9 @@ def test_preprocessor_output(self): ) def test_backbone_output(self, load_weights): input_data = { - "token_ids": tf.constant([[101, 1996, 4248, 102]]), - "segment_ids": tf.constant([[0, 0, 0, 0]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[101, 1996, 4248, 102]]), + "segment_ids": ops.array([[0, 0, 0, 0]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = FNetBackbone.from_preset( "f_net_base_en", load_weights=load_weights @@ -78,7 +79,7 @@ def test_backbone_output(self, load_weights): ("load_weights", True), ("no_load_weights", False) ) def test_classifier_output(self, load_weights): - input_data = tf.constant(["The quick brown fox."]) + input_data = ["The quick brown fox."] model = FNetClassifier.from_preset( "f_net_base_en", num_classes=2, @@ -127,12 +128,10 @@ def test_load_f_net(self, load_weights): for preset in FNetBackbone.presets: model = FNetBackbone.from_preset(preset, load_weights=load_weights) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=model.vocabulary_size ), - "segment_ids": tf.constant( - [0] * 200 + [1] * 312, shape=(1, 512) - ), + "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)), } model(input_data) @@ -146,7 +145,7 @@ def test_load_fnet_classifier(self, load_weights): num_classes=2, load_weights=load_weights, ) - input_data = tf.constant(["This quick brown fox"]) + input_data = ["The quick brown fox."] classifier.predict(input_data) @parameterized.named_parameters( @@ -161,15 +160,13 @@ def test_load_fnet_classifier_without_preprocessing(self, load_weights): load_weights=load_weights, ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=classifier.backbone.vocabulary_size, ), - "segment_ids": tf.constant( - [0] * 200 + [1] * 312, shape=(1, 512) - ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } classifier.predict(input_data) diff --git a/keras_nlp/models/f_net/f_net_tokenizer_test.py b/keras_nlp/models/f_net/f_net_tokenizer_test.py index ea0bd3e7e7..380eb920de 100644 --- a/keras_nlp/models/f_net/f_net_tokenizer_test.py +++ b/keras_nlp/models/f_net/f_net_tokenizer_test.py @@ -25,6 +25,7 @@ from keras_nlp.tests.test_case import TestCase +@pytest.mark.tf_only class FNetTokenizerTest(TestCase): def setUp(self): bytes_io = io.BytesIO() @@ -56,14 +57,14 @@ def test_tokenize(self): self.assertAllEqual(output, [2, 10, 6, 8]) def test_tokenize_batch(self): - input_data = tf.constant(["the quick brown fox", "the earth is round"]) + input_data = ["the quick brown fox", "the earth is round"] output = self.tokenizer(input_data) self.assertAllEqual(output, [[2, 10, 6, 8], [2, 7, 9, 11]]) def test_detokenize(self): - input_data = tf.constant([[2, 10, 6, 8]]) + input_data = [[2, 10, 6, 8]] output = self.tokenizer.detokenize(input_data) - self.assertEqual(output, tf.constant(["the quick brown fox"])) + self.assertEqual(output, ["the quick brown fox"]) def test_vocabulary_size(self): tokenizer = FNetTokenizer(proto=self.proto) @@ -91,6 +92,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) diff --git a/keras_nlp/models/generative_task.py b/keras_nlp/models/generative_task.py index 76771a4f3e..e479bab196 100644 --- a/keras_nlp/models/generative_task.py +++ b/keras_nlp/models/generative_task.py @@ -13,9 +13,13 @@ # limitations under the License. """Base class for Generative Task models.""" +import itertools + import tensorflow as tf +from keras_nlp.backend import config from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.task import Task from keras_nlp.samplers.serialization import get as get_sampler from keras_nlp.utils.tensor_utils import tensor_to_list @@ -54,15 +58,80 @@ def make_generate_function(self): if self.generate_function is not None: return self.generate_function - if self.run_eagerly: - self.generate_function = self.generate_step - else: + self.generate_function = self.generate_step + if config.backend() == "torch": + import torch + + def wrapped_generate_function( + inputs, + end_token_id=None, + ): + with torch.no_grad(): + return self.generate_step(inputs, end_token_id) + + self.generate_function = wrapped_generate_function + elif config.backend() == "tensorflow" and not self.run_eagerly: # `jit_compile` is a property of keras.Model after TF 2.12. # Use `getattr()` for backwards compatibility. jit_compile = getattr(self, "jit_compile", True) self.generate_function = tf.function( self.generate_step, jit_compile=jit_compile ) + elif config.backend() == "jax" and not self.run_eagerly: + import jax + + @jax.jit + def compiled_generate_function(inputs, end_token_id, state): + ( + sampler_variables, + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self._sampler.variables, sampler_variables), + zip(self.trainable_variables, trainable_variables), + zip(self.non_trainable_variables, non_trainable_variables), + ) + + with keras.StatelessScope(state_mapping=mapping) as scope: + outputs = self.generate_step(inputs, end_token_id) + + # Get updated sampler variables from the stateless scope. + sampler_variables = [] + for v in self._sampler.variables: + new_v = scope.get_current_value(v) + sampler_variables.append(new_v if new_v is not None else v) + state = ( + sampler_variables, + trainable_variables, + non_trainable_variables, + ) + return outputs, state + + def wrapped_generate_function( + inputs, + end_token_id=None, + ): + # Create an explicit tuple of all variable state. + state = ( + self._sampler.variables, + self.trainable_variables, + self.non_trainable_variables, + ) + inputs = tf.nest.map_structure(ops.convert_to_tensor, inputs) + outputs, state = compiled_generate_function( + inputs, + end_token_id, + state, + ) + # Only assign the sampler variables (random seeds), as other + # model variables should never be updated in generation. + for ref_v, v in zip(self._sampler.variables, state[0]): + ref_v.assign(v) + return outputs + + self.generate_function = wrapped_generate_function + return self.generate_function def _normalize_generate_inputs( diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index 44826b2b3d..40aefda88b 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -18,6 +18,7 @@ from tensorflow.experimental import dtensor from tensorflow.experimental.dtensor import Layout +from tensorflow.keras.dtensor.experimental import LayoutMap from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -239,7 +240,7 @@ def create_layout_map(cls, mesh): _, model_dim = mesh.dim_names unshard_dim = dtensor.UNSHARDED - layout_map = keras.dtensor.experimental.LayoutMap(mesh=mesh) + layout_map = LayoutMap(mesh=mesh) # Embedding sharding layout_map[r".*embeddings"] = Layout([unshard_dim, model_dim], mesh) diff --git a/keras_nlp/models/gpt2/gpt2_backbone_test.py b/keras_nlp/models/gpt2/gpt2_backbone_test.py index 3757a79f2e..d9134be51b 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone_test.py +++ b/keras_nlp/models/gpt2/gpt2_backbone_test.py @@ -19,16 +19,13 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.tests.test_case import TestCase class GPT2Test(TestCase): def setUp(self): - # For DTensor. - keras.backend.experimental.enable_tf_random_generator() - keras.utils.set_random_seed(1337) - self.backbone = GPT2Backbone( vocabulary_size=10, num_layers=2, @@ -38,9 +35,9 @@ def setUp(self): max_sequence_length=5, ) self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "segment_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "segment_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch @@ -60,8 +57,8 @@ def test_name(self): def test_variable_sequence_length(self): for seq_length in (2, 3, 4): input_data = { - "token_ids": tf.ones((2, seq_length), dtype="int32"), - "padding_mask": tf.ones((2, seq_length), dtype="int32"), + "token_ids": ops.ones((2, seq_length), dtype="int32"), + "padding_mask": ops.ones((2, seq_length), dtype="int32"), } self.backbone(input_data) @@ -121,8 +118,8 @@ def setUp(self): max_sequence_length=5, ) self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index b2417f98d2..2bc67088e7 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -15,20 +15,32 @@ import copy -import tensorflow as tf - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.generative_task import GenerativeTask from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( GPT2CausalLMPreprocessor, ) from keras_nlp.models.gpt2.gpt2_presets import backbone_presets -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty +# TODO: Extend and factor this out into keras_nlp.layers. +class ReverseEmbedding(keras.layers.Layer): + def __init__(self, embedding, **kwargs): + super().__init__(**kwargs) + self.embedding = embedding + + def call(self, inputs): + kernel = ops.transpose(ops.convert_to_tensor(self.embedding.embeddings)) + return ops.matmul(inputs, kernel) + + def compute_output_shape(self, input_shape): + return (input_shape[0],) + (self.embedding.embeddings.shape[0],) + + @keras_nlp_export("keras_nlp.models.GPT2CausalLM") class GPT2CausalLM(GenerativeTask): """An end-to-end GPT2 model for causal langauge modeling. @@ -162,11 +174,10 @@ def __init__( x = backbone(inputs) # Use token embedding weights to project from the token representation # to vocabulary logits. - outputs = tf.matmul( - x, - backbone.token_embedding.embeddings, - transpose_b=True, - ) + outputs = ReverseEmbedding( + backbone.token_embedding, + name="reverse_embedding", + )(x) # Instantiate using Functional API Model constructor. super().__init__( @@ -185,7 +196,7 @@ def __init__( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(2e-5), metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) @classproperty @@ -234,33 +245,30 @@ def call_with_cache( ) x = self.backbone.get_layer("embeddings_dropout")(x) # Each decoder layer has a cache; we update them separately. - caches = tf.unstack(cache, axis=1) + caches = [] for i in range(self.backbone.num_layers): - current_cache = caches[i] + current_cache = cache[:, i, ...] x, next_cache = self.backbone.get_layer(f"transformer_layer_{i}")( x, self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) - caches[i] = next_cache - cache = tf.stack(caches, axis=1) + caches.append(next_cache) + cache = ops.stack(caches, axis=1) x = self.backbone.get_layer("layer_norm")(x) hidden_states = x - logits = tf.matmul( - hidden_states, - self.backbone.get_layer("token_embedding").embeddings, - transpose_b=True, - ) + logits = self.get_layer("reverse_embedding")(x) return logits, hidden_states, cache def _build_cache(self, token_ids): """Build an empty cache for use with `call_with_cache()`.""" - batch_size, max_length = tf.shape(token_ids)[0], tf.shape(token_ids)[1] + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] num_layers = self.backbone.num_layers num_heads = self.backbone.num_heads head_dim = self.backbone.hidden_dim // self.backbone.num_heads shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] - cache = tf.zeros(shape, dtype=self.compute_dtype) + cache = ops.zeros(shape, dtype=self.compute_dtype) # Seed the cache. _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache @@ -287,24 +295,23 @@ def generate_step( # Create and seed cache with a single forward pass. hidden_states, cache = self._build_cache(token_ids) # Compute the lengths of all user inputted tokens ids. - row_lengths = tf.math.reduce_sum( - tf.cast(padding_mask, "int32"), axis=-1 - ) + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) # Start at the first index that has no user inputted id. - index = tf.math.reduce_min(row_lengths) + index = ops.min(row_lengths) def next(prompt, cache, index): # The cache index is the index of our previous token. cache_update_index = index - 1 - prompt = tf.slice(prompt, [0, cache_update_index], [-1, 1]) + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) logits, hidden_states, cache = self.call_with_cache( prompt, cache, cache_update_index, ) return ( - tf.squeeze(logits, axis=1), - tf.squeeze(hidden_states, axis=1), + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), cache, ) @@ -323,14 +330,15 @@ def next(prompt, cache, index): # Build a mask of `end_token_id` locations not in the original # prompt (not in locations where `padding_mask` is True). end_locations = (token_ids == end_token_id) & (~padding_mask) - end_locations = tf.cast(end_locations, "int32") + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. - overflow = tf.math.cumsum(end_locations, exclusive=True, axis=-1) + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations # Our padding mask is the inverse of these overflow locations. - padding_mask = ~tf.cast(overflow, "bool") + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) else: # Without early stopping, all locations will have been updated. - padding_mask = tf.ones_like(token_ids, dtype="bool") + padding_mask = ops.ones_like(token_ids, dtype="bool") return { "token_ids": token_ids, "padding_mask": padding_mask, @@ -353,7 +361,7 @@ def create_layout_map(cls, mesh): distribution, and the second for model parallel distribution. Returns: - A `tf.keras.dtensor.experimental.LayoutMap` which contains the + A `keras.dtensor.experimental.LayoutMap` which contains the proper layout to weights mapping for the model parallel setting. Examples: diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py index 8441b90430..b0a0a8adc6 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py @@ -129,6 +129,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["airplane at airport"]) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py index 836b40b0be..e52274da7f 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py @@ -16,11 +16,11 @@ import os from unittest.mock import patch -import numpy as np import pytest import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( @@ -32,10 +32,6 @@ class GPT2CausalLMTest(TestCase): def setUp(self): - # For DTensor. - keras.backend.experimental.enable_tf_random_generator() - keras.utils.set_random_seed(1337) - self.vocab = { "!": 0, "air": 1, @@ -65,12 +61,10 @@ def setUp(self): preprocessor=self.preprocessor, ) - self.raw_batch = tf.constant( - [ - " airplane at airport", - " airplane at airport", - ] - ) + self.raw_batch = [ + " airplane at airport", + " airplane at airport", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] self.raw_dataset = tf.data.Dataset.from_tensor_slices( self.raw_batch @@ -126,8 +120,10 @@ def test_early_stopping(self): def wrapper(*args, **kwargs): """Modify output logits to always favor end_token_id""" logits, hidden_states, cache = call_with_cache(*args, **kwargs) - logits = np.zeros(logits.shape.as_list()) - logits[:, :, self.preprocessor.tokenizer.end_token_id] = 1.0e9 + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) return logits, hidden_states, cache with patch.object(self.causal_lm, "call_with_cache", wraps=wrapper): @@ -135,7 +131,6 @@ def wrapper(*args, **kwargs): output = self.causal_lm.generate(prompt) # We should immediately abort and output the prompt. self.assertEqual(prompt, output) - self.assertEqual(self.causal_lm.call_with_cache.call_count, 2) def test_generate_compilation(self): # Assert we do not recompile with successive calls. diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py index b535dd65bb..d7b921ee14 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py @@ -110,6 +110,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["airplane at airport"]) diff --git a/keras_nlp/models/gpt2/gpt2_presets_test.py b/keras_nlp/models/gpt2/gpt2_presets_test.py index 396d5b995c..e9cdd707d7 100644 --- a/keras_nlp/models/gpt2/gpt2_presets_test.py +++ b/keras_nlp/models/gpt2/gpt2_presets_test.py @@ -14,9 +14,9 @@ """Tests for loading pretrained model presets.""" import pytest -import tensorflow as tf from absl.testing import parameterized +from keras_nlp.backend import ops from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_nlp.tests.test_case import TestCase @@ -42,8 +42,8 @@ def test_tokenizer_output(self): ) def test_backbone_output(self, load_weights): input_data = { - "token_ids": tf.constant([[1169, 2068, 7586, 21831, 13]]), - "padding_mask": tf.constant([[1, 1, 1, 1, 1]]), + "token_ids": ops.array([[1169, 2068, 7586, 21831, 13]]), + "padding_mask": ops.array([[1, 1, 1, 1, 1]]), } model = GPT2Backbone.from_preset( "gpt2_base_en", load_weights=load_weights @@ -95,12 +95,12 @@ def test_load_gpt2(self, load_weights): for preset in GPT2Backbone.presets: model = GPT2Backbone.from_preset(preset, load_weights=load_weights) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 1024), dtype="int64", maxval=model.vocabulary_size, ), - "padding_mask": tf.constant([1] * 1024, shape=(1, 1024)), + "padding_mask": ops.array([1] * 1024, shape=(1, 1024)), } model(input_data) diff --git a/keras_nlp/models/gpt2/gpt2_tokenizer_test.py b/keras_nlp/models/gpt2/gpt2_tokenizer_test.py index e4e5baee86..51347a30da 100644 --- a/keras_nlp/models/gpt2/gpt2_tokenizer_test.py +++ b/keras_nlp/models/gpt2/gpt2_tokenizer_test.py @@ -74,7 +74,7 @@ def test_tokenize_end_token(self): self.assertAllEqual(output, [1, 2, 3, 1, 4, 0]) def test_tokenize_batch(self): - input_data = tf.constant([" airplane at airport", " kohli is the best"]) + input_data = [" airplane at airport", " kohli is the best"] output = self.tokenizer(input_data) self.assertAllEqual(output, [[1, 2, 3, 1, 4], [5, 6, 7, 8, 9]]) @@ -99,6 +99,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" airplane at airport"]) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py index 7c82eaf469..3a3ce2f3e1 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.models.backbone import Backbone from keras_nlp.models.gpt_neo_x.gpt_neo_x_decoder import GPTNeoXDecoder +from keras_nlp.utils.tensor_utils import assert_tf_backend def _gpt_neo_x_kernel_initializer(stddev=0.02): @@ -76,6 +75,8 @@ def __init__( max_sequence_length=512, **kwargs, ): + assert_tf_backend(self.__class__.__name__) + # Inputs token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") padding_mask = keras.Input( @@ -116,7 +117,7 @@ def __init__( name="layer_norm", axis=-1, epsilon=layer_norm_epsilon, - dtype=tf.float32, + dtype="float32", )(x) # Instantiate using Functional API Model constructor diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone_test.py index d8fe095d99..9e6a2973b5 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone_test.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone_test.py @@ -19,10 +19,12 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models import GPTNeoXBackbone from keras_nlp.tests.test_case import TestCase +@pytest.mark.tf_only class GPTNeoXTest(TestCase): def setUp(self): self.backbone = GPTNeoXBackbone( @@ -34,8 +36,8 @@ def setUp(self): max_sequence_length=10, ) self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch @@ -55,8 +57,8 @@ def test_name(self): def test_variable_sequence_length(self): for seq_length in (2, 3, 4): input_data = { - "token_ids": tf.ones((2, seq_length), dtype="int32"), - "padding_mask": tf.ones((2, seq_length), dtype="int32"), + "token_ids": ops.ones((2, seq_length), dtype="int32"), + "padding_mask": ops.ones((2, seq_length), dtype="int32"), } self.backbone(input_data) @@ -99,8 +101,8 @@ def setUp(self): max_sequence_length=10, ) self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py index c60f9ee0ef..e8f8396d4a 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py @@ -129,6 +129,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["airplane at airport"]) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py index 16fdb028ff..b20fb8157d 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py @@ -15,8 +15,6 @@ """GPTNeoX preprocessor layer.""" from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker - -# from keras_nlp.models.gpt_neo_x.gpt_neo_x_presets import backbone_presets from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py index 5cf3f485a9..463bfb7881 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py @@ -113,6 +113,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["airplane at airport"]) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py index f227953099..4e2ca8961b 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py @@ -74,7 +74,7 @@ def test_tokenize_end_token(self): self.assertAllEqual(output, [1, 2, 3, 1, 4, 0]) def test_tokenize_batch(self): - input_data = tf.constant([" airplane at airport", " kohli is the best"]) + input_data = [" airplane at airport", " kohli is the best"] output = self.tokenizer(input_data) self.assertAllEqual(output, [[1, 2, 3, 1, 4], [5, 6, 7, 8, 9]]) @@ -99,6 +99,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" airplane at airport"]) diff --git a/keras_nlp/models/opt/opt_backbone.py b/keras_nlp/models/opt/opt_backbone.py index ff22c9cf40..fb7bb0bd70 100644 --- a/keras_nlp/models/opt/opt_backbone.py +++ b/keras_nlp/models/opt/opt_backbone.py @@ -18,6 +18,7 @@ from tensorflow.experimental import dtensor from tensorflow.experimental.dtensor import Layout +from tensorflow.keras.dtensor.experimental import LayoutMap from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -217,7 +218,7 @@ def create_layout_map(cls, mesh): _, model_dim = mesh.dim_names unshard_dim = dtensor.UNSHARDED - layout_map = keras.dtensor.experimental.LayoutMap(mesh=mesh) + layout_map = LayoutMap(mesh=mesh) # Embedding sharding layout_map[r".*embeddings"] = Layout([unshard_dim, model_dim], mesh) diff --git a/keras_nlp/models/opt/opt_backbone_test.py b/keras_nlp/models/opt/opt_backbone_test.py index df53f493ac..6746075795 100644 --- a/keras_nlp/models/opt/opt_backbone_test.py +++ b/keras_nlp/models/opt/opt_backbone_test.py @@ -19,16 +19,13 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.opt.opt_backbone import OPTBackbone from keras_nlp.tests.test_case import TestCase -class OPTTest(TestCase): +class OPTBackboneTest(TestCase): def setUp(self): - # For DTensor. - keras.backend.experimental.enable_tf_random_generator() - keras.utils.set_random_seed(1337) - self.backbone = OPTBackbone( vocabulary_size=10, num_layers=2, @@ -38,8 +35,8 @@ def setUp(self): max_sequence_length=5, ) self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( @@ -60,8 +57,8 @@ def test_name(self): def test_variable_sequence_length_call_opt(self): for seq_length in (2, 3, 4): input_data = { - "token_ids": tf.ones((2, seq_length), dtype="int32"), - "padding_mask": tf.ones((2, seq_length), dtype="int32"), + "token_ids": ops.ones((2, seq_length), dtype="int32"), + "padding_mask": ops.ones((2, seq_length), dtype="int32"), } self.backbone(input_data) @@ -121,8 +118,8 @@ def setUp(self): max_sequence_length=128, ) self.input_batch = { - "token_ids": tf.ones((8, 128), dtype="int32"), - "padding_mask": tf.ones((8, 128), dtype="int32"), + "token_ids": ops.ones((8, 128), dtype="int32"), + "padding_mask": ops.ones((8, 128), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py index 9395f8feb6..0bdedf1ed9 100644 --- a/keras_nlp/models/opt/opt_causal_lm.py +++ b/keras_nlp/models/opt/opt_causal_lm.py @@ -15,20 +15,32 @@ import copy -import tensorflow as tf - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.generative_task import GenerativeTask from keras_nlp.models.opt.opt_backbone import OPTBackbone from keras_nlp.models.opt.opt_causal_lm_preprocessor import ( OPTCausalLMPreprocessor, ) from keras_nlp.models.opt.opt_presets import backbone_presets -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty +# TODO: Extend and factor this out into keras_nlp.layers. +class ReverseEmbedding(keras.layers.Layer): + def __init__(self, embedding, **kwargs): + super().__init__(**kwargs) + self.embedding = embedding + + def call(self, inputs): + kernel = ops.transpose(ops.convert_to_tensor(self.embedding.embeddings)) + return ops.matmul(inputs, kernel) + + def compute_output_shape(self, input_shape): + return (input_shape[0],) + (self.embedding.embeddings.shape[0],) + + @keras_nlp_export("keras_nlp.models.OPTCausalLM") class OPTCausalLM(GenerativeTask): """An end-to-end OPT model for causal langauge modeling. @@ -162,11 +174,10 @@ def __init__( x = backbone(inputs) # Use token embedding weights to project from the token representation # to vocabulary logits. - outputs = tf.matmul( - x, - backbone.token_embedding.embeddings, - transpose_b=True, - ) + outputs = ReverseEmbedding( + backbone.token_embedding, + name="reverse_embedding", + )(x) # Instantiate using Functional API Model constructor. super().__init__( @@ -185,7 +196,7 @@ def __init__( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(2e-5), metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) @classproperty @@ -229,33 +240,30 @@ def call_with_cache( token_ids, start_index=cache_update_index ) # Each decoder layer has a cache; we update them separately. - caches = tf.unstack(cache, axis=1) + caches = [] for i in range(self.backbone.num_layers): - current_cache = caches[i] + current_cache = cache[:, i, ...] x, next_cache = self.backbone.get_layer(f"transformer_layer_{i}")( x, self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) - caches[i] = next_cache - cache = tf.stack(caches, axis=1) + caches.append(next_cache) + cache = ops.stack(caches, axis=1) x = self.backbone.get_layer("layer_norm")(x) hidden_states = x - logits = tf.matmul( - hidden_states, - self.backbone.token_embedding.embeddings, - transpose_b=True, - ) + logits = self.get_layer("reverse_embedding")(x) return logits, hidden_states, cache def _build_cache(self, token_ids): """Build an empty cache for use with `call_with_cache()`.""" - batch_size, max_length = tf.shape(token_ids)[0], tf.shape(token_ids)[1] + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] num_layers = self.backbone.num_layers num_heads = self.backbone.num_heads head_dim = self.backbone.hidden_dim // self.backbone.num_heads shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] - cache = tf.zeros(shape, dtype=self.compute_dtype) + cache = ops.zeros(shape, dtype=self.compute_dtype) # Seed the cache. _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache @@ -282,24 +290,23 @@ def generate_step( # Create and seed cache with a single forward pass. hidden_states, cache = self._build_cache(token_ids) # Compute the lengths of all user inputted tokens ids. - row_lengths = tf.math.reduce_sum( - tf.cast(padding_mask, "int32"), axis=-1 - ) + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) # Start at the first index that has no user inputted id. - index = tf.math.reduce_min(row_lengths) + index = ops.min(row_lengths) def next(prompt, cache, index): # The cache index is the index of our previous token. cache_update_index = index - 1 - prompt = tf.slice(prompt, [0, cache_update_index], [-1, 1]) + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) logits, hidden_states, cache = self.call_with_cache( prompt, cache, cache_update_index, ) return ( - tf.squeeze(logits, axis=1), - tf.squeeze(hidden_states, axis=1), + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), cache, ) @@ -318,14 +325,15 @@ def next(prompt, cache, index): # Build a mask of `end_token_id` locations not in the original # prompt (not in locations where `padding_mask` is True). end_locations = (token_ids == end_token_id) & (~padding_mask) - end_locations = tf.cast(end_locations, "int32") + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. - overflow = tf.math.cumsum(end_locations, exclusive=True, axis=-1) + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations # Our padding mask is the inverse of these overflow locations. - padding_mask = ~tf.cast(overflow, "bool") + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) else: # Without early stopping, all locations will have been updated. - padding_mask = tf.ones_like(token_ids, dtype="bool") + padding_mask = ops.ones_like(token_ids, dtype="bool") return { "token_ids": token_ids, "padding_mask": padding_mask, diff --git a/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py b/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py index f6fd07e1e7..c64ebb343a 100644 --- a/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py @@ -131,6 +131,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" airplane at airport"]) diff --git a/keras_nlp/models/opt/opt_causal_lm_test.py b/keras_nlp/models/opt/opt_causal_lm_test.py index 4b523eb761..b4bff2c7d3 100644 --- a/keras_nlp/models/opt/opt_causal_lm_test.py +++ b/keras_nlp/models/opt/opt_causal_lm_test.py @@ -16,11 +16,11 @@ import os from unittest.mock import patch -import numpy as np import pytest import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.opt.opt_backbone import OPTBackbone from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM from keras_nlp.models.opt.opt_causal_lm_preprocessor import ( @@ -32,10 +32,6 @@ class OPTCausalLMTest(TestCase): def setUp(self): - # For DTensor. - keras.backend.experimental.enable_tf_random_generator() - keras.utils.set_random_seed(1337) - self.vocab = { "": 0, "": 1, @@ -71,12 +67,10 @@ def setUp(self): preprocessor=self.preprocessor, ) - self.raw_batch = tf.constant( - [ - " airplane at airport", - " airplane at airport", - ] - ) + self.raw_batch = [ + " airplane at airport", + " airplane at airport", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] self.raw_dataset = tf.data.Dataset.from_tensor_slices( self.raw_batch @@ -132,8 +126,10 @@ def test_early_stopping(self): def wrapper(*args, **kwargs): """Modify output logits to always favor end_token_id""" logits, hidden_states, cache = call_with_cache(*args, **kwargs) - logits = np.zeros(logits.shape.as_list()) - logits[:, :, self.preprocessor.tokenizer.end_token_id] = 1.0e9 + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) return logits, hidden_states, cache with patch.object(self.causal_lm, "call_with_cache", wraps=wrapper): @@ -141,7 +137,6 @@ def wrapper(*args, **kwargs): output = self.causal_lm.generate(prompt) # We should immediately abort and output the prompt. self.assertEqual(prompt, output) - self.assertEqual(self.causal_lm.call_with_cache.call_count, 2) def test_generate_compilation(self): # Assert we do not recompile with successive calls. @@ -167,7 +162,7 @@ def test_saved_model(self): keras.utils.set_random_seed(42) model_output = self.causal_lm.predict(self.raw_batch) path = os.path.join(self.get_temp_dir(), "model.keras") - self.seq_2_seq_lm.save(path, save_format="keras_v3") + self.causal_lm.save(path, save_format="keras_v3") restored_model = keras.models.load_model(path) # Check we got the real object back. diff --git a/keras_nlp/models/opt/opt_preprocessor_test.py b/keras_nlp/models/opt/opt_preprocessor_test.py index aaac994942..a06243ece9 100644 --- a/keras_nlp/models/opt/opt_preprocessor_test.py +++ b/keras_nlp/models/opt/opt_preprocessor_test.py @@ -112,6 +112,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" airplane at airport"]) diff --git a/keras_nlp/models/opt/opt_presets_test.py b/keras_nlp/models/opt/opt_presets_test.py index ec95e5ecc6..744d9d11ef 100644 --- a/keras_nlp/models/opt/opt_presets_test.py +++ b/keras_nlp/models/opt/opt_presets_test.py @@ -14,9 +14,9 @@ """Tests for loading pretrained model presets.""" import pytest -import tensorflow as tf from absl.testing import parameterized +from keras_nlp.backend import ops from keras_nlp.models.opt.opt_backbone import OPTBackbone from keras_nlp.models.opt.opt_tokenizer import OPTTokenizer from keras_nlp.tests.test_case import TestCase @@ -42,8 +42,8 @@ def test_tokenizer_output(self): ) def test_backbone_output(self, load_weights): input_data = { - "token_ids": tf.constant([[133, 2119, 6219, 23602, 4]]), - "padding_mask": tf.constant([[1, 1, 1, 1, 1]]), + "token_ids": ops.array([[133, 2119, 6219, 23602, 4]]), + "padding_mask": ops.array([[1, 1, 1, 1, 1]]), } model = OPTBackbone.from_preset( "opt_125m_en", load_weights=load_weights @@ -95,12 +95,12 @@ def test_load_opt(self, load_weights): for preset in OPTBackbone.presets: model = OPTBackbone.from_preset(preset, load_weights=load_weights) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 1024), dtype="int64", maxval=model.vocabulary_size, ), - "padding_mask": tf.constant([1] * 1024, shape=(1, 1024)), + "padding_mask": ops.array([1] * 1024, shape=(1, 1024)), } model(input_data) diff --git a/keras_nlp/models/opt/opt_tokenizer.py b/keras_nlp/models/opt/opt_tokenizer.py index 3be91276b0..e0a5d70a00 100644 --- a/keras_nlp/models/opt/opt_tokenizer.py +++ b/keras_nlp/models/opt/opt_tokenizer.py @@ -48,39 +48,26 @@ class OPTTokenizer(BytePairTokenizer): merge entities separated by a space. Examples: - - Batched inputs. - >>> vocab = {"": 1, "": 2, "a": 3, "Ä quick": 4, "Ä fox": 5} - >>> merges = ["Ä  q", "u i", "c k", "ui ck", "Ä q uick"] - >>> merges += ["Ä  f", "o x", "Ä f ox"] - >>> tokenizer = keras_nlp.models.OPTTokenizer( - ... vocabulary=vocab, - ... merges=merges, - ... ) - >>> tokenizer(["a quick fox", "a fox quick"]) - - - Unbatched input. - >>> vocab = {"": 1, "": 2, "a": 3, "Ä quick": 4, "Ä fox": 5} - >>> merges = ["Ä  q", "u i", "c k", "ui ck", "Ä q uick"] - >>> merges += ["Ä  f", "o x", "Ä f ox"] - >>> tokenizer = keras_nlp.models.OPTTokenizer( - ... vocabulary=vocab, - ... merges=merges, - ... ) - >>> tokenizer("a quick fox") - - - Detokenization. - >>> vocab = {"": 1, "": 2, "Ä quick": 4, "Ä fox": 5} - >>> merges = ["Ä  q", "u i", "c k", "ui ck", "Ä q uick"] - >>> merges += ["Ä  f", "o x", "Ä f ox"] - >>> tokenizer = keras_nlp.models.OPTTokenizer( - ... vocabulary=vocab, - ... merges=merges, - ... ) - >>> tokenizer.detokenize(tokenizer(" quick fox")).numpy().decode('utf-8') - ' quick fox' + ```python + # Unbatched input. + tokenizer = keras_nlp.models.OPTTokenizer.from_preset( + "opt_125m_en", + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + vocab = {"": 1, "": 2, "Ä quick": 4, "Ä fox": 5} + merges = ["Ä  q", "u i", "c k", "ui ck", "Ä q uick"] + merges += ["Ä  f", "o x", "Ä f ox"] + tokenizer = keras_nlp.models.OPTTokenizer(vocabulary=vocab, merges=merges) + tokenizer("The quick brown fox jumped.") + ``` """ def __init__( diff --git a/keras_nlp/models/opt/opt_tokenizer_test.py b/keras_nlp/models/opt/opt_tokenizer_test.py index 1b3a34a559..7fa37385f5 100644 --- a/keras_nlp/models/opt/opt_tokenizer_test.py +++ b/keras_nlp/models/opt/opt_tokenizer_test.py @@ -58,7 +58,7 @@ def test_tokenize_special_tokens(self): self.assertAllEqual(output, [1, 2, 3, 4, 2, 5, 1, 0]) def test_tokenize_batch(self): - input_data = tf.constant([" airplane at airport", " kohli is the best"]) + input_data = [" airplane at airport", " kohli is the best"] output = self.tokenizer(input_data) self.assertAllEqual(output, [[2, 3, 4, 2, 5], [6, 7, 8, 9, 10]]) @@ -83,6 +83,7 @@ def test_serialization(self): ) @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" airplane at airport"]) diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index df40dcb753..f24b9203a1 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -13,18 +13,28 @@ # limitations under the License. from keras_nlp.backend import keras +from keras_nlp.layers.preprocessing.preprocessing_layer import ( + PreprocessingLayer, +) from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.python_utils import format_docstring @keras.saving.register_keras_serializable(package="keras_nlp") -class Preprocessor(keras.layers.Layer): +class Preprocessor(PreprocessingLayer): """Base class for model preprocessors.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._tokenizer = None + def __setattr__(self, name, value): + # Work around torch setattr for properties. + if name in ["tokenizer"]: + object.__setattr__(self, name, value) + else: + super().__setattr__(name, value) + @property def tokenizer(self): """The tokenizer used to tokenize strings.""" diff --git a/keras_nlp/models/roberta/roberta_backbone_test.py b/keras_nlp/models/roberta/roberta_backbone_test.py index f72cc180fa..4867823be9 100644 --- a/keras_nlp/models/roberta/roberta_backbone_test.py +++ b/keras_nlp/models/roberta/roberta_backbone_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.tests.test_case import TestCase @@ -35,8 +36,8 @@ def setUp(self): ) self.batch_size = 8 self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( @@ -66,13 +67,13 @@ def test_serialization(self): def test_variable_sequence_length_call_roberta(self): for seq_length in (2, 3, 4): input_data = { - "token_ids": tf.ones((2, seq_length), dtype="int32"), - "padding_mask": tf.ones((2, seq_length), dtype="int32"), + "token_ids": ops.ones((2, seq_length), dtype="int32"), + "padding_mask": ops.ones((2, seq_length), dtype="int32"), } output = self.backbone(input_data) self.assertAllEqual( - tf.shape(output), - [2, seq_length, self.backbone.hidden_dim], + ops.shape(output), + (2, seq_length, self.backbone.hidden_dim), ) @pytest.mark.large # Saving is slow, so mark these large. @@ -104,8 +105,8 @@ def setUp(self): max_sequence_length=128, ) self.input_batch = { - "token_ids": tf.ones((8, 128), dtype="int32"), - "padding_mask": tf.ones((8, 128), dtype="int32"), + "token_ids": ops.ones((8, 128), dtype="int32"), + "padding_mask": ops.ones((8, 128), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch diff --git a/keras_nlp/models/roberta/roberta_classifier.py b/keras_nlp/models/roberta/roberta_classifier.py index d9461ed812..3f838c1d4c 100644 --- a/keras_nlp/models/roberta/roberta_classifier.py +++ b/keras_nlp/models/roberta/roberta_classifier.py @@ -22,7 +22,6 @@ from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor from keras_nlp.models.roberta.roberta_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -187,7 +186,7 @@ def __init__( ), optimizer=keras.optimizers.Adam(2e-5), metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) def get_config(self): diff --git a/keras_nlp/models/roberta/roberta_classifier_test.py b/keras_nlp/models/roberta/roberta_classifier_test.py index 9f32494ad9..a512df0c85 100644 --- a/keras_nlp/models/roberta/roberta_classifier_test.py +++ b/keras_nlp/models/roberta/roberta_classifier_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor @@ -71,15 +72,13 @@ def setUp(self): ) # Setup data. - self.raw_batch = tf.constant( - [ - " airplane at airport", - " the airplane is the best", - ] - ) + self.raw_batch = [ + " airplane at airport", + " the airplane is the best", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch) self.raw_dataset = tf.data.Dataset.from_tensor_slices( - (self.raw_batch, tf.ones((2,))) + (self.raw_batch, ops.ones((2,))) ).batch(2) self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) @@ -93,7 +92,7 @@ def test_classifier_predict(self): # Assert predictions match. self.assertAllClose(preds1, preds2) # Assert valid softmax output. - self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0]) + self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0]) def test_classifier_fit(self): self.classifier.fit(self.raw_dataset) diff --git a/keras_nlp/models/roberta/roberta_masked_lm.py b/keras_nlp/models/roberta/roberta_masked_lm.py index c68877871f..659e6c2e00 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm.py +++ b/keras_nlp/models/roberta/roberta_masked_lm.py @@ -25,7 +25,6 @@ ) from keras_nlp.models.roberta.roberta_presets import backbone_presets from keras_nlp.models.task import Task -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -139,7 +138,7 @@ def __init__( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) @classproperty diff --git a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py index f827b87bd4..8bcba68264 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py @@ -149,6 +149,7 @@ def test_serialization(self): ) @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" airplane at airport"]) diff --git a/keras_nlp/models/roberta/roberta_masked_lm_test.py b/keras_nlp/models/roberta/roberta_masked_lm_test.py index f6afa41676..838360d7b6 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm_test.py +++ b/keras_nlp/models/roberta/roberta_masked_lm_test.py @@ -73,12 +73,10 @@ def setUp(self): preprocessor=None, ) - self.raw_batch = tf.constant( - [ - " airplane at airport", - " the airplane is the best", - ] - ) + self.raw_batch = [ + " airplane at airport", + " the airplane is the best", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch) self.raw_dataset = tf.data.Dataset.from_tensor_slices( self.raw_batch diff --git a/keras_nlp/models/roberta/roberta_preprocessor_test.py b/keras_nlp/models/roberta/roberta_preprocessor_test.py index 785a090769..2df4519561 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor_test.py +++ b/keras_nlp/models/roberta/roberta_preprocessor_test.py @@ -148,6 +148,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" airplane at airport"]) diff --git a/keras_nlp/models/roberta/roberta_presets_test.py b/keras_nlp/models/roberta/roberta_presets_test.py index c5cb97d255..c3f5f63aaf 100644 --- a/keras_nlp/models/roberta/roberta_presets_test.py +++ b/keras_nlp/models/roberta/roberta_presets_test.py @@ -14,9 +14,9 @@ """Tests for loading pretrained model presets.""" import pytest -import tensorflow as tf from absl.testing import parameterized +from keras_nlp.backend import ops from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier from keras_nlp.models.roberta.roberta_masked_lm import RobertaMaskedLM @@ -56,8 +56,8 @@ def test_preprocessor_output(self): ) def test_backbone_output(self, load_weights): input_data = { - "token_ids": tf.constant([[0, 133, 2119, 2]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[0, 133, 2119, 2]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = RobertaBackbone.from_preset( "roberta_base_en", load_weights=load_weights @@ -84,8 +84,8 @@ def test_classifier_output(self, load_weights): ) def test_classifier_output_without_preprocessing(self, load_weights): input_data = { - "token_ids": tf.constant([[101, 1996, 4248, 102]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[101, 1996, 4248, 102]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = RobertaClassifier.from_preset( "roberta_base_en", @@ -112,9 +112,9 @@ def test_masked_lm_output(self, load_weights): ) def test_masked_lm_output_without_preprocessing(self, load_weights): input_data = { - "token_ids": tf.constant([[101, 1996, 4248, 102]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), - "mask_positions": tf.constant([[0, 0]]), + "token_ids": ops.array([[101, 1996, 4248, 102]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), + "mask_positions": ops.array([[0, 0]]), } model = RobertaMaskedLM.from_preset( "roberta_base_en", @@ -168,10 +168,10 @@ def test_load_roberta(self, load_weights): preset, load_weights=load_weights ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=model.vocabulary_size ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } model(input_data) @@ -183,7 +183,7 @@ def test_load_roberta_classifier(self, load_weights): classifier = RobertaClassifier.from_preset( preset, num_classes=4, load_weights=load_weights ) - input_data = tf.constant(["The quick brown fox."]) + input_data = ["The quick brown fox."] classifier.predict(input_data) @parameterized.named_parameters( @@ -198,12 +198,12 @@ def test_load_roberta_classifier_without_preprocessing(self, load_weights): load_weights=load_weights, ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=classifier.backbone.vocabulary_size, ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } classifier.predict(input_data) @@ -215,7 +215,7 @@ def test_load_roberta_masked_lm(self, load_weights): classifier = RobertaMaskedLM.from_preset( preset, load_weights=load_weights ) - input_data = tf.constant(["The quick brown fox."]) + input_data = ["The quick brown fox."] classifier.predict(input_data) @parameterized.named_parameters( @@ -229,13 +229,13 @@ def test_load_roberta_masked_lm_without_preprocessing(self, load_weights): load_weights=load_weights, ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=classifier.backbone.vocabulary_size, ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), - "mask_positions": tf.constant([1] * 128, shape=(1, 128)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), + "mask_positions": ops.array([1] * 128, shape=(1, 128)), } classifier.predict(input_data) diff --git a/keras_nlp/models/roberta/roberta_tokenizer_test.py b/keras_nlp/models/roberta/roberta_tokenizer_test.py index 2df2c07658..9f5f390f51 100644 --- a/keras_nlp/models/roberta/roberta_tokenizer_test.py +++ b/keras_nlp/models/roberta/roberta_tokenizer_test.py @@ -59,7 +59,7 @@ def test_tokenize_special_tokens(self): self.assertAllEqual(output, [0, 3, 4, 5, 3, 6, 0, 1]) def test_tokenize_batch(self): - input_data = tf.constant([" airplane at airport", " kohli is the best"]) + input_data = [" airplane at airport", " kohli is the best"] output = self.tokenizer(input_data) self.assertAllEqual(output, [[3, 4, 5, 3, 6], [7, 8, 9, 10, 11]]) @@ -84,6 +84,7 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" airplane at airport"]) diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index f3b1ca2edd..b23e0becd4 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -14,14 +14,13 @@ """T5 backbone model.""" -import tensorflow as tf - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.models.backbone import Backbone from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm from keras_nlp.models.t5.t5_transformer_layer import T5TransformerLayer from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import assert_tf_backend @keras_nlp_export("keras_nlp.models.T5Backbone") @@ -79,6 +78,8 @@ def __init__( layer_norm_epsilon=1e-06, **kwargs, ): + assert_tf_backend(self.__class__.__name__) + # Encoder inputs encoder_token_ids = keras.Input( shape=(None,), dtype="int32", name="encoder_token_ids" @@ -112,8 +113,7 @@ def __init__( name="encoder_embedding_dropout", )(token_embedding) - # Encoder attention mask is just our padding mask. - encoder_attention_mask = encoder_padding_mask[:, tf.newaxis, :] + encoder_attention_mask = encoder_padding_mask[:, None, :] position_bias = None for i in range(num_layers): diff --git a/keras_nlp/models/t5/t5_backbone_test.py b/keras_nlp/models/t5/t5_backbone_test.py index be71ba8bac..8666168e8b 100644 --- a/keras_nlp/models/t5/t5_backbone_test.py +++ b/keras_nlp/models/t5/t5_backbone_test.py @@ -19,10 +19,12 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.t5.t5_backbone import T5Backbone from keras_nlp.tests.test_case import TestCase +@pytest.mark.tf_only class T5Test(TestCase): def setUp(self): self.backbone = T5Backbone( @@ -35,16 +37,16 @@ def setUp(self): self.batch_size = 2 seq_length = 3 self.input_batch = { - "encoder_token_ids": tf.ones( + "encoder_token_ids": ops.ones( (self.batch_size, seq_length), dtype="int32" ), - "encoder_padding_mask": tf.ones( + "encoder_padding_mask": ops.ones( (self.batch_size, seq_length), dtype="int32" ), - "decoder_token_ids": tf.ones( + "decoder_token_ids": ops.ones( (self.batch_size, seq_length), dtype="int32" ), - "decoder_padding_mask": tf.ones( + "decoder_padding_mask": ops.ones( (self.batch_size, seq_length), dtype="int32" ), } @@ -68,16 +70,16 @@ def test_name(self): def test_variable_sequence_length_call_t5(self): for seq_length in (2, 3, 4): input_data = { - "encoder_token_ids": tf.ones( + "encoder_token_ids": ops.ones( (self.batch_size, seq_length), dtype="int32" ), - "encoder_padding_mask": tf.ones( + "encoder_padding_mask": ops.ones( (self.batch_size, seq_length), dtype="int32" ), - "decoder_token_ids": tf.ones( + "decoder_token_ids": ops.ones( (self.batch_size, seq_length), dtype="int32" ), - "decoder_padding_mask": tf.ones( + "decoder_padding_mask": ops.ones( (self.batch_size, seq_length), dtype="int32" ), } @@ -124,8 +126,8 @@ def setUp(self): intermediate_dim=4, ) self.input_batch = { - "token_ids": tf.ones((8, 4), dtype="int32"), - "padding_mask": tf.ones((8, 4), dtype="int32"), + "token_ids": ops.ones((8, 4), dtype="int32"), + "padding_mask": ops.ones((8, 4), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch diff --git a/keras_nlp/models/t5/t5_layer_norm.py b/keras_nlp/models/t5/t5_layer_norm.py index 8eec5c8bf9..dab361e279 100644 --- a/keras_nlp/models/t5/t5_layer_norm.py +++ b/keras_nlp/models/t5/t5_layer_norm.py @@ -27,6 +27,7 @@ def build(self, input_shape): shape=(input_shape[-1],), initializer="ones", ) + self.built = True def call(self, hidden_states): variance = tf.math.reduce_mean( diff --git a/keras_nlp/models/t5/t5_multi_head_attention.py b/keras_nlp/models/t5/t5_multi_head_attention.py index 3288d4f93e..298cf67b18 100644 --- a/keras_nlp/models/t5/t5_multi_head_attention.py +++ b/keras_nlp/models/t5/t5_multi_head_attention.py @@ -82,7 +82,6 @@ def __init__( ) self.dropout_layer = keras.layers.Dropout(dropout) - def build(self, input_shape): if self.use_relative_attention_bias: self.relative_attention_bias = self.add_weight( name="embeddings", diff --git a/keras_nlp/models/t5/t5_tokenizer_test.py b/keras_nlp/models/t5/t5_tokenizer_test.py index ca9a7cba97..a06435fa8c 100644 --- a/keras_nlp/models/t5/t5_tokenizer_test.py +++ b/keras_nlp/models/t5/t5_tokenizer_test.py @@ -55,14 +55,14 @@ def test_tokenize(self): self.assertAllEqual(output, [4, 9, 5, 7]) def test_tokenize_batch(self): - input_data = tf.constant(["the quick brown fox", "the earth is round"]) + input_data = ["the quick brown fox", "the earth is round"] output = self.tokenizer(input_data) self.assertAllEqual(output, [[4, 9, 5, 7], [4, 6, 8, 10]]) def test_detokenize(self): - input_data = tf.constant([[4, 9, 5, 7]]) + input_data = [[4, 9, 5, 7]] output = self.tokenizer.detokenize(input_data) - self.assertEqual(output, tf.constant(["the quick brown fox"])) + self.assertEqual(output, ["the quick brown fox"]) def test_vocabulary_size(self): tokenizer = T5Tokenizer(proto=self.proto) @@ -90,6 +90,7 @@ def test_serialization(self): ) @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 67d8547134..6de3aee0ec 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -15,11 +15,12 @@ import os +import keras_core +import rich import tensorflow as tf from keras_nlp.backend import keras from keras_nlp.utils.keras_utils import print_msg -from keras_nlp.utils.keras_utils import print_row from keras_nlp.utils.pipeline_model import PipelineModel from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.python_utils import format_docstring @@ -34,7 +35,7 @@ def __init__(self, *args, **kwargs): self._preprocessor = None super().__init__(*args, **kwargs) - def _check_for_loss_mismatch(self): + def _check_for_loss_mismatch(self, loss): """Check for a softmax/from_logits mismatch after compile. We cannot handle this in the general case, but we can handle this for @@ -42,13 +43,13 @@ def _check_for_loss_mismatch(self): loss, and a `None` or `"softmax"` activation. """ # Only handle a single loss. - if tf.nest.is_nested(self.loss): + if tf.nest.is_nested(loss): return # Only handle tasks with activation. if not hasattr(self, "activation"): return - loss = keras.losses.get(self.loss) + loss = keras.losses.get(loss) activation = keras.activations.get(self.activation) if isinstance(loss, keras.losses.SparseCategoricalCrossentropy): from_logits = loss.get_config()["from_logits"] @@ -77,13 +78,20 @@ def _check_for_loss_mismatch(self): "`loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True)`. " ) - def compile(self, *args, **kwargs): - super().compile(*args, **kwargs) - self._check_for_loss_mismatch() + def compile(self, optimizer="rmsprop", loss=None, **kwargs): + self._check_for_loss_mismatch(loss) + super().compile(optimizer=optimizer, loss=loss, **kwargs) def preprocess_samples(self, x, y=None, sample_weight=None): return self.preprocessor(x, y=y, sample_weight=sample_weight) + def __setattr__(self, name, value): + # Work around torch setattr for properties. + if name in ["backbone", "preprocessor"]: + object.__setattr__(self, name, value) + else: + super().__setattr__(name, value) + @property def backbone(self): """A `keras.Model` instance providing the backbone submodel.""" @@ -110,7 +118,6 @@ def get_config(self): "backbone": keras.layers.serialize(self.backbone), "preprocessor": keras.layers.serialize(self.preprocessor), "name": self.name, - "trainable": self.trainable, } @classmethod @@ -240,31 +247,67 @@ def summary( **kwargs, ): """Override `model.summary()` to show a preprocessor if set.""" - # Defaults are copied from core Keras; we should try to stay in sync. - line_length = line_length or 98 - positions = positions or [0.33, 0.55, 0.67, 1.0] - if positions[-1] <= 1: - positions = [int(line_length * p) for p in positions] - if print_fn is None: + # Below is copied from keras-core for now. + # We should consider an API contract. + line_length = line_length or 108 + + if not print_fn and not keras.utils.is_interactive_logging_enabled(): print_fn = print_msg + def highlight_number(x): + return f"[color(45)]{x}[/]" if x is None else f"[color(34)]{x}[/]" + + def highlight_symbol(x): + return f"[color(33)]{x}[/]" + + def bold_text(x): + return f"[bold]{x}[/]" + if self.preprocessor: - column_names = ["Tokenizer (type)", "Vocab #"] + # Create a rich console for printing. Capture for non-interactive logging. + if print_fn: + console = rich.console.Console( + highlight=False, force_terminal=False, color_system=None + ) + console.begin_capture() + else: + console = rich.console.Console(highlight=False) + + column_1 = rich.table.Column( + "Tokenizer (type)", + justify="left", + width=int(0.5 * line_length), + ) + column_2 = rich.table.Column( + "Vocab #", + justify="right", + width=int(0.5 * line_length), + ) + table = rich.table.Table( + column_1, column_2, width=line_length, show_lines=True + ) tokenizer = self.preprocessor.tokenizer - column_values = [ - f"{tokenizer.name} ({tokenizer.__class__.__name__})", - f"{tokenizer.vocabulary_size()}", - ] - - print_fn(f'Preprocessor: "{self.preprocessor.name}"') - print_fn("_" * line_length) - print_row(column_names, positions[1:3], print_fn) - print_fn("=" * line_length) - print_row(column_values, positions[1:3], print_fn) - print_fn("_" * line_length) - print_fn(" " * line_length) - - super().summary( + tokenizer_name = rich.markup.escape(tokenizer.name) + tokenizer_class = highlight_symbol( + rich.markup.escape(tokenizer.__class__.__name__) + ) + table.add_row( + f"{tokenizer_name} ({tokenizer_class})", + highlight_number(f"{tokenizer.vocabulary_size():,}"), + ) + + # Print the to the console. + preprocessor_name = rich.markup.escape(self.preprocessor.name) + console.print(bold_text(f'Preprocessor: "{preprocessor_name}"')) + console.print(table) + + # Output captured summary for non-interactive logging. + if print_fn: + print_fn(console.end_capture(), line_break=False) + + # Hardcode summary from keras_core for now. + keras_core.Model.summary( + self, line_length=line_length, positions=positions, print_fn=print_fn, diff --git a/keras_nlp/models/task_test.py b/keras_nlp/models/task_test.py index 506780c159..4155eb2e63 100644 --- a/keras_nlp/models/task_test.py +++ b/keras_nlp/models/task_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tensorflow.keras.losses import SparseCategoricalCrossentropy - from keras_nlp.backend import keras from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.models.task import Task @@ -45,36 +43,52 @@ def test_summary_with_preprocessor(self): preprocessor = SimplePreprocessor() model = SimpleTask(preprocessor) summary = [] - model.summary(print_fn=lambda x: summary.append(x)) + model.summary(print_fn=lambda x, line_break: summary.append(x)) self.assertRegex("\n".join(summary), "Preprocessor:") def test_summary_without_preprocessor(self): model = SimpleTask() summary = [] - model.summary(print_fn=lambda x: summary.append(x)) + model.summary(print_fn=lambda x, line_break: summary.append(x)) self.assertNotRegex("\n".join(summary), "Preprocessor:") def test_mismatched_loss(self): # Logit output. model = SimpleTask(activation=None) - model.compile(loss=SparseCategoricalCrossentropy(from_logits=True)) + model.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True) + ) # Non-standard losses should not throw. model.compile(loss="mean_squared_error") with self.assertRaises(ValueError): model.compile(loss="sparse_categorical_crossentropy") with self.assertRaises(ValueError): - model.compile(loss=SparseCategoricalCrossentropy(from_logits=False)) + model.compile( + loss=keras.losses.SparseCategoricalCrossentropy( + from_logits=False + ) + ) # Probability output. model = SimpleTask(activation="softmax") - model.compile(loss=SparseCategoricalCrossentropy(from_logits=False)) + model.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False) + ) model.compile(loss="sparse_categorical_crossentropy") # Non-standard losses should not throw. model.compile(loss="mean_squared_error") with self.assertRaises(ValueError): - model.compile(loss=SparseCategoricalCrossentropy(from_logits=True)) + model.compile( + loss=keras.losses.SparseCategoricalCrossentropy( + from_logits=True + ) + ) # Non-standard activations should not throw. model = SimpleTask(activation="tanh") - model.compile(loss=SparseCategoricalCrossentropy(from_logits=True)) - model.compile(loss=SparseCategoricalCrossentropy(from_logits=False)) + model.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True) + ) + model.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False) + ) diff --git a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py index 7bcfc80790..a8c8758f66 100644 --- a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py +++ b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py @@ -82,6 +82,10 @@ def __init__( super().__init__(**kwargs) + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + self.built = True + self.num_mels = num_mels self.num_fft_bins = num_fft_bins self.stride = stride diff --git a/keras_nlp/models/whisper/whisper_audio_feature_extractor_test.py b/keras_nlp/models/whisper/whisper_audio_feature_extractor_test.py index 8961c93e1b..ddf4b331fd 100644 --- a/keras_nlp/models/whisper/whisper_audio_feature_extractor_test.py +++ b/keras_nlp/models/whisper/whisper_audio_feature_extractor_test.py @@ -79,6 +79,7 @@ def test_serialization(self): ) @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.tf_only def test_saved_model(self): audio_tensor = tf.ones((2,), dtype="float32") diff --git a/keras_nlp/models/whisper/whisper_backbone.py b/keras_nlp/models/whisper/whisper_backbone.py index fd85cc04c9..1f68853ba1 100644 --- a/keras_nlp/models/whisper/whisper_backbone.py +++ b/keras_nlp/models/whisper/whisper_backbone.py @@ -14,10 +14,9 @@ """Whisper backbone model.""" -import tensorflow as tf - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.token_and_position_embedding import ( TokenAndPositionEmbedding, @@ -25,12 +24,18 @@ from keras_nlp.models.backbone import Backbone from keras_nlp.models.whisper.whisper_decoder import WhisperDecoder from keras_nlp.models.whisper.whisper_encoder import WhisperEncoder +from keras_nlp.utils.tensor_utils import assert_tf_backend def whisper_kernel_initializer(stddev=0.02): return keras.initializers.TruncatedNormal(stddev=stddev) +class Padder(keras.layers.Layer): + def call(self, x): + return ops.pad(x, [[0, 0], [1, 1], [0, 0]]) + + @keras_nlp_export("keras_nlp.models.WhisperBackbone") class WhisperBackbone(Backbone): """A Whisper encoder-decoder network for speech. @@ -107,6 +112,8 @@ def __init__( max_decoder_sequence_length=448, **kwargs, ): + assert_tf_backend(self.__class__.__name__) + # Encoder inputs. Note that the encoder does not have a padding mask: # https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L132. encoder_feature_input = keras.Input( @@ -142,9 +149,7 @@ def __init__( # For the second conv. layer, we cannot use `padding="same"` since # that corresponds to a padding size of 1.5 (since stride is 2). Hence, # we will manually pad the input. - embedded_features = tf.pad( - embedded_features, paddings=[[0, 0], [1, 1], [0, 0]] - ) + embedded_features = Padder()(embedded_features) encoder_conv_layer_2 = keras.layers.Conv1D( filters=hidden_dim, kernel_size=3, diff --git a/keras_nlp/models/whisper/whisper_backbone_test.py b/keras_nlp/models/whisper/whisper_backbone_test.py index 587855762f..60a2c89996 100644 --- a/keras_nlp/models/whisper/whisper_backbone_test.py +++ b/keras_nlp/models/whisper/whisper_backbone_test.py @@ -19,10 +19,12 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.whisper.whisper_backbone import WhisperBackbone from keras_nlp.tests.test_case import TestCase +@pytest.mark.tf_only class WhisperBackboneTest(TestCase): def setUp(self): self.backbone = WhisperBackbone( @@ -35,9 +37,9 @@ def setUp(self): max_decoder_sequence_length=6, ) self.input_batch = { - "encoder_features": tf.ones((2, 5, 80), dtype="float32"), - "decoder_token_ids": tf.ones((2, 5), dtype="int32"), - "decoder_padding_mask": tf.ones((2, 5), dtype="int32"), + "encoder_features": ops.ones((2, 5, 80), dtype="float32"), + "decoder_token_ids": ops.ones((2, 5), dtype="int32"), + "decoder_padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( @@ -60,11 +62,13 @@ def test_name(self): def test_variable_sequence_length_call_whisper(self): for seq_length in (2, 3, 4): input_data = { - "encoder_features": tf.ones( + "encoder_features": ops.ones( (2, seq_length, 80), dtype="float32" ), - "decoder_token_ids": tf.ones((2, seq_length), dtype="int32"), - "decoder_padding_mask": tf.ones((2, seq_length), dtype="int32"), + "decoder_token_ids": ops.ones((2, seq_length), dtype="int32"), + "decoder_padding_mask": ops.ones( + (2, seq_length), dtype="int32" + ), } self.backbone(input_data) @@ -134,7 +138,7 @@ def setUp(self): ) self.input_batch = { - "encoder_features": tf.ones( + "encoder_features": ops.ones( ( 8, self.backbone.max_encoder_sequence_length, @@ -142,10 +146,10 @@ def setUp(self): ), dtype="int32", ), - "decoder_token_ids": tf.ones( + "decoder_token_ids": ops.ones( (8, self.backbone.max_decoder_sequence_length), dtype="int32" ), - "decoder_padding_mask": tf.ones( + "decoder_padding_mask": ops.ones( (8, self.backbone.max_decoder_sequence_length), dtype="int32" ), } diff --git a/keras_nlp/models/whisper/whisper_encoder.py b/keras_nlp/models/whisper/whisper_encoder.py index 7127121498..4527e3fa51 100644 --- a/keras_nlp/models/whisper/whisper_encoder.py +++ b/keras_nlp/models/whisper/whisper_encoder.py @@ -22,7 +22,7 @@ class WhisperEncoder(TransformerEncoder): """A Whisper encoder. Inherits from `keras_nlp.layers.TransformerEncoder`, and overrides the - `_build` method so as to remove the bias term from the key projection layer. + `build` method so as to remove the bias term from the key projection layer. """ def build(self, inputs_shape): diff --git a/keras_nlp/models/whisper/whisper_tokenizer_test.py b/keras_nlp/models/whisper/whisper_tokenizer_test.py index 6fcb28123c..d94b7ad79a 100644 --- a/keras_nlp/models/whisper/whisper_tokenizer_test.py +++ b/keras_nlp/models/whisper/whisper_tokenizer_test.py @@ -23,6 +23,7 @@ from keras_nlp.tests.test_case import TestCase +@pytest.mark.tf_only class WhisperTokenizerTest(TestCase): def setUp(self): self.vocab = { @@ -77,7 +78,7 @@ def test_tokenize_special_tokens(self): self.assertAllEqual(output, [9, 14, 12, 11, 0, 1, 2, 0, 3, 10]) def test_tokenize_batch(self): - input_data = tf.constant([" airplane at airport", " kohli is the best"]) + input_data = [" airplane at airport", " kohli is the best"] output = self.tokenizer(input_data) self.assertAllEqual(output, [[0, 1, 2, 0, 3], [4, 5, 6, 7, 8]]) @@ -113,6 +114,7 @@ def test_errors_missing_special_tokens(self): ) @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" airplane at airport"]) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone_test.py index 76d9a38f37..00387b8d46 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone from keras_nlp.tests.test_case import TestCase @@ -34,8 +35,8 @@ def setUp(self): max_sequence_length=5, ) self.input_batch = { - "token_ids": tf.ones((2, 5), dtype="int32"), - "padding_mask": tf.ones((2, 5), dtype="int32"), + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch @@ -55,12 +56,13 @@ def test_name(self): def test_variable_sequence_length_call_xlm_roberta(self): for seq_length in (2, 3, 4): input_data = { - "token_ids": tf.ones((2, seq_length), dtype="int32"), - "padding_mask": tf.ones((2, seq_length), dtype="int32"), + "token_ids": ops.ones((2, seq_length), dtype="int32"), + "padding_mask": ops.ones((2, seq_length), dtype="int32"), } output = self.backbone(input_data) self.assertAllEqual( - tf.shape(output), [2, seq_length, self.backbone.hidden_dim] + ops.shape(output), + (2, seq_length, self.backbone.hidden_dim), ) def test_predict(self): @@ -102,8 +104,8 @@ def setUp(self): max_sequence_length=128, ) self.input_batch = { - "token_ids": tf.ones((8, 128), dtype="int32"), - "padding_mask": tf.ones((8, 128), dtype="int32"), + "token_ids": ops.ones((8, 128), dtype="int32"), + "padding_mask": ops.ones((8, 128), dtype="int32"), } self.input_dataset = tf.data.Dataset.from_tensor_slices( self.input_batch diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py index 9365ac4b37..344d19b8a6 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py @@ -24,7 +24,6 @@ XLMRobertaPreprocessor, ) from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -199,7 +198,7 @@ def __init__( ), optimizer=keras.optimizers.Adam(5e-5), metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) def preprocess_samples(self, x, y=None, sample_weight=None): diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier_test.py index 72d7b98927..916b0caf57 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier_test.py @@ -21,6 +21,7 @@ import tensorflow as tf from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone from keras_nlp.models.xlm_roberta.xlm_roberta_classifier import ( XLMRobertaClassifier, @@ -70,15 +71,13 @@ def setUp(self): hidden_dim=4, ) - self.raw_batch = tf.constant( - [ - "the quick brown fox.", - "the slow brown fox.", - ] - ) + self.raw_batch = [ + "the quick brown fox.", + "the slow brown fox.", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch) self.raw_dataset = tf.data.Dataset.from_tensor_slices( - (self.raw_batch, tf.ones((2,))) + (self.raw_batch, ops.ones((2,))) ).batch(2) self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) @@ -92,7 +91,7 @@ def test_classifier_predict(self): # Assert predictions match. self.assertAllClose(preds1, preds2) # Assert valid softmax output. - self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0]) + self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0]) def test_classifier_fit(self): self.classifier.fit(self.raw_dataset) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py index 61b747c9d5..eeeb051366 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py @@ -25,7 +25,6 @@ XLMRobertaMaskedLMPreprocessor, ) from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets -from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -142,7 +141,7 @@ def __init__( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(5e-5), weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=is_xla_compatible(self), + jit_compile=True, ) @classproperty diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py index 06afc8621b..f8962b94db 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py @@ -151,17 +151,18 @@ def test_serialization(self): ) @pytest.mark.large + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant([" quick brown fox"]) inputs = keras.Input(dtype="string", shape=()) - outputs = self.preprocessor(inputs) + outputs, y, sw = self.preprocessor(inputs) model = keras.Model(inputs, outputs) path = os.path.join(self.get_temp_dir(), "model.keras") model.save(path, save_format="keras_v3") restored_model = keras.models.load_model(path) - outputs = model(input_data)[0]["token_ids"] - restored_outputs = restored_model(input_data)[0]["token_ids"] + outputs = model(input_data)["token_ids"] + restored_outputs = restored_model(input_data)["token_ids"] self.assertAllEqual(outputs, restored_outputs) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py index a46985c8a7..b0129f6707 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py @@ -77,9 +77,10 @@ def setUp(self): preprocessor=self.preprocessor, ) - self.raw_batch = tf.constant( - ["the quick brown fox", "the slow brown fox"] - ) + self.raw_batch = [ + "the quick brown fox", + "the slow brown fox", + ] self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] self.raw_dataset = tf.data.Dataset.from_tensor_slices( self.raw_batch diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py index 68f145d883..19049fdb56 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py @@ -141,6 +141,7 @@ def test_serialization(self): ) @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) inputs = keras.Input(dtype="string", shape=()) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_presets_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_presets_test.py index e1d74fbaea..1d914bbd60 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_presets_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_presets_test.py @@ -14,9 +14,9 @@ """Tests for loading pretrained model presets.""" import pytest -import tensorflow as tf from absl.testing import parameterized +from keras_nlp.backend import ops from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone from keras_nlp.models.xlm_roberta.xlm_roberta_classifier import ( XLMRobertaClassifier, @@ -61,8 +61,8 @@ def test_preprocessor_output(self): ) def test_backbone_output(self, load_weights): input_data = { - "token_ids": tf.constant([[0, 581, 63773, 2]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[0, 581, 63773, 2]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = XLMRobertaBackbone.from_preset( "xlm_roberta_base_multi", load_weights=load_weights @@ -77,7 +77,7 @@ def test_backbone_output(self, load_weights): ("preset_weights", True), ("random_weights", False) ) def test_classifier_output(self, load_weights): - input_data = tf.constant(["The quick brown fox."]) + input_data = ["The quick brown fox."] model = XLMRobertaClassifier.from_preset( "xlm_roberta_base_multi", num_classes=2, load_weights=load_weights ) @@ -89,8 +89,8 @@ def test_classifier_output(self, load_weights): ) def test_classifier_output_without_preprocessing(self, load_weights): input_data = { - "token_ids": tf.constant([[0, 581, 63773, 2]]), - "padding_mask": tf.constant([[1, 1, 1, 1]]), + "token_ids": ops.array([[0, 581, 63773, 2]]), + "padding_mask": ops.array([[1, 1, 1, 1]]), } model = XLMRobertaClassifier.from_preset( "xlm_roberta_base_multi", @@ -143,10 +143,10 @@ def test_load_xlm_roberta(self, load_weights): preset, load_weights=load_weights ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=model.vocabulary_size ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } model(input_data) @@ -160,7 +160,7 @@ def test_load_xlm_roberta_classifier(self, load_weights): num_classes=4, load_weights=load_weights, ) - input_data = tf.constant(["This quick brown fox"]) + input_data = ["The quick brown fox."] classifier.predict(input_data) @parameterized.named_parameters( @@ -177,12 +177,12 @@ def test_load_xlm_roberta_classifier_without_preprocessing( preprocessor=None, ) input_data = { - "token_ids": tf.random.uniform( + "token_ids": ops.random.uniform( shape=(1, 512), dtype="int64", maxval=classifier.backbone.vocabulary_size, ), - "padding_mask": tf.constant([1] * 512, shape=(1, 512)), + "padding_mask": ops.array([1] * 512, shape=(1, 512)), } classifier.predict(input_data) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py index f4a469615b..f81d81d75e 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py @@ -52,7 +52,7 @@ def test_tokenize(self): self.assertAllEqual(output, [4, 9, 5, 7]) def test_tokenize_batch(self): - input_data = tf.constant(["the quick brown fox", "the earth is round"]) + input_data = ["the quick brown fox", "the earth is round"] output = self.tokenizer(input_data) self.assertAllEqual(output, [[4, 9, 5, 7], [4, 6, 8, 10]]) @@ -63,9 +63,9 @@ def test_unk_token(self): self.assertAllEqual(output, [4, 9, 5, 7, 3]) def test_detokenize(self): - input_data = tf.constant([[4, 9, 5, 7]]) + input_data = [[4, 9, 5, 7]] output = self.tokenizer.detokenize(input_data) - self.assertEqual(output, tf.constant(["brown round earth is"])) + self.assertEqual(output, ["brown round earth is"]) def test_vocabulary(self): vocabulary = self.tokenizer.get_vocabulary() @@ -116,6 +116,7 @@ def test_serialization(self): ) @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.tf_only def test_saved_model(self): input_data = tf.constant(["the quick brown fox"]) diff --git a/keras_nlp/samplers/random_sampler.py b/keras_nlp/samplers/random_sampler.py index 6dcf30e828..baf13cf875 100644 --- a/keras_nlp/samplers/random_sampler.py +++ b/keras_nlp/samplers/random_sampler.py @@ -66,11 +66,15 @@ def __init__( ): super().__init__(**kwargs) self.seed = seed + self.seed_generator = random.SeedGenerator(seed) def get_next_token(self, probabilities): # Sample the next token from the probability distribution. next_token_id = random.categorical( - ops.log(probabilities), 1, seed=self.seed, dtype="int32" + ops.log(probabilities), + 1, + seed=self.seed, + dtype="int32", ) return ops.squeeze(next_token_id, axis=-1) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 46a31ee784..7ca4c9b10f 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -14,8 +14,10 @@ """Base sampler class.""" from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import config from keras_nlp.backend import keras from keras_nlp.backend import ops +from keras_nlp.backend import random from keras_nlp.utils.python_utils import format_docstring call_args_docstring = """next: A function which takes in the @@ -93,6 +95,22 @@ def __init__( temperature=1.0, ): self.temperature = temperature + self._seed_generators = [] + + def __setattr__(self, name, value): + # We could update to the `Tracker` class from keras-core if our needs + # become more advanced (e.g. list assignment, nested trackables). For + # now, we only track `SeedGenerator` instances directly on the sampler. + if isinstance(value, random.SeedGenerator): + self._seed_generators.append(value) + return super().__setattr__(name, value) + + @property + def variables(self): + variables = [] + for sg in self._seed_generators: + variables.append(sg.state) + return variables def __call__( self, @@ -135,17 +153,53 @@ def body(prompt, cache, index): # Update the prompt with the next token. next_token = next_token[:, None] prompt = ops.slice_update(prompt, [0, index], next_token) + # Return the next prompt, cache and incremented index. return (prompt, cache, index + 1) - prompt, _, _ = ops.while_loop( - cond=cond, - body=body, + prompt, _, _ = self.run_loop( + cond, + body, loop_vars=(prompt, cache, index), maximum_iterations=(max_length - index), ) return prompt + def run_loop(self, cond, body, loop_vars=None, maximum_iterations=None): + """Run ops.while_loops with a `StatelessScope` if necessary.""" + if config.backend() == "jax": + + def stateless_cond(variables, *loop_vars): + return cond(*loop_vars) + + def stateless_body(variables, *loop_vars): + mapping = zip(self.variables, variables) + with keras.StatelessScope(state_mapping=mapping) as scope: + loop_vars = body(*loop_vars) + + variables = [] + for v in self.variables: + new_v = scope.get_current_value(v) + variables.append(new_v if new_v is not None else v) + return variables, *loop_vars + + variables = [ops.convert_to_tensor(v) for v in self.variables] + variables, *loop_vars = ops.while_loop( + cond=stateless_cond, + body=stateless_body, + loop_vars=(variables, *loop_vars), + maximum_iterations=maximum_iterations, + ) + [ref_v.assign(v) for ref_v, v in zip(self.variables, variables)] + else: + loop_vars = ops.while_loop( + cond=cond, + body=body, + loop_vars=(loop_vars), + maximum_iterations=maximum_iterations, + ) + return loop_vars + def get_next_token(self, probabilities): """Get the next token. Args: diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 355d216c6c..39d50d2ad4 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -69,6 +69,7 @@ def __init__( super().__init__(**kwargs) self.k = k self.seed = seed + self.seed_generator = random.SeedGenerator(seed) def get_next_token(self, probabilities): # Filter out top-k tokens. @@ -81,7 +82,7 @@ def get_next_token(self, probabilities): sample_indices = random.categorical( ops.log(top_k_pred), 1, - seed=self.seed, + seed=self.seed_generator, dtype="int32", ) diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index d1c5350c49..8ff3147b72 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -79,6 +79,7 @@ def __init__( self.p = p self.k = k self.seed = seed + self.seed_generator = random.SeedGenerator(seed) def get_next_token(self, probabilities): cutoff = ops.shape(probabilities)[1] @@ -105,7 +106,7 @@ def get_next_token(self, probabilities): sorted_next_token = random.categorical( ops.log(probabilities), 1, - seed=self.seed, + seed=self.seed_generator, dtype="int32", ) output = ops.take_along_axis(sorted_indices, sorted_next_token, axis=-1) diff --git a/keras_nlp/tests/doc_tests/docstring_test.py b/keras_nlp/tests/doc_tests/docstring_test.py index 486ecf9389..a484b9f1aa 100644 --- a/keras_nlp/tests/doc_tests/docstring_test.py +++ b/keras_nlp/tests/doc_tests/docstring_test.py @@ -48,6 +48,7 @@ def docstring_module(pytestconfig): return pytestconfig.getoption("docstring_module") +@pytest.mark.tf_only def test_docstrings(docstring_module): keras_nlp_modules = find_modules() # As of this writing, it doesn't seem like pytest support load_tests @@ -86,6 +87,7 @@ def test_docstrings(docstring_module): assert result.wasSuccessful() +@pytest.mark.tf_only @pytest.mark.extra_large @pytest.mark.skipif( astor is None, diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index 60ff401bbe..1f759f8945 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -277,7 +277,7 @@ class WordPieceTokenizer(tokenizer.Tokenizer): Custom splitting. >>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."] - >>> inputs = ["The$quick$brown$fox"] + >>> inputs = "The$quick$brown$fox" >>> tokenizer = keras_nlp.tokenizers.WordPieceTokenizer( ... vocabulary=vocab, ... split=False, @@ -285,8 +285,9 @@ class WordPieceTokenizer(tokenizer.Tokenizer): ... dtype='string', ... ) >>> split_inputs = tf.strings.split(inputs, sep="$") - >>> tokenizer(split_inputs) - + >>> outputs = tokenizer(split_inputs) + >>> np.array(outputs).astype("U") + array(['the', 'qu', '##ick', 'br', '##own', 'fox'], dtype='