diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml
index 28ee6ef64b..d2b9e6d446 100644
--- a/.github/workflows/actions.yml
+++ b/.github/workflows/actions.yml
@@ -102,7 +102,7 @@ jobs:
env:
KERAS_BACKEND: ${{ matrix.backend }}
run: |
- pytest --run_large keras_nlp/layers/modeling keras_nlp/samplers keras_nlp/tokenizers keras_nlp/metrics
+ pytest keras_nlp/
format:
name: Check the code format
runs-on: ubuntu-latest
diff --git a/keras_nlp/conftest.py b/keras_nlp/conftest.py
index 5f74c45c53..3a29b038c5 100644
--- a/keras_nlp/conftest.py
+++ b/keras_nlp/conftest.py
@@ -122,7 +122,12 @@ def pytest_collection_modifyitems(config, items):
item.add_marker(skip_tf_only)
+# Disable traceback filtering for quicker debugging of tests failures.
+tf.debugging.disable_traceback_filtering()
if backend_config.multi_backend():
keras.config.disable_traceback_filtering()
-tf.debugging.disable_traceback_filtering()
+# One off setup for dtensor tests.
+if not backend_config.multi_backend():
+ keras.backend.experimental.enable_tf_random_generator()
+ keras.utils.set_random_seed(1337)
diff --git a/keras_nlp/metrics/rouge_l.py b/keras_nlp/metrics/rouge_l.py
index 201c31a707..ffa62bbbef 100644
--- a/keras_nlp/metrics/rouge_l.py
+++ b/keras_nlp/metrics/rouge_l.py
@@ -102,14 +102,13 @@ class RougeL(RougeBase):
3. Pass the metric to `model.compile()`.
>>> inputs = keras.Input(shape=(), dtype='string')
- >>> outputs = tf.strings.lower(inputs)
+ >>> outputs = keras.layers.Identity()(inputs)
>>> model = keras.Model(inputs, outputs)
>>> model.compile(metrics=[keras_nlp.metrics.RougeL()])
- >>> x = tf.constant(["HELLO THIS IS FUN"])
+ >>> y_pred = x = tf.constant(["hello this is fun"])
>>> y = tf.constant(["hello this is awesome"])
- >>> metric_dict = model.evaluate(x, y, return_dict=True)
- >>> metric_dict["f1_score"]
- 0.75
+ >>> model.compute_metrics(x, y, y_pred, sample_weight=None)["f1_score"]
+ 0.75
"""
def __init__(
diff --git a/keras_nlp/metrics/rouge_n.py b/keras_nlp/metrics/rouge_n.py
index 46a5f528c9..d90122fa8c 100644
--- a/keras_nlp/metrics/rouge_n.py
+++ b/keras_nlp/metrics/rouge_n.py
@@ -121,13 +121,12 @@ class RougeN(RougeBase):
3. Pass the metric to `model.compile()`.
>>> inputs = keras.Input(shape=(), dtype='string')
- >>> outputs = tf.strings.lower(inputs)
+ >>> outputs = keras.layers.Identity()(inputs)
>>> model = keras.Model(inputs, outputs)
>>> model.compile(metrics=[keras_nlp.metrics.RougeN()])
- >>> x = tf.constant(["HELLO THIS IS FUN"])
+ >>> y_pred = x = tf.constant(["hello this is fun"])
>>> y = tf.constant(["hello this is awesome"])
- >>> metric_dict = model.evaluate(x, y, return_dict=True)
- >>> metric_dict["f1_score"]
+ >>> model.compute_metrics(x, y, y_pred, sample_weight=None)["f1_score"]
0.6666666865348816
"""
diff --git a/keras_nlp/models/albert/albert_backbone_test.py b/keras_nlp/models/albert/albert_backbone_test.py
index 94d0f26297..a6bfdc18e1 100644
--- a/keras_nlp/models/albert/albert_backbone_test.py
+++ b/keras_nlp/models/albert/albert_backbone_test.py
@@ -19,6 +19,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.tests.test_case import TestCase
@@ -38,9 +39,9 @@ def setUp(self):
)
self.batch_size = 8
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "segment_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "segment_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
@@ -57,9 +58,9 @@ def test_name(self):
def test_variable_sequence_length_call_albert(self):
for seq_length in (2, 3, 4):
input_data = {
- "token_ids": tf.ones((2, seq_length), dtype="int32"),
- "segment_ids": tf.ones((2, seq_length), dtype="int32"),
- "padding_mask": tf.ones((2, seq_length), dtype="int32"),
+ "token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "segment_ids": ops.ones((2, seq_length), dtype="int32"),
+ "padding_mask": ops.ones((2, seq_length), dtype="int32"),
}
self.backbone(input_data)
@@ -121,9 +122,9 @@ def setUp(self):
)
self.input_batch = {
- "token_ids": tf.ones((8, 128), dtype="int32"),
- "segment_ids": tf.ones((8, 128), dtype="int32"),
- "padding_mask": tf.ones((8, 128), dtype="int32"),
+ "token_ids": ops.ones((8, 128), dtype="int32"),
+ "segment_ids": ops.ones((8, 128), dtype="int32"),
+ "padding_mask": ops.ones((8, 128), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
diff --git a/keras_nlp/models/albert/albert_classifier.py b/keras_nlp/models/albert/albert_classifier.py
index 53a51b0ac5..fe2bd7c7ff 100644
--- a/keras_nlp/models/albert/albert_classifier.py
+++ b/keras_nlp/models/albert/albert_classifier.py
@@ -22,7 +22,6 @@
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
from keras_nlp.models.albert.albert_presets import backbone_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -192,7 +191,7 @@ def __init__(
),
optimizer=keras.optimizers.Adam(5e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
def get_config(self):
diff --git a/keras_nlp/models/albert/albert_classifier_test.py b/keras_nlp/models/albert/albert_classifier_test.py
index f22c921a50..0dd09d32d9 100644
--- a/keras_nlp/models/albert/albert_classifier_test.py
+++ b/keras_nlp/models/albert/albert_classifier_test.py
@@ -21,6 +21,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.models.albert.albert_classifier import AlbertClassifier
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
@@ -77,15 +78,13 @@ def setUp(self):
activation=keras.activations.softmax,
)
- self.raw_batch = tf.constant(
- [
- "the quick brown fox.",
- "the slow brown fox.",
- ]
- )
+ self.raw_batch = [
+ "the quick brown fox.",
+ "the slow brown fox.",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
- (self.raw_batch, tf.ones((2,)))
+ (self.raw_batch, ops.ones((2,)))
).batch(2)
self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
@@ -99,7 +98,7 @@ def test_classifier_predict(self):
# Assert predictions match.
self.assertAllClose(preds1, preds2)
# Assert valid softmax output.
- self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0])
+ self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0])
def test_classifier_fit(self):
self.classifier.fit(self.raw_dataset)
diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py
index 71bdd7721d..9843282c5a 100644
--- a/keras_nlp/models/albert/albert_masked_lm.py
+++ b/keras_nlp/models/albert/albert_masked_lm.py
@@ -26,7 +26,6 @@
)
from keras_nlp.models.albert.albert_presets import backbone_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -135,7 +134,7 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
@classproperty
diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py
index a06cb9889b..522e7fcdda 100644
--- a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py
+++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py
@@ -152,6 +152,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py
index 6ba68245dd..0e8869613d 100644
--- a/keras_nlp/models/albert/albert_masked_lm_test.py
+++ b/keras_nlp/models/albert/albert_masked_lm_test.py
@@ -85,14 +85,12 @@ def setUp(self):
preprocessor=None,
)
- self.raw_batch = tf.constant(
- [
- "quick brown fox",
- "eagle flew over fox",
- "the eagle flew quick",
- "a brown eagle",
- ]
- )
+ self.raw_batch = [
+ "quick brown fox",
+ "eagle flew over fox",
+ "the eagle flew quick",
+ "a brown eagle",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)[0]
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
self.raw_batch
diff --git a/keras_nlp/models/albert/albert_preprocessor_test.py b/keras_nlp/models/albert/albert_preprocessor_test.py
index 5234fb0047..14f6581b03 100644
--- a/keras_nlp/models/albert/albert_preprocessor_test.py
+++ b/keras_nlp/models/albert/albert_preprocessor_test.py
@@ -166,6 +166,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
inputs = keras.Input(dtype="string", shape=())
diff --git a/keras_nlp/models/albert/albert_presets_test.py b/keras_nlp/models/albert/albert_presets_test.py
index 535091f0c6..09904246a2 100644
--- a/keras_nlp/models/albert/albert_presets_test.py
+++ b/keras_nlp/models/albert/albert_presets_test.py
@@ -14,9 +14,9 @@
"""Tests for loading pretrained model presets."""
import pytest
-import tensorflow as tf
from absl.testing import parameterized
+from keras_nlp.backend import ops
from keras_nlp.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.models.albert.albert_classifier import AlbertClassifier
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
@@ -53,7 +53,7 @@ def test_preprocessor_output(self):
("load_weights", True), ("no_load_weights", False)
)
def test_classifier_output(self, load_weights):
- input_data = tf.constant(["The quick brown fox."])
+ input_data = ["The quick brown fox."]
model = AlbertClassifier.from_preset(
"albert_base_en_uncased",
num_classes=2,
@@ -67,9 +67,9 @@ def test_classifier_output(self, load_weights):
)
def test_classifier_output_without_preprocessing(self, load_weights):
input_data = {
- "token_ids": tf.constant([[101, 1996, 4248, 102]]),
- "segment_ids": tf.constant([[0, 0, 0, 0]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[101, 1996, 4248, 102]]),
+ "segment_ids": ops.array([[0, 0, 0, 0]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = AlbertClassifier.from_preset(
"albert_base_en_uncased",
@@ -85,9 +85,9 @@ def test_classifier_output_without_preprocessing(self, load_weights):
)
def test_backbone_output(self, load_weights):
input_data = {
- "token_ids": tf.constant([[2, 13, 1, 3]]),
- "segment_ids": tf.constant([[0, 0, 0, 0]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[2, 13, 1, 3]]),
+ "segment_ids": ops.array([[0, 0, 0, 0]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = AlbertBackbone.from_preset(
"albert_base_en_uncased", load_weights=load_weights
@@ -139,13 +139,11 @@ def test_load_albert(self, load_weights):
preset, load_weights=load_weights
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512), dtype="int64", maxval=model.vocabulary_size
),
- "segment_ids": tf.constant(
- [0] * 200 + [1] * 312, shape=(1, 512)
- ),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
model(input_data)
@@ -159,7 +157,7 @@ def test_load_albert_classifier(self, load_weights):
num_classes=2,
load_weights=load_weights,
)
- input_data = tf.constant(["This quick brown fox"])
+ input_data = ["This quick brown fox."]
classifier.predict(input_data)
@parameterized.named_parameters(
@@ -174,15 +172,13 @@ def test_load_albert_classifier_without_preprocessing(self, load_weights):
load_weights=load_weights,
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512),
dtype="int64",
maxval=classifier.backbone.vocabulary_size,
),
- "segment_ids": tf.constant(
- [0] * 200 + [1] * 312, shape=(1, 512)
- ),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
classifier.predict(input_data)
diff --git a/keras_nlp/models/albert/albert_tokenizer_test.py b/keras_nlp/models/albert/albert_tokenizer_test.py
index 6e9decbc0d..010b1a46b9 100644
--- a/keras_nlp/models/albert/albert_tokenizer_test.py
+++ b/keras_nlp/models/albert/albert_tokenizer_test.py
@@ -56,14 +56,14 @@ def test_tokenize(self):
self.assertAllEqual(output, [5, 10, 6, 8])
def test_tokenize_batch(self):
- input_data = tf.constant(["the quick brown fox", "the earth is round"])
+ input_data = ["the quick brown fox", "the earth is round"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[5, 10, 6, 8], [5, 7, 9, 11]])
def test_detokenize(self):
- input_data = tf.constant([[5, 10, 6, 8]])
+ input_data = [[5, 10, 6, 8]]
output = self.tokenizer.detokenize(input_data)
- self.assertEqual(output, tf.constant(["the quick brown fox"]))
+ self.assertEqual(output, ["the quick brown fox"])
def test_vocabulary_size(self):
tokenizer = AlbertTokenizer(proto=self.proto)
@@ -91,6 +91,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
diff --git a/keras_nlp/models/bart/bart_backbone_test.py b/keras_nlp/models/bart/bart_backbone_test.py
index 206facf9f6..e0782d4d65 100644
--- a/keras_nlp/models/bart/bart_backbone_test.py
+++ b/keras_nlp/models/bart/bart_backbone_test.py
@@ -17,86 +17,70 @@
import pytest
import tensorflow as tf
-from absl.testing import parameterized
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.bart.bart_backbone import BartBackbone
from keras_nlp.tests.test_case import TestCase
class BartBackboneTest(TestCase):
def setUp(self):
- self.model = BartBackbone(
- vocabulary_size=1000,
+ self.backbone = BartBackbone(
+ vocabulary_size=10,
num_layers=2,
num_heads=2,
- hidden_dim=64,
- intermediate_dim=128,
- max_sequence_length=128,
+ hidden_dim=3,
+ intermediate_dim=4,
+ max_sequence_length=5,
)
- self.batch_size = 8
self.input_batch = {
- "encoder_token_ids": tf.ones(
- (self.batch_size, self.model.max_sequence_length), dtype="int32"
- ),
- "encoder_padding_mask": tf.ones(
- (self.batch_size, self.model.max_sequence_length), dtype="int32"
- ),
- "decoder_token_ids": tf.ones(
- (self.batch_size, self.model.max_sequence_length), dtype="int32"
- ),
- "decoder_padding_mask": tf.ones(
- (self.batch_size, self.model.max_sequence_length), dtype="int32"
- ),
+ "encoder_token_ids": ops.ones((2, 5), dtype="int32"),
+ "encoder_padding_mask": ops.ones((2, 5), dtype="int32"),
+ "decoder_token_ids": ops.ones((2, 5), dtype="int32"),
+ "decoder_padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
).batch(2)
- def test_valid_call_bart(self):
- self.model(self.input_batch)
+ def test_valid_call(self):
+ self.backbone(self.input_batch)
+ def test_name(self):
# Check default name passed through
- self.assertRegexpMatches(self.model.name, "bart_backbone")
+ self.assertRegexpMatches(self.backbone.name, "bart_backbone")
- def test_variable_sequence_length_call_bart(self):
- for seq_length in (25, 50, 75):
+ def test_variable_sequence_length_call(self):
+ for seq_length in (2, 3, 4):
input_data = {
- "encoder_token_ids": tf.ones(
- (self.batch_size, seq_length), dtype="int32"
+ "encoder_token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "encoder_padding_mask": ops.ones(
+ (2, seq_length), dtype="int32"
),
- "encoder_padding_mask": tf.ones(
- (self.batch_size, seq_length), dtype="int32"
- ),
- "decoder_token_ids": tf.ones(
- (self.batch_size, seq_length), dtype="int32"
- ),
- "decoder_padding_mask": tf.ones(
- (self.batch_size, seq_length), dtype="int32"
+ "decoder_token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "decoder_padding_mask": ops.ones(
+ (2, seq_length), dtype="int32"
),
}
- self.model(input_data)
-
- @parameterized.named_parameters(
- ("jit_compile_false", False), ("jit_compile_true", True)
- )
- def test_compile(self, jit_compile):
- self.model.compile(jit_compile=jit_compile)
- self.model.predict(self.input_batch)
-
- @parameterized.named_parameters(
- ("jit_compile_false", False), ("jit_compile_true", True)
- )
- def test_compile_batched_ds(self, jit_compile):
- self.model.compile(jit_compile=jit_compile)
- self.model.predict(self.input_dataset)
+ self.backbone(input_data)
+
+ def test_predict(self):
+ self.backbone.predict(self.input_batch)
+ self.backbone.predict(self.input_dataset)
+
+ def test_serialization(self):
+ new_backbone = keras.saving.deserialize_keras_object(
+ keras.saving.serialize_keras_object(self.backbone)
+ )
+ self.assertEqual(new_backbone.get_config(), self.backbone.get_config())
@pytest.mark.large
def test_saved_model(self):
model_output = self.backbone(self.input_batch)
path = os.path.join(self.get_temp_dir(), "model.keras")
- self.model.save(path, save_format="keras_v3")
+ self.backbone.save(path, save_format="keras_v3")
restored_model = keras.models.load_model(path)
# Check we got the real object back.
@@ -119,7 +103,7 @@ def test_saved_model(self):
class BartBackboneTPUTest(TestCase):
def setUp(self):
with self.tpu_strategy.scope():
- self.model = BartBackbone(
+ self.backbone = BartBackbone(
vocabulary_size=1000,
num_layers=2,
num_heads=2,
@@ -128,15 +112,15 @@ def setUp(self):
max_sequence_length=128,
)
self.input_batch = {
- "encoder_token_ids": tf.ones((8, 128), dtype="int32"),
- "encoder_padding_mask": tf.ones((8, 128), dtype="int32"),
- "decoder_token_ids": tf.ones((8, 128), dtype="int32"),
- "decoder_padding_mask": tf.ones((8, 128), dtype="int32"),
+ "encoder_token_ids": ops.ones((8, 128), dtype="int32"),
+ "encoder_padding_mask": ops.ones((8, 128), dtype="int32"),
+ "decoder_token_ids": ops.ones((8, 128), dtype="int32"),
+ "decoder_padding_mask": ops.ones((8, 128), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
).batch(2)
def test_predict(self):
- self.model.compile()
- self.model.predict(self.input_dataset)
+ self.backbone.compile()
+ self.backbone.predict(self.input_dataset)
diff --git a/keras_nlp/models/bart/bart_preprocessor_test.py b/keras_nlp/models/bart/bart_preprocessor_test.py
index 3849792284..866b2eb537 100644
--- a/keras_nlp/models/bart/bart_preprocessor_test.py
+++ b/keras_nlp/models/bart/bart_preprocessor_test.py
@@ -168,6 +168,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = {
"encoder_text": tf.constant([" airplane at airport"]),
diff --git a/keras_nlp/models/bart/bart_presets_test.py b/keras_nlp/models/bart/bart_presets_test.py
index 8c8b837289..324e840178 100644
--- a/keras_nlp/models/bart/bart_presets_test.py
+++ b/keras_nlp/models/bart/bart_presets_test.py
@@ -13,6 +13,7 @@
# limitations under the License.
# Copyright 2023 The KerasNLP Authors
#
+from keras_nlp.backend import ops
from keras_nlp.tests.test_case import TestCase
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -29,7 +30,6 @@
"""Tests for loading pretrained model presets."""
import pytest
-import tensorflow as tf
from absl.testing import parameterized
from keras_nlp.models.bart.bart_backbone import BartBackbone
@@ -58,10 +58,10 @@ def test_tokenizer_output(self):
)
def test_backbone_output(self, load_weights):
input_data = {
- "encoder_token_ids": tf.constant([[0, 133, 2119, 2]]),
- "encoder_padding_mask": tf.constant([[1, 1, 1, 1]]),
- "decoder_token_ids": tf.constant([[0, 7199, 14, 2119, 2]]),
- "decoder_padding_mask": tf.constant([[1, 1, 1, 1, 1]]),
+ "encoder_token_ids": ops.array([[0, 133, 2119, 2]]),
+ "encoder_padding_mask": ops.array([[1, 1, 1, 1]]),
+ "decoder_token_ids": ops.array([[0, 7199, 14, 2119, 2]]),
+ "decoder_padding_mask": ops.array([[1, 1, 1, 1, 1]]),
}
model = BartBackbone.from_preset(
"bart_base_en", load_weights=load_weights
@@ -116,20 +116,20 @@ def test_load_bart(self, load_weights):
for preset in BartBackbone.presets:
model = BartBackbone.from_preset(preset, load_weights=load_weights)
input_data = {
- "encoder_token_ids": tf.random.uniform(
+ "encoder_token_ids": ops.random.uniform(
shape=(1, 1024),
dtype="int64",
maxval=model.vocabulary_size,
),
- "encoder_padding_mask": tf.constant(
+ "encoder_padding_mask": ops.array(
[1] * 768 + [0] * 256, shape=(1, 1024)
),
- "decoder_token_ids": tf.random.uniform(
+ "decoder_token_ids": ops.random.uniform(
shape=(1, 1024),
dtype="int64",
maxval=model.vocabulary_size,
),
- "decoder_padding_mask": tf.constant(
+ "decoder_padding_mask": ops.array(
[1] * 489 + [0] * 535, shape=(1, 1024)
),
}
diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm.py b/keras_nlp/models/bart/bart_seq_2_seq_lm.py
index 20119e67fe..50868c4810 100644
--- a/keras_nlp/models/bart/bart_seq_2_seq_lm.py
+++ b/keras_nlp/models/bart/bart_seq_2_seq_lm.py
@@ -15,20 +15,32 @@
import copy
-import tensorflow as tf
-
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.bart.bart_backbone import BartBackbone
from keras_nlp.models.bart.bart_presets import backbone_presets
from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import (
BartSeq2SeqLMPreprocessor,
)
from keras_nlp.models.generative_task import GenerativeTask
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
+# TODO: Extend and factor this out into keras_nlp.layers.
+class ReverseEmbedding(keras.layers.Layer):
+ def __init__(self, embedding, **kwargs):
+ super().__init__(**kwargs)
+ self.embedding = embedding
+
+ def call(self, inputs):
+ kernel = ops.transpose(ops.convert_to_tensor(self.embedding.embeddings))
+ return ops.matmul(inputs, kernel)
+
+ def compute_output_shape(self, input_shape):
+ return (input_shape[0],) + (self.embedding.embeddings.shape[0],)
+
+
@keras_nlp_export("keras_nlp.models.BartSeq2SeqLM")
class BartSeq2SeqLM(GenerativeTask):
"""An end-to-end BART model for seq2seq language modeling.
@@ -192,11 +204,10 @@ def __init__(
x = backbone(inputs)["decoder_sequence_output"]
# Use token embedding weights to project from the token representation
# to vocabulary logits.
- outputs = tf.matmul(
- x,
- backbone.token_embedding.embeddings,
- transpose_b=True,
- )
+ outputs = ReverseEmbedding(
+ backbone.token_embedding,
+ name="reverse_embedding",
+ )(x)
# Instantiate using Functional API Model constructor.
super().__init__(
@@ -216,7 +227,7 @@ def __init__(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(2e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
@classproperty
@@ -305,11 +316,11 @@ def call_decoder_with_cache(
# Every decoder layer has a separate cache for the self-attention layer
# and the cross-attention layer. We update all of them separately.
- self_attention_caches = tf.unstack(self_attention_cache, axis=1)
- cross_attention_caches = tf.unstack(cross_attention_cache, axis=1)
+ self_attention_caches = []
+ cross_attention_caches = []
for i in range(self.backbone.num_layers):
- current_self_attention_cache = self_attention_caches[i]
- current_cross_attention_cache = cross_attention_caches[i]
+ current_self_attention_cache = self_attention_cache[:, i, ...]
+ current_cross_attention_cache = cross_attention_cache[:, i, ...]
(
x,
@@ -326,22 +337,17 @@ def call_decoder_with_cache(
)
if self_attention_cache_update_index is not None:
- self_attention_caches[i] = next_self_attention_cache
+ self_attention_caches.append(next_self_attention_cache)
if cross_attention_cache_update_index is not None:
- cross_attention_caches[i] = next_cross_attention_cache
+ cross_attention_caches.append(next_cross_attention_cache)
if self_attention_cache_update_index is not None:
- self_attention_cache = tf.stack(self_attention_caches, axis=1)
+ self_attention_cache = ops.stack(self_attention_caches, axis=1)
if cross_attention_cache_update_index is not None:
- cross_attention_cache = tf.stack(cross_attention_caches, axis=1)
+ cross_attention_cache = ops.stack(cross_attention_caches, axis=1)
hidden_states = x
-
- logits = tf.matmul(
- hidden_states,
- self.backbone.get_layer("token_embedding").embeddings,
- transpose_b=True,
- )
+ logits = self.get_layer("reverse_embedding")(x)
return (
logits,
hidden_states,
@@ -375,9 +381,9 @@ def call_encoder(self, token_ids, padding_mask):
def _initialize_cache(self, encoder_token_ids, decoder_token_ids):
"""Initializes empty self-attention cache and cross-attention cache."""
- batch_size = tf.shape(encoder_token_ids)[0]
- encoder_max_length = tf.shape(encoder_token_ids)[1]
- decoder_max_length = tf.shape(decoder_token_ids)[1]
+ batch_size = ops.shape(encoder_token_ids)[0]
+ encoder_max_length = ops.shape(encoder_token_ids)[1]
+ decoder_max_length = ops.shape(decoder_token_ids)[1]
num_layers = self.backbone.num_layers
num_heads = self.backbone.num_heads
@@ -391,10 +397,10 @@ def _initialize_cache(self, encoder_token_ids, decoder_token_ids):
num_heads,
head_dim,
]
- self_attention_cache = tf.zeros(shape, dtype=self.compute_dtype)
+ self_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
shape[3] = encoder_max_length
- cross_attention_cache = tf.zeros(shape, dtype=self.compute_dtype)
+ cross_attention_cache = ops.zeros(shape, dtype=self.compute_dtype)
return (self_attention_cache, cross_attention_cache)
@@ -464,7 +470,7 @@ def generate_step(
inputs["decoder_padding_mask"],
)
- batch_size = tf.shape(encoder_token_ids)[0]
+ batch_size = ops.shape(encoder_token_ids)[0]
# Create and seed cache with a single forward pass.
(
@@ -476,24 +482,21 @@ def generate_step(
encoder_token_ids, encoder_padding_mask, decoder_token_ids
)
# Compute the lengths of all user inputted tokens ids.
- row_lengths = tf.math.reduce_sum(
- tf.cast(decoder_padding_mask, "int32"), axis=-1
- )
+ row_lengths = ops.sum(ops.cast(decoder_padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
- index = tf.math.reduce_min(row_lengths)
+ index = ops.min(row_lengths)
def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_index = index - 1
- prompt = tf.slice(prompt, [0, cache_index], [-1, 1])
-
- num_samples = tf.shape(prompt)[0]
+ num_samples = ops.shape(prompt)[0]
+ prompt = ops.slice(prompt, [0, cache_index], [num_samples, 1])
def repeat_tensor(x):
"""Repeats tensors along batch axis to match dim for beam search."""
- if tf.shape(x)[0] == num_samples:
+ if ops.shape(x)[0] == num_samples:
return x
- return tf.repeat(x, repeats=num_samples // batch_size, axis=0)
+ return ops.repeat(x, repeats=num_samples // batch_size, axis=0)
logits, hidden_states, cache, _ = self.call_decoder_with_cache(
encoder_hidden_states=repeat_tensor(encoder_hidden_states),
@@ -505,8 +508,8 @@ def repeat_tensor(x):
cross_attention_cache_update_index=None,
)
return (
- tf.squeeze(logits, axis=1),
- tf.squeeze(hidden_states, axis=1),
+ ops.squeeze(logits, axis=1),
+ ops.squeeze(hidden_states, axis=1),
cache,
)
@@ -527,14 +530,17 @@ def repeat_tensor(x):
end_locations = (decoder_token_ids == end_token_id) & (
~decoder_padding_mask
)
- end_locations = tf.cast(end_locations, "int32")
+ end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after `end_locations`.
- overflow = tf.math.cumsum(end_locations, exclusive=True, axis=-1)
+ cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
+ overflow = cumsum - end_locations
# Our padding mask is the inverse of these overflow locations.
- decoder_padding_mask = ~tf.cast(overflow, "bool")
+ decoder_padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
else:
# Without early stopping, all locations will have been updated.
- decoder_padding_mask = tf.ones_like(decoder_token_ids, dtype="bool")
+ decoder_padding_mask = ops.ones_like(
+ decoder_token_ids, dtype="bool"
+ )
return {
"decoder_token_ids": decoder_token_ids,
diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py
index 034a10cd9f..d54a83984c 100644
--- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py
+++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor_test.py
@@ -156,6 +156,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = {
"encoder_text": tf.constant([" airplane at airport"]),
@@ -163,8 +164,12 @@ def test_saved_model(self):
}
inputs = {
- "encoder_text": keras.Input(dtype="string", shape=()),
- "decoder_text": keras.Input(dtype="string", shape=()),
+ "encoder_text": keras.Input(
+ dtype="string", name="encoder_text", shape=()
+ ),
+ "decoder_text": keras.Input(
+ dtype="string", name="decoder_text", shape=()
+ ),
}
outputs, y, sw = self.preprocessor(inputs)
model = keras.Model(inputs=inputs, outputs=outputs)
diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py
index 87f448c2e4..a3b9db4bd0 100644
--- a/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py
+++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_test.py
@@ -16,11 +16,11 @@
import os
from unittest.mock import patch
-import numpy as np
import pytest
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.bart.bart_backbone import BartBackbone
from keras_nlp.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM
from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import (
@@ -32,10 +32,6 @@
class BartSeq2SeqLMTest(TestCase):
def setUp(self):
- # For DTensor.
- keras.backend.experimental.enable_tf_random_generator()
- keras.utils.set_random_seed(1337)
-
self.vocab = {
"": 0,
"": 1,
@@ -75,12 +71,8 @@ def setUp(self):
)
self.raw_batch = {
- "encoder_text": tf.constant(
- [" airplane at airport", " airplane at airport"]
- ),
- "decoder_text": tf.constant(
- [" kohli is the best", " kohli is the best"]
- ),
+ "encoder_text": [" airplane at airport", " airplane at airport"],
+ "decoder_text": [" kohli is the best", " kohli is the best"],
}
self.preprocessed_batch = self.preprocessor(self.raw_batch)[0]
@@ -171,8 +163,10 @@ def wrapper(*args, **kwargs):
self_attention_cache,
cross_attention_cache,
) = call_decoder_with_cache(*args, **kwargs)
- logits = np.zeros(logits.shape.as_list())
- logits[:, :, self.preprocessor.tokenizer.end_token_id] = 1.0e9
+ index = self.preprocessor.tokenizer.end_token_id
+ update = ops.ones_like(logits)[:, :, index] * 1.0e9
+ update = ops.expand_dims(update, axis=-1)
+ logits = ops.slice_update(logits, (0, 0, index), update)
return (
logits,
hidden_states,
@@ -194,10 +188,9 @@ def wrapper(*args, **kwargs):
# We should immediately abort and output the prompt.
self.assertAllEqual(inputs["decoder_text"], output)
- self.assertEqual(
- self.seq_2_seq_lm.call_decoder_with_cache.call_count, 2
- )
+ # TODO: fix beam search.
+ @pytest.mark.tf_only
def test_beam_search(self):
seq_2_seq_lm = BartSeq2SeqLM(
backbone=self.backbone,
diff --git a/keras_nlp/models/bart/bart_tokenizer_test.py b/keras_nlp/models/bart/bart_tokenizer_test.py
index f93849b898..3f750dd084 100644
--- a/keras_nlp/models/bart/bart_tokenizer_test.py
+++ b/keras_nlp/models/bart/bart_tokenizer_test.py
@@ -58,7 +58,7 @@ def test_tokenize_special_tokens(self):
self.assertAllEqual(output, [0, 3, 4, 5, 3, 6, 0, 1])
def test_tokenize_batch(self):
- input_data = tf.constant([" airplane at airport", " kohli is the best"])
+ input_data = [" airplane at airport", " kohli is the best"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[3, 4, 5, 3, 6], [7, 8, 9, 10, 11]])
@@ -74,7 +74,16 @@ def test_errors_missing_special_tokens(self):
with self.assertRaises(ValueError):
BartTokenizer(vocabulary=["a", "b", "c"], merges=[])
+ def test_serialization(self):
+ config = keras.saving.serialize_keras_object(self.tokenizer)
+ new_tokenizer = keras.saving.deserialize_keras_object(config)
+ self.assertEqual(
+ new_tokenizer.get_config(),
+ self.tokenizer.get_config(),
+ )
+
@pytest.mark.large # Saving is slow, so mark these large.
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" airplane at airport"])
diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py
index 746343e657..056086fe28 100644
--- a/keras_nlp/models/bert/bert_backbone.py
+++ b/keras_nlp/models/bert/bert_backbone.py
@@ -168,12 +168,13 @@ def __init__(
# Construct the two BERT outputs. The pooled output is a dense layer on
# top of the [CLS] token.
sequence_output = x
- pooled_output = keras.layers.Dense(
+ x = keras.layers.Dense(
hidden_dim,
kernel_initializer=bert_kernel_initializer(),
activation="tanh",
name="pooled_dense",
- )(x[:, cls_token_index, :])
+ )(x)
+ pooled_output = x[:, cls_token_index, :]
# Instantiate using Functional API Model constructor
super().__init__(
diff --git a/keras_nlp/models/bert/bert_backbone_test.py b/keras_nlp/models/bert/bert_backbone_test.py
index 275220af6b..5c26fdc122 100644
--- a/keras_nlp/models/bert/bert_backbone_test.py
+++ b/keras_nlp/models/bert/bert_backbone_test.py
@@ -19,6 +19,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.bert.bert_backbone import BertBackbone
from keras_nlp.tests.test_case import TestCase
@@ -34,9 +35,9 @@ def setUp(self):
max_sequence_length=5,
)
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "segment_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "segment_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
@@ -56,9 +57,9 @@ def test_name(self):
def test_variable_sequence_length_call_bert(self):
for seq_length in (2, 3, 4):
input_data = {
- "token_ids": tf.ones((2, seq_length), dtype="int32"),
- "segment_ids": tf.ones((2, seq_length), dtype="int32"),
- "padding_mask": tf.ones((2, seq_length), dtype="int32"),
+ "token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "segment_ids": ops.ones((2, seq_length), dtype="int32"),
+ "padding_mask": ops.ones((2, seq_length), dtype="int32"),
}
self.backbone(input_data)
@@ -84,9 +85,7 @@ def test_saved_model(self):
# Check that output matches.
restored_output = restored_model(self.input_batch)
- self.assertAllClose(
- model_output["pooled_output"], restored_output["pooled_output"]
- )
+ self.assertAllClose(model_output, restored_output)
@pytest.mark.tpu
@@ -103,9 +102,9 @@ def setUp(self):
max_sequence_length=128,
)
self.input_batch = {
- "token_ids": tf.ones((8, 128), dtype="int32"),
- "segment_ids": tf.ones((8, 128), dtype="int32"),
- "padding_mask": tf.ones((8, 128), dtype="int32"),
+ "token_ids": ops.ones((8, 128), dtype="int32"),
+ "segment_ids": ops.ones((8, 128), dtype="int32"),
+ "padding_mask": ops.ones((8, 128), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
diff --git a/keras_nlp/models/bert/bert_classifier.py b/keras_nlp/models/bert/bert_classifier.py
index 04b6f8c8d1..3dff53ed8d 100644
--- a/keras_nlp/models/bert/bert_classifier.py
+++ b/keras_nlp/models/bert/bert_classifier.py
@@ -23,7 +23,6 @@
from keras_nlp.models.bert.bert_presets import backbone_presets
from keras_nlp.models.bert.bert_presets import classifier_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -177,7 +176,7 @@ def __init__(
),
optimizer=keras.optimizers.Adam(5e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
def get_config(self):
diff --git a/keras_nlp/models/bert/bert_classifier_test.py b/keras_nlp/models/bert/bert_classifier_test.py
index 2ea2836e72..1ad831c3dc 100644
--- a/keras_nlp/models/bert/bert_classifier_test.py
+++ b/keras_nlp/models/bert/bert_classifier_test.py
@@ -19,6 +19,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.bert.bert_backbone import BertBackbone
from keras_nlp.models.bert.bert_classifier import BertClassifier
from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
@@ -52,15 +53,13 @@ def setUp(self):
)
# Setup data.
- self.raw_batch = tf.constant(
- [
- "the quick brown fox.",
- "the slow brown fox.",
- ]
- )
+ self.raw_batch = [
+ "the quick brown fox.",
+ "the slow brown fox.",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
- (self.raw_batch, tf.ones((2,)))
+ (self.raw_batch, ops.ones((2,)))
).batch(2)
self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
@@ -74,7 +73,7 @@ def test_classifier_predict(self):
# Assert predictions match.
self.assertAllClose(preds1, preds2)
# Assert valid softmax output.
- self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0])
+ self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0])
def test_classifier_fit(self):
self.classifier.fit(self.raw_dataset)
@@ -84,6 +83,7 @@ def test_classifier_fit(self):
def test_classifier_fit_no_xla(self):
self.classifier.preprocessor = None
self.classifier.compile(
+ optimizer="adam",
loss="sparse_categorical_crossentropy",
jit_compile=False,
)
diff --git a/keras_nlp/models/bert/bert_masked_lm.py b/keras_nlp/models/bert/bert_masked_lm.py
index 56f6e8499a..1b1cd4853f 100644
--- a/keras_nlp/models/bert/bert_masked_lm.py
+++ b/keras_nlp/models/bert/bert_masked_lm.py
@@ -25,7 +25,6 @@
)
from keras_nlp.models.bert.bert_presets import backbone_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -136,7 +135,7 @@ def __init__(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
@classproperty
diff --git a/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py b/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py
index c993253387..004c30e495 100644
--- a/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py
+++ b/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py
@@ -136,6 +136,7 @@ def test_serialization(self):
)
@pytest.mark.large # Saving is slow, so mark these large.
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
diff --git a/keras_nlp/models/bert/bert_masked_lm_test.py b/keras_nlp/models/bert/bert_masked_lm_test.py
index 5f1db45408..46c777b0c2 100644
--- a/keras_nlp/models/bert/bert_masked_lm_test.py
+++ b/keras_nlp/models/bert/bert_masked_lm_test.py
@@ -56,34 +56,33 @@ def setUp(self):
)
# Setup data.
- self.raw_batch = tf.constant(
- [
- "the quick brown fox.",
- "the slow brown fox.",
- ]
- )
+ self.raw_batch = [
+ "the quick brown fox.",
+ "the slow brown fox.",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
self.raw_batch
).batch(2)
self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
- def test_valid_call_classifier(self):
+ def test_valid_call(self):
self.masked_lm(self.preprocessed_batch[0])
- def test_classifier_predict(self):
+ def test_predict(self):
self.masked_lm.predict(self.raw_batch)
self.masked_lm.preprocessor = None
self.masked_lm.predict(self.preprocessed_batch[0])
- def test_classifier_fit(self):
+ def test_fit(self):
self.masked_lm.fit(self.raw_dataset)
self.masked_lm.preprocessor = None
self.masked_lm.fit(self.preprocessed_dataset)
- def test_classifier_fit_no_xla(self):
+ def test_fit_no_xla(self):
self.masked_lm.preprocessor = None
self.masked_lm.compile(
+ optimizer="adam",
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
jit_compile=False,
)
diff --git a/keras_nlp/models/bert/bert_preprocessor_test.py b/keras_nlp/models/bert/bert_preprocessor_test.py
index 10c1c22017..28f12a84dd 100644
--- a/keras_nlp/models/bert/bert_preprocessor_test.py
+++ b/keras_nlp/models/bert/bert_preprocessor_test.py
@@ -121,6 +121,7 @@ def test_serialization(self):
)
@pytest.mark.large # Saving is slow, so mark these large.
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["THE QUICK BROWN FOX."])
inputs = keras.Input(dtype="string", shape=())
diff --git a/keras_nlp/models/bert/bert_presets_test.py b/keras_nlp/models/bert/bert_presets_test.py
index e9e7f36cac..6728a7ded5 100644
--- a/keras_nlp/models/bert/bert_presets_test.py
+++ b/keras_nlp/models/bert/bert_presets_test.py
@@ -14,9 +14,9 @@
"""Tests for loading pretrained model presets."""
import pytest
-import tensorflow as tf
from absl.testing import parameterized
+from keras_nlp.backend import ops
from keras_nlp.models.bert.bert_backbone import BertBackbone
from keras_nlp.models.bert.bert_classifier import BertClassifier
from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
@@ -55,9 +55,9 @@ def test_preprocessor_output(self):
)
def test_backbone_output(self, load_weights):
input_data = {
- "token_ids": tf.constant([[101, 1996, 4248, 102]]),
- "segment_ids": tf.constant([[0, 0, 0, 0]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[101, 1996, 4248, 102]]),
+ "segment_ids": ops.array([[0, 0, 0, 0]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = BertBackbone.from_preset(
"bert_tiny_en_uncased", load_weights=load_weights
@@ -78,7 +78,7 @@ def test_backbone_output(self, load_weights):
("load_weights", True), ("no_load_weights", False)
)
def test_classifier_output(self, load_weights):
- input_data = tf.constant(["The quick brown fox."])
+ input_data = ["The quick brown fox."]
model = BertClassifier.from_preset(
"bert_tiny_en_uncased",
num_classes=2,
@@ -92,9 +92,9 @@ def test_classifier_output(self, load_weights):
)
def test_classifier_output_without_preprocessing(self, load_weights):
input_data = {
- "token_ids": tf.constant([[101, 1996, 4248, 102]]),
- "segment_ids": tf.constant([[0, 0, 0, 0]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[101, 1996, 4248, 102]]),
+ "segment_ids": ops.array([[0, 0, 0, 0]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = BertClassifier.from_preset(
"bert_tiny_en_uncased",
@@ -187,13 +187,11 @@ def test_load_bert(self, load_weights):
for preset in BertBackbone.presets:
model = BertBackbone.from_preset(preset, load_weights=load_weights)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512), dtype="int64", maxval=model.vocabulary_size
),
- "segment_ids": tf.constant(
- [0] * 200 + [1] * 312, shape=(1, 512)
- ),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
model(input_data)
@@ -207,7 +205,7 @@ def test_load_bert_classifier(self, load_weights):
num_classes=2,
load_weights=load_weights,
)
- input_data = tf.constant(["This quick brown fox"])
+ input_data = ["This quick brown fox."]
classifier.predict(input_data)
@parameterized.named_parameters(
@@ -222,15 +220,13 @@ def test_load_bert_classifier_without_preprocessing(self, load_weights):
load_weights=load_weights,
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512),
dtype="int64",
maxval=classifier.backbone.vocabulary_size,
),
- "segment_ids": tf.constant(
- [0] * 200 + [1] * 312, shape=(1, 512)
- ),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
classifier.predict(input_data)
diff --git a/keras_nlp/models/bert/bert_tokenizer_test.py b/keras_nlp/models/bert/bert_tokenizer_test.py
index 1023626a2e..b56a9f5c19 100644
--- a/keras_nlp/models/bert/bert_tokenizer_test.py
+++ b/keras_nlp/models/bert/bert_tokenizer_test.py
@@ -36,7 +36,7 @@ def test_tokenize(self):
self.assertAllEqual(output, [5, 6, 7, 8, 1])
def test_tokenize_batch(self):
- input_data = tf.constant(["THE QUICK BROWN FOX.", "THE FOX."])
+ input_data = ["THE QUICK BROWN FOX.", "THE FOX."]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[5, 6, 7, 8, 1], [5, 8, 1]])
@@ -67,6 +67,7 @@ def test_serialization(self):
)
@pytest.mark.large # Saving is slow, so mark these large.
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["THE QUICK BROWN FOX."])
tokenizer = BertTokenizer(vocabulary=self.vocab)
diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py
index 616b1b6131..1c889fc36c 100644
--- a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py
+++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py
@@ -25,6 +25,7 @@
)
from keras_nlp.models.deberta_v3.relative_embedding import RelativeEmbedding
from keras_nlp.utils.python_utils import classproperty
+from keras_nlp.utils.tensor_utils import assert_tf_backend
def deberta_kernel_initializer(stddev=0.02):
@@ -110,6 +111,8 @@ def __init__(
bucket_size=256,
**kwargs,
):
+ assert_tf_backend(self.__class__.__name__)
+
# Inputs
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py
index 9ced19a2c2..2cd49914df 100644
--- a/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py
+++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone_test.py
@@ -19,10 +19,12 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone
from keras_nlp.tests.test_case import TestCase
+@pytest.mark.tf_only
class DebertaV3BackboneTest(TestCase):
def setUp(self):
self.backbone = DebertaV3Backbone(
@@ -36,8 +38,8 @@ def setUp(self):
)
self.batch_size = 8
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
@@ -57,8 +59,8 @@ def test_token_embedding(self):
def test_variable_sequence_length_call_deberta(self):
for seq_length in (2, 3, 4):
input_data = {
- "token_ids": tf.ones((2, seq_length), dtype="int32"),
- "padding_mask": tf.ones((2, seq_length), dtype="int32"),
+ "token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "padding_mask": ops.ones((2, seq_length), dtype="int32"),
}
output = self.backbone(input_data)
self.assertAllEqual(
@@ -93,6 +95,7 @@ def test_saved_model(self):
@pytest.mark.tpu
@pytest.mark.usefixtures("tpu_test_class")
+@pytest.mark.tf_only
class DebertaV3BackboneTPUTest(TestCase):
def setUp(self):
with self.tpu_strategy.scope():
@@ -106,8 +109,8 @@ def setUp(self):
bucket_size=2,
)
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
diff --git a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py
index f91e159bd6..2d15a95ab9 100644
--- a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py
+++ b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py
@@ -26,7 +26,6 @@
)
from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -208,7 +207,7 @@ def __init__(
),
optimizer=keras.optimizers.Adam(5e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
def get_config(self):
diff --git a/keras_nlp/models/deberta_v3/deberta_v3_classifier_test.py b/keras_nlp/models/deberta_v3/deberta_v3_classifier_test.py
index 4342a6fd9a..da69036472 100644
--- a/keras_nlp/models/deberta_v3/deberta_v3_classifier_test.py
+++ b/keras_nlp/models/deberta_v3/deberta_v3_classifier_test.py
@@ -21,6 +21,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone
from keras_nlp.models.deberta_v3.deberta_v3_classifier import (
DebertaV3Classifier,
@@ -32,6 +33,7 @@
from keras_nlp.tests.test_case import TestCase
+@pytest.mark.tf_only
class DebertaV3ClassifierTest(TestCase):
def setUp(self):
bytes_io = io.BytesIO()
@@ -75,15 +77,13 @@ def setUp(self):
hidden_dim=4,
)
- self.raw_batch = tf.constant(
- [
- "the quick brown fox.",
- "the slow brown fox.",
- ]
- )
+ self.raw_batch = [
+ "the quick brown fox.",
+ "the slow brown fox.",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
- (self.raw_batch, tf.ones((2,)))
+ (self.raw_batch, ops.ones((2,)))
).batch(2)
self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
@@ -97,7 +97,7 @@ def test_classifier_predict(self):
# Assert predictions match.
self.assertAllClose(preds1, preds2)
# Assert valid softmax output.
- self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0])
+ self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0])
def test_classifier_fit(self):
self.classifier.fit(self.raw_dataset)
diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py
index 4d97b77e82..8fe6e0c919 100644
--- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py
+++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py
@@ -27,7 +27,6 @@
)
from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -142,7 +141,7 @@ def __init__(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
@classproperty
diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py
index b54da09fcb..65fc464027 100644
--- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py
+++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor_test.py
@@ -148,6 +148,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_test.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_test.py
index 0b62f27548..768fc25661 100644
--- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_test.py
+++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_test.py
@@ -30,6 +30,7 @@
from keras_nlp.tests.test_case import TestCase
+@pytest.mark.tf_only
class DebertaV3MaskedLMTest(TestCase):
def setUp(self):
bytes_io = io.BytesIO()
@@ -70,12 +71,10 @@ def setUp(self):
preprocessor=self.preprocessor,
)
- self.raw_batch = tf.constant(
- [
- "the quick brown fox.",
- "the eagle flew over fox.",
- ]
- )
+ self.raw_batch = [
+ "the quick brown fox.",
+ "the eagle flew over fox.",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
self.raw_batch
diff --git a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py
index fbf6a7e09f..0526b4b548 100644
--- a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py
+++ b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor_test.py
@@ -145,6 +145,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
inputs = keras.Input(dtype="string", shape=())
diff --git a/keras_nlp/models/deberta_v3/deberta_v3_presets_test.py b/keras_nlp/models/deberta_v3/deberta_v3_presets_test.py
index 0df222fd0b..334ea8a263 100644
--- a/keras_nlp/models/deberta_v3/deberta_v3_presets_test.py
+++ b/keras_nlp/models/deberta_v3/deberta_v3_presets_test.py
@@ -14,9 +14,9 @@
"""Tests for loading pretrained model presets."""
import pytest
-import tensorflow as tf
from absl.testing import parameterized
+from keras_nlp.backend import ops
from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone
from keras_nlp.models.deberta_v3.deberta_v3_classifier import (
DebertaV3Classifier,
@@ -29,6 +29,7 @@
@pytest.mark.large
+@pytest.mark.tf_only
class DebertaV3PresetSmokeTest(TestCase):
"""
A smoke test for DeBERTa presets we run continuously.
@@ -67,8 +68,8 @@ def test_preprocessor_mask_token(self):
)
def test_backbone_output(self, load_weights):
input_data = {
- "token_ids": tf.constant([[0, 581, 63773, 2]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[0, 581, 63773, 2]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = DebertaV3Backbone.from_preset(
"deberta_v3_extra_small_en", load_weights=load_weights
@@ -83,7 +84,7 @@ def test_backbone_output(self, load_weights):
("preset_weights", True), ("random_weights", False)
)
def test_classifier_output(self, load_weights):
- input_data = tf.constant(["The quick brown fox."])
+ input_data = ["The quick brown fox."]
model = DebertaV3Classifier.from_preset(
"deberta_v3_extra_small_en",
num_classes=2,
@@ -97,8 +98,8 @@ def test_classifier_output(self, load_weights):
)
def test_classifier_output_without_preprocessing(self, load_weights):
input_data = {
- "token_ids": tf.constant([[0, 581, 63773, 2]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[0, 581, 63773, 2]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = DebertaV3Classifier.from_preset(
"deberta_v3_extra_small_en",
@@ -133,6 +134,7 @@ def test_unknown_preset_error(self, cls, kwargs):
@pytest.mark.extra_large
+@pytest.mark.tf_only
class DebertaV3PresetFullTest(TestCase):
"""
Test the full enumeration of our preset.
@@ -151,10 +153,10 @@ def test_load_deberta(self, load_weights):
preset, load_weights=load_weights
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512), dtype="int64", maxval=model.vocabulary_size
),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
model(input_data)
@@ -168,7 +170,7 @@ def test_load_deberta_classifier(self, load_weights):
num_classes=4,
load_weights=load_weights,
)
- input_data = tf.constant(["This quick brown fox"])
+ input_data = ["The quick brown fox."]
classifier.predict(input_data)
@parameterized.named_parameters(
@@ -183,12 +185,12 @@ def test_load_deberta_classifier_without_preprocessing(self, load_weights):
preprocessor=None,
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512),
dtype="int64",
maxval=classifier.backbone.vocabulary_size,
),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
classifier.predict(input_data)
diff --git a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer_test.py b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer_test.py
index 0019a28841..88f3936ec2 100644
--- a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer_test.py
+++ b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer_test.py
@@ -55,19 +55,19 @@ def test_tokenize(self):
self.assertAllEqual(output, [4, 9, 5, 7])
def test_tokenize_batch(self):
- input_data = tf.constant(["the quick brown fox", "the earth is round"])
+ input_data = ["the quick brown fox", "the earth is round"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[4, 9, 5, 7], [4, 6, 8, 3]])
def test_detokenize(self):
- input_data = tf.constant([[4, 9, 5, 7]])
+ input_data = [[4, 9, 5, 7]]
output = self.tokenizer.detokenize(input_data)
- self.assertEqual(output, tf.constant(["the quick brown fox"]))
+ self.assertEqual(output, ["the quick brown fox"])
def test_detokenize_mask_token(self):
- input_data = tf.constant([[4, 9, 5, 7, self.tokenizer.mask_token_id]])
+ input_data = [[4, 9, 5, 7, self.tokenizer.mask_token_id]]
output = self.tokenizer.detokenize(input_data)
- self.assertEqual(output, tf.constant(["the quick brown fox"]))
+ self.assertEqual(output, ["the quick brown fox"])
def test_vocabulary_size(self):
self.assertEqual(self.tokenizer.vocabulary_size(), 11)
@@ -103,6 +103,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
diff --git a/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py b/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py
index 6c3a872131..594e7188ca 100644
--- a/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py
+++ b/keras_nlp/models/deberta_v3/disentangled_attention_encoder.py
@@ -75,11 +75,7 @@ def __init__(
bias_initializer="zeros",
**kwargs
):
- # Work around for model saving
- self._input_shape = kwargs.pop("build_input_shape", None)
-
super().__init__(**kwargs)
-
self.intermediate_dim = intermediate_dim
self.num_heads = num_heads
self.max_position_embeddings = max_position_embeddings
@@ -92,15 +88,9 @@ def __init__(
self._built = False
self.supports_masking = True
- if self._input_shape is not None:
- self._build(self._input_shape)
-
- def _build(self, input_shape):
- # Create layers based on input shape.
- self._built = True
- self._input_shape = input_shape
+ def build(self, inputs_shape):
# Infer the dimension of our hidden feature size from the build shape.
- hidden_dim = input_shape[-1]
+ hidden_dim = inputs_shape[-1]
# Self attention layers.
self._self_attention_layer = DisentangledSelfAttention(
@@ -112,9 +102,11 @@ def _build(self, input_shape):
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
)
+ self._self_attention_layer.build(inputs_shape)
self._self_attention_layernorm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)
+ self._self_attention_layernorm.build(inputs_shape)
self._self_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
)
@@ -123,20 +115,26 @@ def _build(self, input_shape):
self._feedforward_layernorm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)
+ self._feedforward_layernorm.build(inputs_shape)
self._feedforward_intermediate_dense = keras.layers.Dense(
self.intermediate_dim,
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
)
+ self._feedforward_intermediate_dense.build(inputs_shape)
self._feedforward_output_dense = keras.layers.Dense(
hidden_dim,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
)
+ intermediate_shape = list(inputs_shape)
+ intermediate_shape[-1] = self.intermediate_dim
+ self._feedforward_output_dense.build(tuple(intermediate_shape))
self._feedforward_dropout = keras.layers.Dropout(
rate=self.dropout,
)
+ self.built = True
def call(
self,
@@ -163,10 +161,6 @@ def call(
Returns:
A Tensor of the same shape as the `inputs`.
"""
-
- if not self._built:
- self._build(inputs.shape)
-
x = inputs
# Compute self attention mask.
@@ -177,7 +171,7 @@ def call(
# Self attention block.
residual = x
x = self._self_attention_layer(
- hidden_states=x,
+ x,
rel_embeddings=rel_embeddings,
attention_mask=self_attention_mask,
)
@@ -212,7 +206,9 @@ def get_config(self):
"bias_initializer": keras.initializers.serialize(
self.bias_initializer
),
- "build_input_shape": self._input_shape,
}
)
return config
+
+ def compute_output_shape(self, inputs_shape):
+ return inputs_shape
diff --git a/keras_nlp/models/deberta_v3/disentangled_self_attention.py b/keras_nlp/models/deberta_v3/disentangled_self_attention.py
index f1b8d51721..5f1ab8a473 100644
--- a/keras_nlp/models/deberta_v3/disentangled_self_attention.py
+++ b/keras_nlp/models/deberta_v3/disentangled_self_attention.py
@@ -82,8 +82,7 @@ def __init__(
float(num_type_attn * self.attn_head_size)
)
- # Layers.
-
+ def build(self, inputs_shape, rel_embeddings_shape=None):
# Q, K, V linear layers.
self._query_dense = keras.layers.EinsumDense(
equation="abc,cde->abde",
@@ -92,6 +91,7 @@ def __init__(
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="query",
)
+ self._query_dense.build(inputs_shape)
self._key_dense = keras.layers.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, self.num_heads, self.attn_head_size),
@@ -99,6 +99,7 @@ def __init__(
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="key",
)
+ self._key_dense.build(inputs_shape)
self._value_dense = keras.layers.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, self.num_heads, self.attn_head_size),
@@ -106,6 +107,7 @@ def __init__(
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="value",
)
+ self._value_dense.build(inputs_shape)
# Relative attention.
self._position_dropout_layer = keras.layers.Dropout(self.dropout)
@@ -123,6 +125,8 @@ def __init__(
**self._get_common_kwargs_for_sublayer(use_bias=True),
name="attention_output",
)
+ self._output_dense.build(inputs_shape)
+ self.built = True
def _get_common_kwargs_for_sublayer(self, use_bias=True):
common_kwargs = {}
@@ -322,7 +326,7 @@ def _compute_disentangled_attention(
def call(
self,
- hidden_states,
+ inputs,
rel_embeddings,
attention_mask=None,
return_attention_scores=False,
@@ -330,9 +334,9 @@ def call(
):
# `query`, `key`, `value` shape:
# `(batch_size, sequence_length, num_heads, attn_head_size)`.
- query = self._query_dense(hidden_states)
- key = self._key_dense(hidden_states)
- value = self._value_dense(hidden_states)
+ query = self._query_dense(inputs)
+ key = self._key_dense(inputs)
+ value = self._value_dense(inputs)
attention_output, attention_scores = self._compute_attention(
query=query,
diff --git a/keras_nlp/models/deberta_v3/relative_embedding.py b/keras_nlp/models/deberta_v3/relative_embedding.py
index 06bc1f66ad..dddab8f9c8 100644
--- a/keras_nlp/models/deberta_v3/relative_embedding.py
+++ b/keras_nlp/models/deberta_v3/relative_embedding.py
@@ -88,3 +88,6 @@ def get_config(self):
}
)
return config
+
+ def compute_output_shape(self, input_shape):
+ return (input_shape[0],) + (self.bucket_size * 2, self.hidden_dim)
diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone_test.py b/keras_nlp/models/distil_bert/distil_bert_backbone_test.py
index 3d3e2dcd78..f16c058f71 100644
--- a/keras_nlp/models/distil_bert/distil_bert_backbone_test.py
+++ b/keras_nlp/models/distil_bert/distil_bert_backbone_test.py
@@ -19,6 +19,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone
from keras_nlp.tests.test_case import TestCase
@@ -36,8 +37,8 @@ def setUp(self):
)
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
@@ -54,9 +55,9 @@ def test_token_embedding(self):
def test_variable_sequence_length_call_distilbert(self):
for seq_length in (2, 3, 4):
input_data = {
- "token_ids": tf.ones((2, seq_length), dtype="int32"),
- "mask_positions": tf.ones((2, seq_length), dtype="int32"),
- "padding_mask": tf.ones((2, seq_length), dtype="int32"),
+ "token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "mask_positions": ops.ones((2, seq_length), dtype="int32"),
+ "padding_mask": ops.ones((2, seq_length), dtype="int32"),
}
self.backbone(input_data)
@@ -99,8 +100,8 @@ def setUp(self):
max_sequence_length=128,
)
self.input_batch = {
- "token_ids": tf.ones((8, 128), dtype="int32"),
- "padding_mask": tf.ones((8, 128), dtype="int32"),
+ "token_ids": ops.ones((8, 128), dtype="int32"),
+ "padding_mask": ops.ones((8, 128), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
diff --git a/keras_nlp/models/distil_bert/distil_bert_classifier.py b/keras_nlp/models/distil_bert/distil_bert_classifier.py
index d13f3def43..1ae94cb2fa 100644
--- a/keras_nlp/models/distil_bert/distil_bert_classifier.py
+++ b/keras_nlp/models/distil_bert/distil_bert_classifier.py
@@ -26,7 +26,6 @@
)
from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -193,7 +192,7 @@ def __init__(
),
optimizer=keras.optimizers.Adam(5e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
def get_config(self):
diff --git a/keras_nlp/models/distil_bert/distil_bert_classifier_test.py b/keras_nlp/models/distil_bert/distil_bert_classifier_test.py
index 518842ffc0..1d9b2fef0c 100644
--- a/keras_nlp/models/distil_bert/distil_bert_classifier_test.py
+++ b/keras_nlp/models/distil_bert/distil_bert_classifier_test.py
@@ -19,6 +19,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone
from keras_nlp.models.distil_bert.distil_bert_classifier import (
DistilBertClassifier,
@@ -59,15 +60,13 @@ def setUp(self):
hidden_dim=4,
)
- self.raw_batch = tf.constant(
- [
- "the quick brown fox.",
- "the slow brown fox.",
- ]
- )
+ self.raw_batch = [
+ "the quick brown fox.",
+ "the slow brown fox.",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
- (self.raw_batch, tf.ones((2,)))
+ (self.raw_batch, ops.ones((2,)))
).batch(2)
self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
@@ -81,7 +80,7 @@ def test_classifier_predict(self):
# Assert predictions match.
self.assertAllClose(preds1, preds2)
# Assert valid softmax output.
- self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0])
+ self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0])
def test_classifier_fit(self):
self.classifier.fit(self.raw_dataset)
diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py
index 4e7d29258b..1bc7f0dbf1 100644
--- a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py
+++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py
@@ -27,7 +27,6 @@
)
from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -140,7 +139,7 @@ def __init__(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
@classproperty
diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py
index 16870de81a..192cd6b2b1 100644
--- a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py
+++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py
@@ -113,6 +113,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" THE QUICK BROWN FOX."])
diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py
index 154970f74a..e949a01ab1 100644
--- a/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py
+++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py
@@ -55,12 +55,10 @@ def setUp(self):
preprocessor=self.preprocessor,
)
- self.raw_batch = tf.constant(
- [
- "the quick brown fox.",
- "the slow brown fox.",
- ]
- )
+ self.raw_batch = [
+ "the quick brown fox.",
+ "the slow brown fox.",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
self.raw_batch
diff --git a/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py b/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py
index 9ca119ddee..22127a90fe 100644
--- a/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py
+++ b/keras_nlp/models/distil_bert/distil_bert_preprocessor_test.py
@@ -111,6 +111,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["THE QUICK BROWN FOX."])
inputs = keras.Input(dtype="string", shape=())
diff --git a/keras_nlp/models/distil_bert/distil_bert_presets_test.py b/keras_nlp/models/distil_bert/distil_bert_presets_test.py
index 71e56911c0..fcef57e0c0 100644
--- a/keras_nlp/models/distil_bert/distil_bert_presets_test.py
+++ b/keras_nlp/models/distil_bert/distil_bert_presets_test.py
@@ -14,9 +14,9 @@
"""Tests for loading pretrained model presets."""
import pytest
-import tensorflow as tf
from absl.testing import parameterized
+from keras_nlp.backend import ops
from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone
from keras_nlp.models.distil_bert.distil_bert_classifier import (
DistilBertClassifier,
@@ -61,8 +61,8 @@ def test_preprocessor_output(self):
)
def test_backbone_output(self, load_weights):
input_data = {
- "token_ids": tf.constant([[101, 1996, 4248, 102]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[101, 1996, 4248, 102]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = DistilBertBackbone.from_preset(
"distil_bert_base_en_uncased", load_weights=load_weights
@@ -76,7 +76,7 @@ def test_backbone_output(self, load_weights):
("preset_weights", True), ("random_weights", False)
)
def test_classifier_output(self, load_weights):
- input_data = tf.constant(["The quick brown fox."])
+ input_data = ["The quick brown fox."]
model = DistilBertClassifier.from_preset(
"distil_bert_base_en_uncased",
num_classes=2,
@@ -89,8 +89,8 @@ def test_classifier_output(self, load_weights):
)
def test_classifier_output_without_preprocessing(self, load_weights):
input_data = {
- "token_ids": tf.constant([[101, 1996, 4248, 102]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[101, 1996, 4248, 102]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = DistilBertClassifier.from_preset(
"distil_bert_base_en_uncased",
@@ -142,10 +142,10 @@ def test_load_distilbert(self, load_weights):
preset, load_weights=load_weights
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512), dtype="int64", maxval=model.vocabulary_size
),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
model(input_data)
@@ -159,7 +159,7 @@ def test_load_distilbert_classifier(self, load_weights):
num_classes=2,
load_weights=load_weights,
)
- input_data = tf.constant(["This quick brown fox"])
+ input_data = ["This quick brown fox."]
classifier.predict(input_data)
@parameterized.named_parameters(
@@ -174,12 +174,12 @@ def test_load_distilbert_classifier_no_preprocessing(self, load_weights):
preprocessor=None,
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512),
dtype="int64",
maxval=classifier.backbone.vocabulary_size,
),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
classifier.predict(input_data)
diff --git a/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py b/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py
index 11007ecc21..7e3481f6f3 100644
--- a/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py
+++ b/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py
@@ -38,7 +38,7 @@ def test_tokenize(self):
self.assertAllEqual(output, [5, 6, 7, 8, 1])
def test_tokenize_batch(self):
- input_data = tf.constant(["THE QUICK BROWN FOX.", "THE FOX."])
+ input_data = ["THE QUICK BROWN FOX.", "THE FOX."]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[5, 6, 7, 8, 1], [5, 8, 1]])
@@ -69,6 +69,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["THE QUICK BROWN FOX."])
tokenizer = DistilBertTokenizer(vocabulary=self.vocab)
diff --git a/keras_nlp/models/f_net/f_net_backbone_test.py b/keras_nlp/models/f_net/f_net_backbone_test.py
index 57930fe2c6..c65b845755 100644
--- a/keras_nlp/models/f_net/f_net_backbone_test.py
+++ b/keras_nlp/models/f_net/f_net_backbone_test.py
@@ -19,10 +19,12 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.tests.test_case import TestCase
+@pytest.mark.tf_only
class FNetBackboneTest(TestCase):
def setUp(self):
self.backbone = FNetBackbone(
@@ -34,8 +36,8 @@ def setUp(self):
num_segments=4,
)
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "segment_ids": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "segment_ids": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
@@ -51,8 +53,8 @@ def test_valid_call_f_net(self):
def test_variable_sequence_length_call_f_net(self):
for seq_length in (2, 3, 4):
input_data = {
- "token_ids": tf.ones((2, seq_length), dtype="int32"),
- "segment_ids": tf.ones((2, seq_length), dtype="int32"),
+ "token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "segment_ids": ops.ones((2, seq_length), dtype="int32"),
}
self.backbone(input_data)
@@ -97,8 +99,8 @@ def setUp(self):
num_segments=4,
)
self.input_batch = {
- "token_ids": tf.ones((8, 128), dtype="int32"),
- "segment_ids": tf.ones((8, 128), dtype="int32"),
+ "token_ids": ops.ones((8, 128), dtype="int32"),
+ "segment_ids": ops.ones((8, 128), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
diff --git a/keras_nlp/models/f_net/f_net_classifier.py b/keras_nlp/models/f_net/f_net_classifier.py
index 9dab025296..ca15a52392 100644
--- a/keras_nlp/models/f_net/f_net_classifier.py
+++ b/keras_nlp/models/f_net/f_net_classifier.py
@@ -23,7 +23,6 @@
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor
from keras_nlp.models.f_net.f_net_presets import backbone_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -144,7 +143,7 @@ def __init__(
),
optimizer=keras.optimizers.Adam(5e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
def get_config(self):
diff --git a/keras_nlp/models/f_net/f_net_classifier_test.py b/keras_nlp/models/f_net/f_net_classifier_test.py
index 7131067481..9c8978b6f1 100644
--- a/keras_nlp/models/f_net/f_net_classifier_test.py
+++ b/keras_nlp/models/f_net/f_net_classifier_test.py
@@ -21,6 +21,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.models.f_net.f_net_classifier import FNetClassifier
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor
@@ -28,6 +29,7 @@
from keras_nlp.tests.test_case import TestCase
+@pytest.mark.tf_only
class FNetClassifierTest(TestCase):
def setUp(self):
# Setup Model
@@ -74,15 +76,13 @@ def setUp(self):
)
# Setup data.
- self.raw_batch = tf.constant(
- [
- "the quick brown fox.",
- "the slow brown fox.",
- ]
- )
+ self.raw_batch = [
+ "the quick brown fox.",
+ "the slow brown fox.",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
- (self.raw_batch, tf.ones((2,)))
+ (self.raw_batch, ops.ones((2,)))
).batch(2)
self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
@@ -96,7 +96,7 @@ def test_classifier_predict(self):
# Assert predictions match.
self.assertAllClose(preds1, preds2)
# Assert valid softmax output.
- self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0])
+ self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0])
def test_fnet_classifier_fit(self):
self.classifier.fit(self.raw_dataset)
diff --git a/keras_nlp/models/f_net/f_net_masked_lm.py b/keras_nlp/models/f_net/f_net_masked_lm.py
index 8c6c713b04..f6ecd549d8 100644
--- a/keras_nlp/models/f_net/f_net_masked_lm.py
+++ b/keras_nlp/models/f_net/f_net_masked_lm.py
@@ -23,7 +23,6 @@
)
from keras_nlp.models.f_net.f_net_presets import backbone_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -134,7 +133,7 @@ def __init__(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
@classproperty
diff --git a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py
index 2b2e0ca952..81ee99a4a4 100644
--- a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py
+++ b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor_test.py
@@ -130,6 +130,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
diff --git a/keras_nlp/models/f_net/f_net_masked_lm_test.py b/keras_nlp/models/f_net/f_net_masked_lm_test.py
index 23d67ddde7..4b102cd59a 100644
--- a/keras_nlp/models/f_net/f_net_masked_lm_test.py
+++ b/keras_nlp/models/f_net/f_net_masked_lm_test.py
@@ -28,6 +28,7 @@
from keras_nlp.tests.test_case import TestCase
+@pytest.mark.tf_only
class FNetMaskedLMTest(TestCase):
def setUp(self):
# Setup Model.
@@ -68,12 +69,10 @@ def setUp(self):
preprocessor=self.preprocessor,
)
- self.raw_batch = tf.constant(
- [
- "the quick brown fox",
- "the slow brown fox",
- ]
- )
+ self.raw_batch = [
+ "the quick brown fox",
+ "the slow brown fox",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)[0]
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
self.raw_batch
diff --git a/keras_nlp/models/f_net/f_net_preprocessor_test.py b/keras_nlp/models/f_net/f_net_preprocessor_test.py
index 3230998298..8034c09e6c 100644
--- a/keras_nlp/models/f_net/f_net_preprocessor_test.py
+++ b/keras_nlp/models/f_net/f_net_preprocessor_test.py
@@ -148,6 +148,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
inputs = keras.Input(dtype="string", shape=())
diff --git a/keras_nlp/models/f_net/f_net_presets_test.py b/keras_nlp/models/f_net/f_net_presets_test.py
index 4c30e0b0b2..37d0f21a47 100644
--- a/keras_nlp/models/f_net/f_net_presets_test.py
+++ b/keras_nlp/models/f_net/f_net_presets_test.py
@@ -14,9 +14,9 @@
"""Tests for loading pretrained model presets."""
import pytest
-import tensorflow as tf
from absl.testing import parameterized
+from keras_nlp.backend import ops
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.models.f_net.f_net_classifier import FNetClassifier
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor
@@ -25,6 +25,7 @@
@pytest.mark.large
+@pytest.mark.tf_only
class FNetPresetSmokeTest(TestCase):
"""
A smoke test for FNet presets we run continuously.
@@ -55,9 +56,9 @@ def test_preprocessor_output(self):
)
def test_backbone_output(self, load_weights):
input_data = {
- "token_ids": tf.constant([[101, 1996, 4248, 102]]),
- "segment_ids": tf.constant([[0, 0, 0, 0]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[101, 1996, 4248, 102]]),
+ "segment_ids": ops.array([[0, 0, 0, 0]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = FNetBackbone.from_preset(
"f_net_base_en", load_weights=load_weights
@@ -78,7 +79,7 @@ def test_backbone_output(self, load_weights):
("load_weights", True), ("no_load_weights", False)
)
def test_classifier_output(self, load_weights):
- input_data = tf.constant(["The quick brown fox."])
+ input_data = ["The quick brown fox."]
model = FNetClassifier.from_preset(
"f_net_base_en",
num_classes=2,
@@ -127,12 +128,10 @@ def test_load_f_net(self, load_weights):
for preset in FNetBackbone.presets:
model = FNetBackbone.from_preset(preset, load_weights=load_weights)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512), dtype="int64", maxval=model.vocabulary_size
),
- "segment_ids": tf.constant(
- [0] * 200 + [1] * 312, shape=(1, 512)
- ),
+ "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)),
}
model(input_data)
@@ -146,7 +145,7 @@ def test_load_fnet_classifier(self, load_weights):
num_classes=2,
load_weights=load_weights,
)
- input_data = tf.constant(["This quick brown fox"])
+ input_data = ["The quick brown fox."]
classifier.predict(input_data)
@parameterized.named_parameters(
@@ -161,15 +160,13 @@ def test_load_fnet_classifier_without_preprocessing(self, load_weights):
load_weights=load_weights,
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512),
dtype="int64",
maxval=classifier.backbone.vocabulary_size,
),
- "segment_ids": tf.constant(
- [0] * 200 + [1] * 312, shape=(1, 512)
- ),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "segment_ids": ops.array([0] * 200 + [1] * 312, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
classifier.predict(input_data)
diff --git a/keras_nlp/models/f_net/f_net_tokenizer_test.py b/keras_nlp/models/f_net/f_net_tokenizer_test.py
index ea0bd3e7e7..380eb920de 100644
--- a/keras_nlp/models/f_net/f_net_tokenizer_test.py
+++ b/keras_nlp/models/f_net/f_net_tokenizer_test.py
@@ -25,6 +25,7 @@
from keras_nlp.tests.test_case import TestCase
+@pytest.mark.tf_only
class FNetTokenizerTest(TestCase):
def setUp(self):
bytes_io = io.BytesIO()
@@ -56,14 +57,14 @@ def test_tokenize(self):
self.assertAllEqual(output, [2, 10, 6, 8])
def test_tokenize_batch(self):
- input_data = tf.constant(["the quick brown fox", "the earth is round"])
+ input_data = ["the quick brown fox", "the earth is round"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[2, 10, 6, 8], [2, 7, 9, 11]])
def test_detokenize(self):
- input_data = tf.constant([[2, 10, 6, 8]])
+ input_data = [[2, 10, 6, 8]]
output = self.tokenizer.detokenize(input_data)
- self.assertEqual(output, tf.constant(["the quick brown fox"]))
+ self.assertEqual(output, ["the quick brown fox"])
def test_vocabulary_size(self):
tokenizer = FNetTokenizer(proto=self.proto)
@@ -91,6 +92,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
diff --git a/keras_nlp/models/generative_task.py b/keras_nlp/models/generative_task.py
index 76771a4f3e..e479bab196 100644
--- a/keras_nlp/models/generative_task.py
+++ b/keras_nlp/models/generative_task.py
@@ -13,9 +13,13 @@
# limitations under the License.
"""Base class for Generative Task models."""
+import itertools
+
import tensorflow as tf
+from keras_nlp.backend import config
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.task import Task
from keras_nlp.samplers.serialization import get as get_sampler
from keras_nlp.utils.tensor_utils import tensor_to_list
@@ -54,15 +58,80 @@ def make_generate_function(self):
if self.generate_function is not None:
return self.generate_function
- if self.run_eagerly:
- self.generate_function = self.generate_step
- else:
+ self.generate_function = self.generate_step
+ if config.backend() == "torch":
+ import torch
+
+ def wrapped_generate_function(
+ inputs,
+ end_token_id=None,
+ ):
+ with torch.no_grad():
+ return self.generate_step(inputs, end_token_id)
+
+ self.generate_function = wrapped_generate_function
+ elif config.backend() == "tensorflow" and not self.run_eagerly:
# `jit_compile` is a property of keras.Model after TF 2.12.
# Use `getattr()` for backwards compatibility.
jit_compile = getattr(self, "jit_compile", True)
self.generate_function = tf.function(
self.generate_step, jit_compile=jit_compile
)
+ elif config.backend() == "jax" and not self.run_eagerly:
+ import jax
+
+ @jax.jit
+ def compiled_generate_function(inputs, end_token_id, state):
+ (
+ sampler_variables,
+ trainable_variables,
+ non_trainable_variables,
+ ) = state
+ mapping = itertools.chain(
+ zip(self._sampler.variables, sampler_variables),
+ zip(self.trainable_variables, trainable_variables),
+ zip(self.non_trainable_variables, non_trainable_variables),
+ )
+
+ with keras.StatelessScope(state_mapping=mapping) as scope:
+ outputs = self.generate_step(inputs, end_token_id)
+
+ # Get updated sampler variables from the stateless scope.
+ sampler_variables = []
+ for v in self._sampler.variables:
+ new_v = scope.get_current_value(v)
+ sampler_variables.append(new_v if new_v is not None else v)
+ state = (
+ sampler_variables,
+ trainable_variables,
+ non_trainable_variables,
+ )
+ return outputs, state
+
+ def wrapped_generate_function(
+ inputs,
+ end_token_id=None,
+ ):
+ # Create an explicit tuple of all variable state.
+ state = (
+ self._sampler.variables,
+ self.trainable_variables,
+ self.non_trainable_variables,
+ )
+ inputs = tf.nest.map_structure(ops.convert_to_tensor, inputs)
+ outputs, state = compiled_generate_function(
+ inputs,
+ end_token_id,
+ state,
+ )
+ # Only assign the sampler variables (random seeds), as other
+ # model variables should never be updated in generation.
+ for ref_v, v in zip(self._sampler.variables, state[0]):
+ ref_v.assign(v)
+ return outputs
+
+ self.generate_function = wrapped_generate_function
+
return self.generate_function
def _normalize_generate_inputs(
diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py
index 44826b2b3d..40aefda88b 100644
--- a/keras_nlp/models/gpt2/gpt2_backbone.py
+++ b/keras_nlp/models/gpt2/gpt2_backbone.py
@@ -18,6 +18,7 @@
from tensorflow.experimental import dtensor
from tensorflow.experimental.dtensor import Layout
+from tensorflow.keras.dtensor.experimental import LayoutMap
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
@@ -239,7 +240,7 @@ def create_layout_map(cls, mesh):
_, model_dim = mesh.dim_names
unshard_dim = dtensor.UNSHARDED
- layout_map = keras.dtensor.experimental.LayoutMap(mesh=mesh)
+ layout_map = LayoutMap(mesh=mesh)
# Embedding sharding
layout_map[r".*embeddings"] = Layout([unshard_dim, model_dim], mesh)
diff --git a/keras_nlp/models/gpt2/gpt2_backbone_test.py b/keras_nlp/models/gpt2/gpt2_backbone_test.py
index 3757a79f2e..d9134be51b 100644
--- a/keras_nlp/models/gpt2/gpt2_backbone_test.py
+++ b/keras_nlp/models/gpt2/gpt2_backbone_test.py
@@ -19,16 +19,13 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.tests.test_case import TestCase
class GPT2Test(TestCase):
def setUp(self):
- # For DTensor.
- keras.backend.experimental.enable_tf_random_generator()
- keras.utils.set_random_seed(1337)
-
self.backbone = GPT2Backbone(
vocabulary_size=10,
num_layers=2,
@@ -38,9 +35,9 @@ def setUp(self):
max_sequence_length=5,
)
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "segment_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "segment_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
@@ -60,8 +57,8 @@ def test_name(self):
def test_variable_sequence_length(self):
for seq_length in (2, 3, 4):
input_data = {
- "token_ids": tf.ones((2, seq_length), dtype="int32"),
- "padding_mask": tf.ones((2, seq_length), dtype="int32"),
+ "token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "padding_mask": ops.ones((2, seq_length), dtype="int32"),
}
self.backbone(input_data)
@@ -121,8 +118,8 @@ def setUp(self):
max_sequence_length=5,
)
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py
index b2417f98d2..2bc67088e7 100644
--- a/keras_nlp/models/gpt2/gpt2_causal_lm.py
+++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py
@@ -15,20 +15,32 @@
import copy
-import tensorflow as tf
-
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.generative_task import GenerativeTask
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
GPT2CausalLMPreprocessor,
)
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
+# TODO: Extend and factor this out into keras_nlp.layers.
+class ReverseEmbedding(keras.layers.Layer):
+ def __init__(self, embedding, **kwargs):
+ super().__init__(**kwargs)
+ self.embedding = embedding
+
+ def call(self, inputs):
+ kernel = ops.transpose(ops.convert_to_tensor(self.embedding.embeddings))
+ return ops.matmul(inputs, kernel)
+
+ def compute_output_shape(self, input_shape):
+ return (input_shape[0],) + (self.embedding.embeddings.shape[0],)
+
+
@keras_nlp_export("keras_nlp.models.GPT2CausalLM")
class GPT2CausalLM(GenerativeTask):
"""An end-to-end GPT2 model for causal langauge modeling.
@@ -162,11 +174,10 @@ def __init__(
x = backbone(inputs)
# Use token embedding weights to project from the token representation
# to vocabulary logits.
- outputs = tf.matmul(
- x,
- backbone.token_embedding.embeddings,
- transpose_b=True,
- )
+ outputs = ReverseEmbedding(
+ backbone.token_embedding,
+ name="reverse_embedding",
+ )(x)
# Instantiate using Functional API Model constructor.
super().__init__(
@@ -185,7 +196,7 @@ def __init__(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(2e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
@classproperty
@@ -234,33 +245,30 @@ def call_with_cache(
)
x = self.backbone.get_layer("embeddings_dropout")(x)
# Each decoder layer has a cache; we update them separately.
- caches = tf.unstack(cache, axis=1)
+ caches = []
for i in range(self.backbone.num_layers):
- current_cache = caches[i]
+ current_cache = cache[:, i, ...]
x, next_cache = self.backbone.get_layer(f"transformer_layer_{i}")(
x,
self_attention_cache=current_cache,
self_attention_cache_update_index=cache_update_index,
)
- caches[i] = next_cache
- cache = tf.stack(caches, axis=1)
+ caches.append(next_cache)
+ cache = ops.stack(caches, axis=1)
x = self.backbone.get_layer("layer_norm")(x)
hidden_states = x
- logits = tf.matmul(
- hidden_states,
- self.backbone.get_layer("token_embedding").embeddings,
- transpose_b=True,
- )
+ logits = self.get_layer("reverse_embedding")(x)
return logits, hidden_states, cache
def _build_cache(self, token_ids):
"""Build an empty cache for use with `call_with_cache()`."""
- batch_size, max_length = tf.shape(token_ids)[0], tf.shape(token_ids)[1]
+ batch_size = ops.shape(token_ids)[0]
+ max_length = ops.shape(token_ids)[1]
num_layers = self.backbone.num_layers
num_heads = self.backbone.num_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
- cache = tf.zeros(shape, dtype=self.compute_dtype)
+ cache = ops.zeros(shape, dtype=self.compute_dtype)
# Seed the cache.
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
return hidden_states, cache
@@ -287,24 +295,23 @@ def generate_step(
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
- row_lengths = tf.math.reduce_sum(
- tf.cast(padding_mask, "int32"), axis=-1
- )
+ row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
- index = tf.math.reduce_min(row_lengths)
+ index = ops.min(row_lengths)
def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_update_index = index - 1
- prompt = tf.slice(prompt, [0, cache_update_index], [-1, 1])
+ batch_size = ops.shape(prompt)[0]
+ prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
logits, hidden_states, cache = self.call_with_cache(
prompt,
cache,
cache_update_index,
)
return (
- tf.squeeze(logits, axis=1),
- tf.squeeze(hidden_states, axis=1),
+ ops.squeeze(logits, axis=1),
+ ops.squeeze(hidden_states, axis=1),
cache,
)
@@ -323,14 +330,15 @@ def next(prompt, cache, index):
# Build a mask of `end_token_id` locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = (token_ids == end_token_id) & (~padding_mask)
- end_locations = tf.cast(end_locations, "int32")
+ end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
- overflow = tf.math.cumsum(end_locations, exclusive=True, axis=-1)
+ cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
+ overflow = cumsum - end_locations
# Our padding mask is the inverse of these overflow locations.
- padding_mask = ~tf.cast(overflow, "bool")
+ padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
else:
# Without early stopping, all locations will have been updated.
- padding_mask = tf.ones_like(token_ids, dtype="bool")
+ padding_mask = ops.ones_like(token_ids, dtype="bool")
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
@@ -353,7 +361,7 @@ def create_layout_map(cls, mesh):
distribution, and the second for model parallel distribution.
Returns:
- A `tf.keras.dtensor.experimental.LayoutMap` which contains the
+ A `keras.dtensor.experimental.LayoutMap` which contains the
proper layout to weights mapping for the model parallel setting.
Examples:
diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py
index 8441b90430..b0a0a8adc6 100644
--- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py
+++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py
@@ -129,6 +129,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["airplane at airport"])
diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py
index 836b40b0be..e52274da7f 100644
--- a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py
+++ b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py
@@ -16,11 +16,11 @@
import os
from unittest.mock import patch
-import numpy as np
import pytest
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
@@ -32,10 +32,6 @@
class GPT2CausalLMTest(TestCase):
def setUp(self):
- # For DTensor.
- keras.backend.experimental.enable_tf_random_generator()
- keras.utils.set_random_seed(1337)
-
self.vocab = {
"!": 0,
"air": 1,
@@ -65,12 +61,10 @@ def setUp(self):
preprocessor=self.preprocessor,
)
- self.raw_batch = tf.constant(
- [
- " airplane at airport",
- " airplane at airport",
- ]
- )
+ self.raw_batch = [
+ " airplane at airport",
+ " airplane at airport",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)[0]
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
self.raw_batch
@@ -126,8 +120,10 @@ def test_early_stopping(self):
def wrapper(*args, **kwargs):
"""Modify output logits to always favor end_token_id"""
logits, hidden_states, cache = call_with_cache(*args, **kwargs)
- logits = np.zeros(logits.shape.as_list())
- logits[:, :, self.preprocessor.tokenizer.end_token_id] = 1.0e9
+ index = self.preprocessor.tokenizer.end_token_id
+ update = ops.ones_like(logits)[:, :, index] * 1.0e9
+ update = ops.expand_dims(update, axis=-1)
+ logits = ops.slice_update(logits, (0, 0, index), update)
return logits, hidden_states, cache
with patch.object(self.causal_lm, "call_with_cache", wraps=wrapper):
@@ -135,7 +131,6 @@ def wrapper(*args, **kwargs):
output = self.causal_lm.generate(prompt)
# We should immediately abort and output the prompt.
self.assertEqual(prompt, output)
- self.assertEqual(self.causal_lm.call_with_cache.call_count, 2)
def test_generate_compilation(self):
# Assert we do not recompile with successive calls.
diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py
index b535dd65bb..d7b921ee14 100644
--- a/keras_nlp/models/gpt2/gpt2_preprocessor_test.py
+++ b/keras_nlp/models/gpt2/gpt2_preprocessor_test.py
@@ -110,6 +110,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["airplane at airport"])
diff --git a/keras_nlp/models/gpt2/gpt2_presets_test.py b/keras_nlp/models/gpt2/gpt2_presets_test.py
index 396d5b995c..e9cdd707d7 100644
--- a/keras_nlp/models/gpt2/gpt2_presets_test.py
+++ b/keras_nlp/models/gpt2/gpt2_presets_test.py
@@ -14,9 +14,9 @@
"""Tests for loading pretrained model presets."""
import pytest
-import tensorflow as tf
from absl.testing import parameterized
+from keras_nlp.backend import ops
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
from keras_nlp.tests.test_case import TestCase
@@ -42,8 +42,8 @@ def test_tokenizer_output(self):
)
def test_backbone_output(self, load_weights):
input_data = {
- "token_ids": tf.constant([[1169, 2068, 7586, 21831, 13]]),
- "padding_mask": tf.constant([[1, 1, 1, 1, 1]]),
+ "token_ids": ops.array([[1169, 2068, 7586, 21831, 13]]),
+ "padding_mask": ops.array([[1, 1, 1, 1, 1]]),
}
model = GPT2Backbone.from_preset(
"gpt2_base_en", load_weights=load_weights
@@ -95,12 +95,12 @@ def test_load_gpt2(self, load_weights):
for preset in GPT2Backbone.presets:
model = GPT2Backbone.from_preset(preset, load_weights=load_weights)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 1024),
dtype="int64",
maxval=model.vocabulary_size,
),
- "padding_mask": tf.constant([1] * 1024, shape=(1, 1024)),
+ "padding_mask": ops.array([1] * 1024, shape=(1, 1024)),
}
model(input_data)
diff --git a/keras_nlp/models/gpt2/gpt2_tokenizer_test.py b/keras_nlp/models/gpt2/gpt2_tokenizer_test.py
index e4e5baee86..51347a30da 100644
--- a/keras_nlp/models/gpt2/gpt2_tokenizer_test.py
+++ b/keras_nlp/models/gpt2/gpt2_tokenizer_test.py
@@ -74,7 +74,7 @@ def test_tokenize_end_token(self):
self.assertAllEqual(output, [1, 2, 3, 1, 4, 0])
def test_tokenize_batch(self):
- input_data = tf.constant([" airplane at airport", " kohli is the best"])
+ input_data = [" airplane at airport", " kohli is the best"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[1, 2, 3, 1, 4], [5, 6, 7, 8, 9]])
@@ -99,6 +99,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" airplane at airport"])
diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py
index 7c82eaf469..3a3ce2f3e1 100644
--- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py
+++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py
@@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import tensorflow as tf
-
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.gpt_neo_x.gpt_neo_x_decoder import GPTNeoXDecoder
+from keras_nlp.utils.tensor_utils import assert_tf_backend
def _gpt_neo_x_kernel_initializer(stddev=0.02):
@@ -76,6 +75,8 @@ def __init__(
max_sequence_length=512,
**kwargs,
):
+ assert_tf_backend(self.__class__.__name__)
+
# Inputs
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
padding_mask = keras.Input(
@@ -116,7 +117,7 @@ def __init__(
name="layer_norm",
axis=-1,
epsilon=layer_norm_epsilon,
- dtype=tf.float32,
+ dtype="float32",
)(x)
# Instantiate using Functional API Model constructor
diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone_test.py
index d8fe095d99..9e6a2973b5 100644
--- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone_test.py
+++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone_test.py
@@ -19,10 +19,12 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models import GPTNeoXBackbone
from keras_nlp.tests.test_case import TestCase
+@pytest.mark.tf_only
class GPTNeoXTest(TestCase):
def setUp(self):
self.backbone = GPTNeoXBackbone(
@@ -34,8 +36,8 @@ def setUp(self):
max_sequence_length=10,
)
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
@@ -55,8 +57,8 @@ def test_name(self):
def test_variable_sequence_length(self):
for seq_length in (2, 3, 4):
input_data = {
- "token_ids": tf.ones((2, seq_length), dtype="int32"),
- "padding_mask": tf.ones((2, seq_length), dtype="int32"),
+ "token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "padding_mask": ops.ones((2, seq_length), dtype="int32"),
}
self.backbone(input_data)
@@ -99,8 +101,8 @@ def setUp(self):
max_sequence_length=10,
)
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py
index c60f9ee0ef..e8f8396d4a 100644
--- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py
+++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py
@@ -129,6 +129,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["airplane at airport"])
diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py
index 16fdb028ff..b20fb8157d 100644
--- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py
+++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py
@@ -15,8 +15,6 @@
"""GPTNeoX preprocessor layer."""
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker
-
-# from keras_nlp.models.gpt_neo_x.gpt_neo_x_presets import backbone_presets
from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import (
diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py
index 5cf3f485a9..463bfb7881 100644
--- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py
+++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor_test.py
@@ -113,6 +113,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["airplane at airport"])
diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py
index f227953099..4e2ca8961b 100644
--- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py
+++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py
@@ -74,7 +74,7 @@ def test_tokenize_end_token(self):
self.assertAllEqual(output, [1, 2, 3, 1, 4, 0])
def test_tokenize_batch(self):
- input_data = tf.constant([" airplane at airport", " kohli is the best"])
+ input_data = [" airplane at airport", " kohli is the best"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[1, 2, 3, 1, 4], [5, 6, 7, 8, 9]])
@@ -99,6 +99,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" airplane at airport"])
diff --git a/keras_nlp/models/opt/opt_backbone.py b/keras_nlp/models/opt/opt_backbone.py
index ff22c9cf40..fb7bb0bd70 100644
--- a/keras_nlp/models/opt/opt_backbone.py
+++ b/keras_nlp/models/opt/opt_backbone.py
@@ -18,6 +18,7 @@
from tensorflow.experimental import dtensor
from tensorflow.experimental.dtensor import Layout
+from tensorflow.keras.dtensor.experimental import LayoutMap
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
@@ -217,7 +218,7 @@ def create_layout_map(cls, mesh):
_, model_dim = mesh.dim_names
unshard_dim = dtensor.UNSHARDED
- layout_map = keras.dtensor.experimental.LayoutMap(mesh=mesh)
+ layout_map = LayoutMap(mesh=mesh)
# Embedding sharding
layout_map[r".*embeddings"] = Layout([unshard_dim, model_dim], mesh)
diff --git a/keras_nlp/models/opt/opt_backbone_test.py b/keras_nlp/models/opt/opt_backbone_test.py
index df53f493ac..6746075795 100644
--- a/keras_nlp/models/opt/opt_backbone_test.py
+++ b/keras_nlp/models/opt/opt_backbone_test.py
@@ -19,16 +19,13 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.opt.opt_backbone import OPTBackbone
from keras_nlp.tests.test_case import TestCase
-class OPTTest(TestCase):
+class OPTBackboneTest(TestCase):
def setUp(self):
- # For DTensor.
- keras.backend.experimental.enable_tf_random_generator()
- keras.utils.set_random_seed(1337)
-
self.backbone = OPTBackbone(
vocabulary_size=10,
num_layers=2,
@@ -38,8 +35,8 @@ def setUp(self):
max_sequence_length=5,
)
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
@@ -60,8 +57,8 @@ def test_name(self):
def test_variable_sequence_length_call_opt(self):
for seq_length in (2, 3, 4):
input_data = {
- "token_ids": tf.ones((2, seq_length), dtype="int32"),
- "padding_mask": tf.ones((2, seq_length), dtype="int32"),
+ "token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "padding_mask": ops.ones((2, seq_length), dtype="int32"),
}
self.backbone(input_data)
@@ -121,8 +118,8 @@ def setUp(self):
max_sequence_length=128,
)
self.input_batch = {
- "token_ids": tf.ones((8, 128), dtype="int32"),
- "padding_mask": tf.ones((8, 128), dtype="int32"),
+ "token_ids": ops.ones((8, 128), dtype="int32"),
+ "padding_mask": ops.ones((8, 128), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py
index 9395f8feb6..0bdedf1ed9 100644
--- a/keras_nlp/models/opt/opt_causal_lm.py
+++ b/keras_nlp/models/opt/opt_causal_lm.py
@@ -15,20 +15,32 @@
import copy
-import tensorflow as tf
-
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.generative_task import GenerativeTask
from keras_nlp.models.opt.opt_backbone import OPTBackbone
from keras_nlp.models.opt.opt_causal_lm_preprocessor import (
OPTCausalLMPreprocessor,
)
from keras_nlp.models.opt.opt_presets import backbone_presets
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
+# TODO: Extend and factor this out into keras_nlp.layers.
+class ReverseEmbedding(keras.layers.Layer):
+ def __init__(self, embedding, **kwargs):
+ super().__init__(**kwargs)
+ self.embedding = embedding
+
+ def call(self, inputs):
+ kernel = ops.transpose(ops.convert_to_tensor(self.embedding.embeddings))
+ return ops.matmul(inputs, kernel)
+
+ def compute_output_shape(self, input_shape):
+ return (input_shape[0],) + (self.embedding.embeddings.shape[0],)
+
+
@keras_nlp_export("keras_nlp.models.OPTCausalLM")
class OPTCausalLM(GenerativeTask):
"""An end-to-end OPT model for causal langauge modeling.
@@ -162,11 +174,10 @@ def __init__(
x = backbone(inputs)
# Use token embedding weights to project from the token representation
# to vocabulary logits.
- outputs = tf.matmul(
- x,
- backbone.token_embedding.embeddings,
- transpose_b=True,
- )
+ outputs = ReverseEmbedding(
+ backbone.token_embedding,
+ name="reverse_embedding",
+ )(x)
# Instantiate using Functional API Model constructor.
super().__init__(
@@ -185,7 +196,7 @@ def __init__(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(2e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
@classproperty
@@ -229,33 +240,30 @@ def call_with_cache(
token_ids, start_index=cache_update_index
)
# Each decoder layer has a cache; we update them separately.
- caches = tf.unstack(cache, axis=1)
+ caches = []
for i in range(self.backbone.num_layers):
- current_cache = caches[i]
+ current_cache = cache[:, i, ...]
x, next_cache = self.backbone.get_layer(f"transformer_layer_{i}")(
x,
self_attention_cache=current_cache,
self_attention_cache_update_index=cache_update_index,
)
- caches[i] = next_cache
- cache = tf.stack(caches, axis=1)
+ caches.append(next_cache)
+ cache = ops.stack(caches, axis=1)
x = self.backbone.get_layer("layer_norm")(x)
hidden_states = x
- logits = tf.matmul(
- hidden_states,
- self.backbone.token_embedding.embeddings,
- transpose_b=True,
- )
+ logits = self.get_layer("reverse_embedding")(x)
return logits, hidden_states, cache
def _build_cache(self, token_ids):
"""Build an empty cache for use with `call_with_cache()`."""
- batch_size, max_length = tf.shape(token_ids)[0], tf.shape(token_ids)[1]
+ batch_size = ops.shape(token_ids)[0]
+ max_length = ops.shape(token_ids)[1]
num_layers = self.backbone.num_layers
num_heads = self.backbone.num_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
- cache = tf.zeros(shape, dtype=self.compute_dtype)
+ cache = ops.zeros(shape, dtype=self.compute_dtype)
# Seed the cache.
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
return hidden_states, cache
@@ -282,24 +290,23 @@ def generate_step(
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
- row_lengths = tf.math.reduce_sum(
- tf.cast(padding_mask, "int32"), axis=-1
- )
+ row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
- index = tf.math.reduce_min(row_lengths)
+ index = ops.min(row_lengths)
def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_update_index = index - 1
- prompt = tf.slice(prompt, [0, cache_update_index], [-1, 1])
+ batch_size = ops.shape(prompt)[0]
+ prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
logits, hidden_states, cache = self.call_with_cache(
prompt,
cache,
cache_update_index,
)
return (
- tf.squeeze(logits, axis=1),
- tf.squeeze(hidden_states, axis=1),
+ ops.squeeze(logits, axis=1),
+ ops.squeeze(hidden_states, axis=1),
cache,
)
@@ -318,14 +325,15 @@ def next(prompt, cache, index):
# Build a mask of `end_token_id` locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = (token_ids == end_token_id) & (~padding_mask)
- end_locations = tf.cast(end_locations, "int32")
+ end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
- overflow = tf.math.cumsum(end_locations, exclusive=True, axis=-1)
+ cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
+ overflow = cumsum - end_locations
# Our padding mask is the inverse of these overflow locations.
- padding_mask = ~tf.cast(overflow, "bool")
+ padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
else:
# Without early stopping, all locations will have been updated.
- padding_mask = tf.ones_like(token_ids, dtype="bool")
+ padding_mask = ops.ones_like(token_ids, dtype="bool")
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
diff --git a/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py b/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py
index f6fd07e1e7..c64ebb343a 100644
--- a/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py
+++ b/keras_nlp/models/opt/opt_causal_lm_preprocessor_test.py
@@ -131,6 +131,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" airplane at airport"])
diff --git a/keras_nlp/models/opt/opt_causal_lm_test.py b/keras_nlp/models/opt/opt_causal_lm_test.py
index 4b523eb761..b4bff2c7d3 100644
--- a/keras_nlp/models/opt/opt_causal_lm_test.py
+++ b/keras_nlp/models/opt/opt_causal_lm_test.py
@@ -16,11 +16,11 @@
import os
from unittest.mock import patch
-import numpy as np
import pytest
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.opt.opt_backbone import OPTBackbone
from keras_nlp.models.opt.opt_causal_lm import OPTCausalLM
from keras_nlp.models.opt.opt_causal_lm_preprocessor import (
@@ -32,10 +32,6 @@
class OPTCausalLMTest(TestCase):
def setUp(self):
- # For DTensor.
- keras.backend.experimental.enable_tf_random_generator()
- keras.utils.set_random_seed(1337)
-
self.vocab = {
"": 0,
"": 1,
@@ -71,12 +67,10 @@ def setUp(self):
preprocessor=self.preprocessor,
)
- self.raw_batch = tf.constant(
- [
- " airplane at airport",
- " airplane at airport",
- ]
- )
+ self.raw_batch = [
+ " airplane at airport",
+ " airplane at airport",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)[0]
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
self.raw_batch
@@ -132,8 +126,10 @@ def test_early_stopping(self):
def wrapper(*args, **kwargs):
"""Modify output logits to always favor end_token_id"""
logits, hidden_states, cache = call_with_cache(*args, **kwargs)
- logits = np.zeros(logits.shape.as_list())
- logits[:, :, self.preprocessor.tokenizer.end_token_id] = 1.0e9
+ index = self.preprocessor.tokenizer.end_token_id
+ update = ops.ones_like(logits)[:, :, index] * 1.0e9
+ update = ops.expand_dims(update, axis=-1)
+ logits = ops.slice_update(logits, (0, 0, index), update)
return logits, hidden_states, cache
with patch.object(self.causal_lm, "call_with_cache", wraps=wrapper):
@@ -141,7 +137,6 @@ def wrapper(*args, **kwargs):
output = self.causal_lm.generate(prompt)
# We should immediately abort and output the prompt.
self.assertEqual(prompt, output)
- self.assertEqual(self.causal_lm.call_with_cache.call_count, 2)
def test_generate_compilation(self):
# Assert we do not recompile with successive calls.
@@ -167,7 +162,7 @@ def test_saved_model(self):
keras.utils.set_random_seed(42)
model_output = self.causal_lm.predict(self.raw_batch)
path = os.path.join(self.get_temp_dir(), "model.keras")
- self.seq_2_seq_lm.save(path, save_format="keras_v3")
+ self.causal_lm.save(path, save_format="keras_v3")
restored_model = keras.models.load_model(path)
# Check we got the real object back.
diff --git a/keras_nlp/models/opt/opt_preprocessor_test.py b/keras_nlp/models/opt/opt_preprocessor_test.py
index aaac994942..a06243ece9 100644
--- a/keras_nlp/models/opt/opt_preprocessor_test.py
+++ b/keras_nlp/models/opt/opt_preprocessor_test.py
@@ -112,6 +112,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" airplane at airport"])
diff --git a/keras_nlp/models/opt/opt_presets_test.py b/keras_nlp/models/opt/opt_presets_test.py
index ec95e5ecc6..744d9d11ef 100644
--- a/keras_nlp/models/opt/opt_presets_test.py
+++ b/keras_nlp/models/opt/opt_presets_test.py
@@ -14,9 +14,9 @@
"""Tests for loading pretrained model presets."""
import pytest
-import tensorflow as tf
from absl.testing import parameterized
+from keras_nlp.backend import ops
from keras_nlp.models.opt.opt_backbone import OPTBackbone
from keras_nlp.models.opt.opt_tokenizer import OPTTokenizer
from keras_nlp.tests.test_case import TestCase
@@ -42,8 +42,8 @@ def test_tokenizer_output(self):
)
def test_backbone_output(self, load_weights):
input_data = {
- "token_ids": tf.constant([[133, 2119, 6219, 23602, 4]]),
- "padding_mask": tf.constant([[1, 1, 1, 1, 1]]),
+ "token_ids": ops.array([[133, 2119, 6219, 23602, 4]]),
+ "padding_mask": ops.array([[1, 1, 1, 1, 1]]),
}
model = OPTBackbone.from_preset(
"opt_125m_en", load_weights=load_weights
@@ -95,12 +95,12 @@ def test_load_opt(self, load_weights):
for preset in OPTBackbone.presets:
model = OPTBackbone.from_preset(preset, load_weights=load_weights)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 1024),
dtype="int64",
maxval=model.vocabulary_size,
),
- "padding_mask": tf.constant([1] * 1024, shape=(1, 1024)),
+ "padding_mask": ops.array([1] * 1024, shape=(1, 1024)),
}
model(input_data)
diff --git a/keras_nlp/models/opt/opt_tokenizer.py b/keras_nlp/models/opt/opt_tokenizer.py
index 3be91276b0..e0a5d70a00 100644
--- a/keras_nlp/models/opt/opt_tokenizer.py
+++ b/keras_nlp/models/opt/opt_tokenizer.py
@@ -48,39 +48,26 @@ class OPTTokenizer(BytePairTokenizer):
merge entities separated by a space.
Examples:
-
- Batched inputs.
- >>> vocab = {"": 1, "": 2, "a": 3, "Ä quick": 4, "Ä fox": 5}
- >>> merges = ["Ä q", "u i", "c k", "ui ck", "Ä q uick"]
- >>> merges += ["Ä f", "o x", "Ä f ox"]
- >>> tokenizer = keras_nlp.models.OPTTokenizer(
- ... vocabulary=vocab,
- ... merges=merges,
- ... )
- >>> tokenizer(["a quick fox", "a fox quick"])
-
-
- Unbatched input.
- >>> vocab = {"": 1, "": 2, "a": 3, "Ä quick": 4, "Ä fox": 5}
- >>> merges = ["Ä q", "u i", "c k", "ui ck", "Ä q uick"]
- >>> merges += ["Ä f", "o x", "Ä f ox"]
- >>> tokenizer = keras_nlp.models.OPTTokenizer(
- ... vocabulary=vocab,
- ... merges=merges,
- ... )
- >>> tokenizer("a quick fox")
-
-
- Detokenization.
- >>> vocab = {"": 1, "": 2, "Ä quick": 4, "Ä fox": 5}
- >>> merges = ["Ä q", "u i", "c k", "ui ck", "Ä q uick"]
- >>> merges += ["Ä f", "o x", "Ä f ox"]
- >>> tokenizer = keras_nlp.models.OPTTokenizer(
- ... vocabulary=vocab,
- ... merges=merges,
- ... )
- >>> tokenizer.detokenize(tokenizer(" quick fox")).numpy().decode('utf-8')
- ' quick fox'
+ ```python
+ # Unbatched input.
+ tokenizer = keras_nlp.models.OPTTokenizer.from_preset(
+ "opt_125m_en",
+ )
+ tokenizer("The quick brown fox jumped.")
+
+ # Batched input.
+ tokenizer(["The quick brown fox jumped.", "The fox slept."])
+
+ # Detokenization.
+ tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
+
+ # Custom vocabulary.
+ vocab = {"": 1, "": 2, "Ä quick": 4, "Ä fox": 5}
+ merges = ["Ä q", "u i", "c k", "ui ck", "Ä q uick"]
+ merges += ["Ä f", "o x", "Ä f ox"]
+ tokenizer = keras_nlp.models.OPTTokenizer(vocabulary=vocab, merges=merges)
+ tokenizer("The quick brown fox jumped.")
+ ```
"""
def __init__(
diff --git a/keras_nlp/models/opt/opt_tokenizer_test.py b/keras_nlp/models/opt/opt_tokenizer_test.py
index 1b3a34a559..7fa37385f5 100644
--- a/keras_nlp/models/opt/opt_tokenizer_test.py
+++ b/keras_nlp/models/opt/opt_tokenizer_test.py
@@ -58,7 +58,7 @@ def test_tokenize_special_tokens(self):
self.assertAllEqual(output, [1, 2, 3, 4, 2, 5, 1, 0])
def test_tokenize_batch(self):
- input_data = tf.constant([" airplane at airport", " kohli is the best"])
+ input_data = [" airplane at airport", " kohli is the best"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[2, 3, 4, 2, 5], [6, 7, 8, 9, 10]])
@@ -83,6 +83,7 @@ def test_serialization(self):
)
@pytest.mark.large # Saving is slow, so mark these large.
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" airplane at airport"])
diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py
index df40dcb753..f24b9203a1 100644
--- a/keras_nlp/models/preprocessor.py
+++ b/keras_nlp/models/preprocessor.py
@@ -13,18 +13,28 @@
# limitations under the License.
from keras_nlp.backend import keras
+from keras_nlp.layers.preprocessing.preprocessing_layer import (
+ PreprocessingLayer,
+)
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring
@keras.saving.register_keras_serializable(package="keras_nlp")
-class Preprocessor(keras.layers.Layer):
+class Preprocessor(PreprocessingLayer):
"""Base class for model preprocessors."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._tokenizer = None
+ def __setattr__(self, name, value):
+ # Work around torch setattr for properties.
+ if name in ["tokenizer"]:
+ object.__setattr__(self, name, value)
+ else:
+ super().__setattr__(name, value)
+
@property
def tokenizer(self):
"""The tokenizer used to tokenize strings."""
diff --git a/keras_nlp/models/roberta/roberta_backbone_test.py b/keras_nlp/models/roberta/roberta_backbone_test.py
index f72cc180fa..4867823be9 100644
--- a/keras_nlp/models/roberta/roberta_backbone_test.py
+++ b/keras_nlp/models/roberta/roberta_backbone_test.py
@@ -19,6 +19,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.tests.test_case import TestCase
@@ -35,8 +36,8 @@ def setUp(self):
)
self.batch_size = 8
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
@@ -66,13 +67,13 @@ def test_serialization(self):
def test_variable_sequence_length_call_roberta(self):
for seq_length in (2, 3, 4):
input_data = {
- "token_ids": tf.ones((2, seq_length), dtype="int32"),
- "padding_mask": tf.ones((2, seq_length), dtype="int32"),
+ "token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "padding_mask": ops.ones((2, seq_length), dtype="int32"),
}
output = self.backbone(input_data)
self.assertAllEqual(
- tf.shape(output),
- [2, seq_length, self.backbone.hidden_dim],
+ ops.shape(output),
+ (2, seq_length, self.backbone.hidden_dim),
)
@pytest.mark.large # Saving is slow, so mark these large.
@@ -104,8 +105,8 @@ def setUp(self):
max_sequence_length=128,
)
self.input_batch = {
- "token_ids": tf.ones((8, 128), dtype="int32"),
- "padding_mask": tf.ones((8, 128), dtype="int32"),
+ "token_ids": ops.ones((8, 128), dtype="int32"),
+ "padding_mask": ops.ones((8, 128), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
diff --git a/keras_nlp/models/roberta/roberta_classifier.py b/keras_nlp/models/roberta/roberta_classifier.py
index d9461ed812..3f838c1d4c 100644
--- a/keras_nlp/models/roberta/roberta_classifier.py
+++ b/keras_nlp/models/roberta/roberta_classifier.py
@@ -22,7 +22,6 @@
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
from keras_nlp.models.roberta.roberta_presets import backbone_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -187,7 +186,7 @@ def __init__(
),
optimizer=keras.optimizers.Adam(2e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
def get_config(self):
diff --git a/keras_nlp/models/roberta/roberta_classifier_test.py b/keras_nlp/models/roberta/roberta_classifier_test.py
index 9f32494ad9..a512df0c85 100644
--- a/keras_nlp/models/roberta/roberta_classifier_test.py
+++ b/keras_nlp/models/roberta/roberta_classifier_test.py
@@ -19,6 +19,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
@@ -71,15 +72,13 @@ def setUp(self):
)
# Setup data.
- self.raw_batch = tf.constant(
- [
- " airplane at airport",
- " the airplane is the best",
- ]
- )
+ self.raw_batch = [
+ " airplane at airport",
+ " the airplane is the best",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
- (self.raw_batch, tf.ones((2,)))
+ (self.raw_batch, ops.ones((2,)))
).batch(2)
self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
@@ -93,7 +92,7 @@ def test_classifier_predict(self):
# Assert predictions match.
self.assertAllClose(preds1, preds2)
# Assert valid softmax output.
- self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0])
+ self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0])
def test_classifier_fit(self):
self.classifier.fit(self.raw_dataset)
diff --git a/keras_nlp/models/roberta/roberta_masked_lm.py b/keras_nlp/models/roberta/roberta_masked_lm.py
index c68877871f..659e6c2e00 100644
--- a/keras_nlp/models/roberta/roberta_masked_lm.py
+++ b/keras_nlp/models/roberta/roberta_masked_lm.py
@@ -25,7 +25,6 @@
)
from keras_nlp.models.roberta.roberta_presets import backbone_presets
from keras_nlp.models.task import Task
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -139,7 +138,7 @@ def __init__(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
@classproperty
diff --git a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py
index f827b87bd4..8bcba68264 100644
--- a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py
+++ b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor_test.py
@@ -149,6 +149,7 @@ def test_serialization(self):
)
@pytest.mark.large # Saving is slow, so mark these large.
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" airplane at airport"])
diff --git a/keras_nlp/models/roberta/roberta_masked_lm_test.py b/keras_nlp/models/roberta/roberta_masked_lm_test.py
index f6afa41676..838360d7b6 100644
--- a/keras_nlp/models/roberta/roberta_masked_lm_test.py
+++ b/keras_nlp/models/roberta/roberta_masked_lm_test.py
@@ -73,12 +73,10 @@ def setUp(self):
preprocessor=None,
)
- self.raw_batch = tf.constant(
- [
- " airplane at airport",
- " the airplane is the best",
- ]
- )
+ self.raw_batch = [
+ " airplane at airport",
+ " the airplane is the best",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
self.raw_batch
diff --git a/keras_nlp/models/roberta/roberta_preprocessor_test.py b/keras_nlp/models/roberta/roberta_preprocessor_test.py
index 785a090769..2df4519561 100644
--- a/keras_nlp/models/roberta/roberta_preprocessor_test.py
+++ b/keras_nlp/models/roberta/roberta_preprocessor_test.py
@@ -148,6 +148,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" airplane at airport"])
diff --git a/keras_nlp/models/roberta/roberta_presets_test.py b/keras_nlp/models/roberta/roberta_presets_test.py
index c5cb97d255..c3f5f63aaf 100644
--- a/keras_nlp/models/roberta/roberta_presets_test.py
+++ b/keras_nlp/models/roberta/roberta_presets_test.py
@@ -14,9 +14,9 @@
"""Tests for loading pretrained model presets."""
import pytest
-import tensorflow as tf
from absl.testing import parameterized
+from keras_nlp.backend import ops
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
from keras_nlp.models.roberta.roberta_masked_lm import RobertaMaskedLM
@@ -56,8 +56,8 @@ def test_preprocessor_output(self):
)
def test_backbone_output(self, load_weights):
input_data = {
- "token_ids": tf.constant([[0, 133, 2119, 2]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[0, 133, 2119, 2]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = RobertaBackbone.from_preset(
"roberta_base_en", load_weights=load_weights
@@ -84,8 +84,8 @@ def test_classifier_output(self, load_weights):
)
def test_classifier_output_without_preprocessing(self, load_weights):
input_data = {
- "token_ids": tf.constant([[101, 1996, 4248, 102]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[101, 1996, 4248, 102]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = RobertaClassifier.from_preset(
"roberta_base_en",
@@ -112,9 +112,9 @@ def test_masked_lm_output(self, load_weights):
)
def test_masked_lm_output_without_preprocessing(self, load_weights):
input_data = {
- "token_ids": tf.constant([[101, 1996, 4248, 102]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
- "mask_positions": tf.constant([[0, 0]]),
+ "token_ids": ops.array([[101, 1996, 4248, 102]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
+ "mask_positions": ops.array([[0, 0]]),
}
model = RobertaMaskedLM.from_preset(
"roberta_base_en",
@@ -168,10 +168,10 @@ def test_load_roberta(self, load_weights):
preset, load_weights=load_weights
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512), dtype="int64", maxval=model.vocabulary_size
),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
model(input_data)
@@ -183,7 +183,7 @@ def test_load_roberta_classifier(self, load_weights):
classifier = RobertaClassifier.from_preset(
preset, num_classes=4, load_weights=load_weights
)
- input_data = tf.constant(["The quick brown fox."])
+ input_data = ["The quick brown fox."]
classifier.predict(input_data)
@parameterized.named_parameters(
@@ -198,12 +198,12 @@ def test_load_roberta_classifier_without_preprocessing(self, load_weights):
load_weights=load_weights,
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512),
dtype="int64",
maxval=classifier.backbone.vocabulary_size,
),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
classifier.predict(input_data)
@@ -215,7 +215,7 @@ def test_load_roberta_masked_lm(self, load_weights):
classifier = RobertaMaskedLM.from_preset(
preset, load_weights=load_weights
)
- input_data = tf.constant(["The quick brown fox."])
+ input_data = ["The quick brown fox."]
classifier.predict(input_data)
@parameterized.named_parameters(
@@ -229,13 +229,13 @@ def test_load_roberta_masked_lm_without_preprocessing(self, load_weights):
load_weights=load_weights,
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512),
dtype="int64",
maxval=classifier.backbone.vocabulary_size,
),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
- "mask_positions": tf.constant([1] * 128, shape=(1, 128)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
+ "mask_positions": ops.array([1] * 128, shape=(1, 128)),
}
classifier.predict(input_data)
diff --git a/keras_nlp/models/roberta/roberta_tokenizer_test.py b/keras_nlp/models/roberta/roberta_tokenizer_test.py
index 2df2c07658..9f5f390f51 100644
--- a/keras_nlp/models/roberta/roberta_tokenizer_test.py
+++ b/keras_nlp/models/roberta/roberta_tokenizer_test.py
@@ -59,7 +59,7 @@ def test_tokenize_special_tokens(self):
self.assertAllEqual(output, [0, 3, 4, 5, 3, 6, 0, 1])
def test_tokenize_batch(self):
- input_data = tf.constant([" airplane at airport", " kohli is the best"])
+ input_data = [" airplane at airport", " kohli is the best"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[3, 4, 5, 3, 6], [7, 8, 9, 10, 11]])
@@ -84,6 +84,7 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" airplane at airport"])
diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py
index f3b1ca2edd..b23e0becd4 100644
--- a/keras_nlp/models/t5/t5_backbone.py
+++ b/keras_nlp/models/t5/t5_backbone.py
@@ -14,14 +14,13 @@
"""T5 backbone model."""
-import tensorflow as tf
-
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm
from keras_nlp.models.t5.t5_transformer_layer import T5TransformerLayer
from keras_nlp.utils.python_utils import classproperty
+from keras_nlp.utils.tensor_utils import assert_tf_backend
@keras_nlp_export("keras_nlp.models.T5Backbone")
@@ -79,6 +78,8 @@ def __init__(
layer_norm_epsilon=1e-06,
**kwargs,
):
+ assert_tf_backend(self.__class__.__name__)
+
# Encoder inputs
encoder_token_ids = keras.Input(
shape=(None,), dtype="int32", name="encoder_token_ids"
@@ -112,8 +113,7 @@ def __init__(
name="encoder_embedding_dropout",
)(token_embedding)
- # Encoder attention mask is just our padding mask.
- encoder_attention_mask = encoder_padding_mask[:, tf.newaxis, :]
+ encoder_attention_mask = encoder_padding_mask[:, None, :]
position_bias = None
for i in range(num_layers):
diff --git a/keras_nlp/models/t5/t5_backbone_test.py b/keras_nlp/models/t5/t5_backbone_test.py
index be71ba8bac..8666168e8b 100644
--- a/keras_nlp/models/t5/t5_backbone_test.py
+++ b/keras_nlp/models/t5/t5_backbone_test.py
@@ -19,10 +19,12 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.t5.t5_backbone import T5Backbone
from keras_nlp.tests.test_case import TestCase
+@pytest.mark.tf_only
class T5Test(TestCase):
def setUp(self):
self.backbone = T5Backbone(
@@ -35,16 +37,16 @@ def setUp(self):
self.batch_size = 2
seq_length = 3
self.input_batch = {
- "encoder_token_ids": tf.ones(
+ "encoder_token_ids": ops.ones(
(self.batch_size, seq_length), dtype="int32"
),
- "encoder_padding_mask": tf.ones(
+ "encoder_padding_mask": ops.ones(
(self.batch_size, seq_length), dtype="int32"
),
- "decoder_token_ids": tf.ones(
+ "decoder_token_ids": ops.ones(
(self.batch_size, seq_length), dtype="int32"
),
- "decoder_padding_mask": tf.ones(
+ "decoder_padding_mask": ops.ones(
(self.batch_size, seq_length), dtype="int32"
),
}
@@ -68,16 +70,16 @@ def test_name(self):
def test_variable_sequence_length_call_t5(self):
for seq_length in (2, 3, 4):
input_data = {
- "encoder_token_ids": tf.ones(
+ "encoder_token_ids": ops.ones(
(self.batch_size, seq_length), dtype="int32"
),
- "encoder_padding_mask": tf.ones(
+ "encoder_padding_mask": ops.ones(
(self.batch_size, seq_length), dtype="int32"
),
- "decoder_token_ids": tf.ones(
+ "decoder_token_ids": ops.ones(
(self.batch_size, seq_length), dtype="int32"
),
- "decoder_padding_mask": tf.ones(
+ "decoder_padding_mask": ops.ones(
(self.batch_size, seq_length), dtype="int32"
),
}
@@ -124,8 +126,8 @@ def setUp(self):
intermediate_dim=4,
)
self.input_batch = {
- "token_ids": tf.ones((8, 4), dtype="int32"),
- "padding_mask": tf.ones((8, 4), dtype="int32"),
+ "token_ids": ops.ones((8, 4), dtype="int32"),
+ "padding_mask": ops.ones((8, 4), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
diff --git a/keras_nlp/models/t5/t5_layer_norm.py b/keras_nlp/models/t5/t5_layer_norm.py
index 8eec5c8bf9..dab361e279 100644
--- a/keras_nlp/models/t5/t5_layer_norm.py
+++ b/keras_nlp/models/t5/t5_layer_norm.py
@@ -27,6 +27,7 @@ def build(self, input_shape):
shape=(input_shape[-1],),
initializer="ones",
)
+ self.built = True
def call(self, hidden_states):
variance = tf.math.reduce_mean(
diff --git a/keras_nlp/models/t5/t5_multi_head_attention.py b/keras_nlp/models/t5/t5_multi_head_attention.py
index 3288d4f93e..298cf67b18 100644
--- a/keras_nlp/models/t5/t5_multi_head_attention.py
+++ b/keras_nlp/models/t5/t5_multi_head_attention.py
@@ -82,7 +82,6 @@ def __init__(
)
self.dropout_layer = keras.layers.Dropout(dropout)
- def build(self, input_shape):
if self.use_relative_attention_bias:
self.relative_attention_bias = self.add_weight(
name="embeddings",
diff --git a/keras_nlp/models/t5/t5_tokenizer_test.py b/keras_nlp/models/t5/t5_tokenizer_test.py
index ca9a7cba97..a06435fa8c 100644
--- a/keras_nlp/models/t5/t5_tokenizer_test.py
+++ b/keras_nlp/models/t5/t5_tokenizer_test.py
@@ -55,14 +55,14 @@ def test_tokenize(self):
self.assertAllEqual(output, [4, 9, 5, 7])
def test_tokenize_batch(self):
- input_data = tf.constant(["the quick brown fox", "the earth is round"])
+ input_data = ["the quick brown fox", "the earth is round"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[4, 9, 5, 7], [4, 6, 8, 10]])
def test_detokenize(self):
- input_data = tf.constant([[4, 9, 5, 7]])
+ input_data = [[4, 9, 5, 7]]
output = self.tokenizer.detokenize(input_data)
- self.assertEqual(output, tf.constant(["the quick brown fox"]))
+ self.assertEqual(output, ["the quick brown fox"])
def test_vocabulary_size(self):
tokenizer = T5Tokenizer(proto=self.proto)
@@ -90,6 +90,7 @@ def test_serialization(self):
)
@pytest.mark.large # Saving is slow, so mark these large.
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py
index 67d8547134..6de3aee0ec 100644
--- a/keras_nlp/models/task.py
+++ b/keras_nlp/models/task.py
@@ -15,11 +15,12 @@
import os
+import keras_core
+import rich
import tensorflow as tf
from keras_nlp.backend import keras
from keras_nlp.utils.keras_utils import print_msg
-from keras_nlp.utils.keras_utils import print_row
from keras_nlp.utils.pipeline_model import PipelineModel
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring
@@ -34,7 +35,7 @@ def __init__(self, *args, **kwargs):
self._preprocessor = None
super().__init__(*args, **kwargs)
- def _check_for_loss_mismatch(self):
+ def _check_for_loss_mismatch(self, loss):
"""Check for a softmax/from_logits mismatch after compile.
We cannot handle this in the general case, but we can handle this for
@@ -42,13 +43,13 @@ def _check_for_loss_mismatch(self):
loss, and a `None` or `"softmax"` activation.
"""
# Only handle a single loss.
- if tf.nest.is_nested(self.loss):
+ if tf.nest.is_nested(loss):
return
# Only handle tasks with activation.
if not hasattr(self, "activation"):
return
- loss = keras.losses.get(self.loss)
+ loss = keras.losses.get(loss)
activation = keras.activations.get(self.activation)
if isinstance(loss, keras.losses.SparseCategoricalCrossentropy):
from_logits = loss.get_config()["from_logits"]
@@ -77,13 +78,20 @@ def _check_for_loss_mismatch(self):
"`loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True)`. "
)
- def compile(self, *args, **kwargs):
- super().compile(*args, **kwargs)
- self._check_for_loss_mismatch()
+ def compile(self, optimizer="rmsprop", loss=None, **kwargs):
+ self._check_for_loss_mismatch(loss)
+ super().compile(optimizer=optimizer, loss=loss, **kwargs)
def preprocess_samples(self, x, y=None, sample_weight=None):
return self.preprocessor(x, y=y, sample_weight=sample_weight)
+ def __setattr__(self, name, value):
+ # Work around torch setattr for properties.
+ if name in ["backbone", "preprocessor"]:
+ object.__setattr__(self, name, value)
+ else:
+ super().__setattr__(name, value)
+
@property
def backbone(self):
"""A `keras.Model` instance providing the backbone submodel."""
@@ -110,7 +118,6 @@ def get_config(self):
"backbone": keras.layers.serialize(self.backbone),
"preprocessor": keras.layers.serialize(self.preprocessor),
"name": self.name,
- "trainable": self.trainable,
}
@classmethod
@@ -240,31 +247,67 @@ def summary(
**kwargs,
):
"""Override `model.summary()` to show a preprocessor if set."""
- # Defaults are copied from core Keras; we should try to stay in sync.
- line_length = line_length or 98
- positions = positions or [0.33, 0.55, 0.67, 1.0]
- if positions[-1] <= 1:
- positions = [int(line_length * p) for p in positions]
- if print_fn is None:
+ # Below is copied from keras-core for now.
+ # We should consider an API contract.
+ line_length = line_length or 108
+
+ if not print_fn and not keras.utils.is_interactive_logging_enabled():
print_fn = print_msg
+ def highlight_number(x):
+ return f"[color(45)]{x}[/]" if x is None else f"[color(34)]{x}[/]"
+
+ def highlight_symbol(x):
+ return f"[color(33)]{x}[/]"
+
+ def bold_text(x):
+ return f"[bold]{x}[/]"
+
if self.preprocessor:
- column_names = ["Tokenizer (type)", "Vocab #"]
+ # Create a rich console for printing. Capture for non-interactive logging.
+ if print_fn:
+ console = rich.console.Console(
+ highlight=False, force_terminal=False, color_system=None
+ )
+ console.begin_capture()
+ else:
+ console = rich.console.Console(highlight=False)
+
+ column_1 = rich.table.Column(
+ "Tokenizer (type)",
+ justify="left",
+ width=int(0.5 * line_length),
+ )
+ column_2 = rich.table.Column(
+ "Vocab #",
+ justify="right",
+ width=int(0.5 * line_length),
+ )
+ table = rich.table.Table(
+ column_1, column_2, width=line_length, show_lines=True
+ )
tokenizer = self.preprocessor.tokenizer
- column_values = [
- f"{tokenizer.name} ({tokenizer.__class__.__name__})",
- f"{tokenizer.vocabulary_size()}",
- ]
-
- print_fn(f'Preprocessor: "{self.preprocessor.name}"')
- print_fn("_" * line_length)
- print_row(column_names, positions[1:3], print_fn)
- print_fn("=" * line_length)
- print_row(column_values, positions[1:3], print_fn)
- print_fn("_" * line_length)
- print_fn(" " * line_length)
-
- super().summary(
+ tokenizer_name = rich.markup.escape(tokenizer.name)
+ tokenizer_class = highlight_symbol(
+ rich.markup.escape(tokenizer.__class__.__name__)
+ )
+ table.add_row(
+ f"{tokenizer_name} ({tokenizer_class})",
+ highlight_number(f"{tokenizer.vocabulary_size():,}"),
+ )
+
+ # Print the to the console.
+ preprocessor_name = rich.markup.escape(self.preprocessor.name)
+ console.print(bold_text(f'Preprocessor: "{preprocessor_name}"'))
+ console.print(table)
+
+ # Output captured summary for non-interactive logging.
+ if print_fn:
+ print_fn(console.end_capture(), line_break=False)
+
+ # Hardcode summary from keras_core for now.
+ keras_core.Model.summary(
+ self,
line_length=line_length,
positions=positions,
print_fn=print_fn,
diff --git a/keras_nlp/models/task_test.py b/keras_nlp/models/task_test.py
index 506780c159..4155eb2e63 100644
--- a/keras_nlp/models/task_test.py
+++ b/keras_nlp/models/task_test.py
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from tensorflow.keras.losses import SparseCategoricalCrossentropy
-
from keras_nlp.backend import keras
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.models.task import Task
@@ -45,36 +43,52 @@ def test_summary_with_preprocessor(self):
preprocessor = SimplePreprocessor()
model = SimpleTask(preprocessor)
summary = []
- model.summary(print_fn=lambda x: summary.append(x))
+ model.summary(print_fn=lambda x, line_break: summary.append(x))
self.assertRegex("\n".join(summary), "Preprocessor:")
def test_summary_without_preprocessor(self):
model = SimpleTask()
summary = []
- model.summary(print_fn=lambda x: summary.append(x))
+ model.summary(print_fn=lambda x, line_break: summary.append(x))
self.assertNotRegex("\n".join(summary), "Preprocessor:")
def test_mismatched_loss(self):
# Logit output.
model = SimpleTask(activation=None)
- model.compile(loss=SparseCategoricalCrossentropy(from_logits=True))
+ model.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ )
# Non-standard losses should not throw.
model.compile(loss="mean_squared_error")
with self.assertRaises(ValueError):
model.compile(loss="sparse_categorical_crossentropy")
with self.assertRaises(ValueError):
- model.compile(loss=SparseCategoricalCrossentropy(from_logits=False))
+ model.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(
+ from_logits=False
+ )
+ )
# Probability output.
model = SimpleTask(activation="softmax")
- model.compile(loss=SparseCategoricalCrossentropy(from_logits=False))
+ model.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False)
+ )
model.compile(loss="sparse_categorical_crossentropy")
# Non-standard losses should not throw.
model.compile(loss="mean_squared_error")
with self.assertRaises(ValueError):
- model.compile(loss=SparseCategoricalCrossentropy(from_logits=True))
+ model.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(
+ from_logits=True
+ )
+ )
# Non-standard activations should not throw.
model = SimpleTask(activation="tanh")
- model.compile(loss=SparseCategoricalCrossentropy(from_logits=True))
- model.compile(loss=SparseCategoricalCrossentropy(from_logits=False))
+ model.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ )
+ model.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False)
+ )
diff --git a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py
index 7bcfc80790..a8c8758f66 100644
--- a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py
+++ b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py
@@ -82,6 +82,10 @@ def __init__(
super().__init__(**kwargs)
+ self._convert_input_args = False
+ self._allow_non_tensor_positional_args = True
+ self.built = True
+
self.num_mels = num_mels
self.num_fft_bins = num_fft_bins
self.stride = stride
diff --git a/keras_nlp/models/whisper/whisper_audio_feature_extractor_test.py b/keras_nlp/models/whisper/whisper_audio_feature_extractor_test.py
index 8961c93e1b..ddf4b331fd 100644
--- a/keras_nlp/models/whisper/whisper_audio_feature_extractor_test.py
+++ b/keras_nlp/models/whisper/whisper_audio_feature_extractor_test.py
@@ -79,6 +79,7 @@ def test_serialization(self):
)
@pytest.mark.large # Saving is slow, so mark these large.
+ @pytest.mark.tf_only
def test_saved_model(self):
audio_tensor = tf.ones((2,), dtype="float32")
diff --git a/keras_nlp/models/whisper/whisper_backbone.py b/keras_nlp/models/whisper/whisper_backbone.py
index fd85cc04c9..1f68853ba1 100644
--- a/keras_nlp/models/whisper/whisper_backbone.py
+++ b/keras_nlp/models/whisper/whisper_backbone.py
@@ -14,10 +14,9 @@
"""Whisper backbone model."""
-import tensorflow as tf
-
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.layers.modeling.position_embedding import PositionEmbedding
from keras_nlp.layers.modeling.token_and_position_embedding import (
TokenAndPositionEmbedding,
@@ -25,12 +24,18 @@
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.whisper.whisper_decoder import WhisperDecoder
from keras_nlp.models.whisper.whisper_encoder import WhisperEncoder
+from keras_nlp.utils.tensor_utils import assert_tf_backend
def whisper_kernel_initializer(stddev=0.02):
return keras.initializers.TruncatedNormal(stddev=stddev)
+class Padder(keras.layers.Layer):
+ def call(self, x):
+ return ops.pad(x, [[0, 0], [1, 1], [0, 0]])
+
+
@keras_nlp_export("keras_nlp.models.WhisperBackbone")
class WhisperBackbone(Backbone):
"""A Whisper encoder-decoder network for speech.
@@ -107,6 +112,8 @@ def __init__(
max_decoder_sequence_length=448,
**kwargs,
):
+ assert_tf_backend(self.__class__.__name__)
+
# Encoder inputs. Note that the encoder does not have a padding mask:
# https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L132.
encoder_feature_input = keras.Input(
@@ -142,9 +149,7 @@ def __init__(
# For the second conv. layer, we cannot use `padding="same"` since
# that corresponds to a padding size of 1.5 (since stride is 2). Hence,
# we will manually pad the input.
- embedded_features = tf.pad(
- embedded_features, paddings=[[0, 0], [1, 1], [0, 0]]
- )
+ embedded_features = Padder()(embedded_features)
encoder_conv_layer_2 = keras.layers.Conv1D(
filters=hidden_dim,
kernel_size=3,
diff --git a/keras_nlp/models/whisper/whisper_backbone_test.py b/keras_nlp/models/whisper/whisper_backbone_test.py
index 587855762f..60a2c89996 100644
--- a/keras_nlp/models/whisper/whisper_backbone_test.py
+++ b/keras_nlp/models/whisper/whisper_backbone_test.py
@@ -19,10 +19,12 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.whisper.whisper_backbone import WhisperBackbone
from keras_nlp.tests.test_case import TestCase
+@pytest.mark.tf_only
class WhisperBackboneTest(TestCase):
def setUp(self):
self.backbone = WhisperBackbone(
@@ -35,9 +37,9 @@ def setUp(self):
max_decoder_sequence_length=6,
)
self.input_batch = {
- "encoder_features": tf.ones((2, 5, 80), dtype="float32"),
- "decoder_token_ids": tf.ones((2, 5), dtype="int32"),
- "decoder_padding_mask": tf.ones((2, 5), dtype="int32"),
+ "encoder_features": ops.ones((2, 5, 80), dtype="float32"),
+ "decoder_token_ids": ops.ones((2, 5), dtype="int32"),
+ "decoder_padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
@@ -60,11 +62,13 @@ def test_name(self):
def test_variable_sequence_length_call_whisper(self):
for seq_length in (2, 3, 4):
input_data = {
- "encoder_features": tf.ones(
+ "encoder_features": ops.ones(
(2, seq_length, 80), dtype="float32"
),
- "decoder_token_ids": tf.ones((2, seq_length), dtype="int32"),
- "decoder_padding_mask": tf.ones((2, seq_length), dtype="int32"),
+ "decoder_token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "decoder_padding_mask": ops.ones(
+ (2, seq_length), dtype="int32"
+ ),
}
self.backbone(input_data)
@@ -134,7 +138,7 @@ def setUp(self):
)
self.input_batch = {
- "encoder_features": tf.ones(
+ "encoder_features": ops.ones(
(
8,
self.backbone.max_encoder_sequence_length,
@@ -142,10 +146,10 @@ def setUp(self):
),
dtype="int32",
),
- "decoder_token_ids": tf.ones(
+ "decoder_token_ids": ops.ones(
(8, self.backbone.max_decoder_sequence_length), dtype="int32"
),
- "decoder_padding_mask": tf.ones(
+ "decoder_padding_mask": ops.ones(
(8, self.backbone.max_decoder_sequence_length), dtype="int32"
),
}
diff --git a/keras_nlp/models/whisper/whisper_encoder.py b/keras_nlp/models/whisper/whisper_encoder.py
index 7127121498..4527e3fa51 100644
--- a/keras_nlp/models/whisper/whisper_encoder.py
+++ b/keras_nlp/models/whisper/whisper_encoder.py
@@ -22,7 +22,7 @@ class WhisperEncoder(TransformerEncoder):
"""A Whisper encoder.
Inherits from `keras_nlp.layers.TransformerEncoder`, and overrides the
- `_build` method so as to remove the bias term from the key projection layer.
+ `build` method so as to remove the bias term from the key projection layer.
"""
def build(self, inputs_shape):
diff --git a/keras_nlp/models/whisper/whisper_tokenizer_test.py b/keras_nlp/models/whisper/whisper_tokenizer_test.py
index 6fcb28123c..d94b7ad79a 100644
--- a/keras_nlp/models/whisper/whisper_tokenizer_test.py
+++ b/keras_nlp/models/whisper/whisper_tokenizer_test.py
@@ -23,6 +23,7 @@
from keras_nlp.tests.test_case import TestCase
+@pytest.mark.tf_only
class WhisperTokenizerTest(TestCase):
def setUp(self):
self.vocab = {
@@ -77,7 +78,7 @@ def test_tokenize_special_tokens(self):
self.assertAllEqual(output, [9, 14, 12, 11, 0, 1, 2, 0, 3, 10])
def test_tokenize_batch(self):
- input_data = tf.constant([" airplane at airport", " kohli is the best"])
+ input_data = [" airplane at airport", " kohli is the best"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[0, 1, 2, 0, 3], [4, 5, 6, 7, 8]])
@@ -113,6 +114,7 @@ def test_errors_missing_special_tokens(self):
)
@pytest.mark.large # Saving is slow, so mark these large.
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" airplane at airport"])
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone_test.py
index 76d9a38f37..00387b8d46 100644
--- a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone_test.py
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone_test.py
@@ -19,6 +19,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone
from keras_nlp.tests.test_case import TestCase
@@ -34,8 +35,8 @@ def setUp(self):
max_sequence_length=5,
)
self.input_batch = {
- "token_ids": tf.ones((2, 5), dtype="int32"),
- "padding_mask": tf.ones((2, 5), dtype="int32"),
+ "token_ids": ops.ones((2, 5), dtype="int32"),
+ "padding_mask": ops.ones((2, 5), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
@@ -55,12 +56,13 @@ def test_name(self):
def test_variable_sequence_length_call_xlm_roberta(self):
for seq_length in (2, 3, 4):
input_data = {
- "token_ids": tf.ones((2, seq_length), dtype="int32"),
- "padding_mask": tf.ones((2, seq_length), dtype="int32"),
+ "token_ids": ops.ones((2, seq_length), dtype="int32"),
+ "padding_mask": ops.ones((2, seq_length), dtype="int32"),
}
output = self.backbone(input_data)
self.assertAllEqual(
- tf.shape(output), [2, seq_length, self.backbone.hidden_dim]
+ ops.shape(output),
+ (2, seq_length, self.backbone.hidden_dim),
)
def test_predict(self):
@@ -102,8 +104,8 @@ def setUp(self):
max_sequence_length=128,
)
self.input_batch = {
- "token_ids": tf.ones((8, 128), dtype="int32"),
- "padding_mask": tf.ones((8, 128), dtype="int32"),
+ "token_ids": ops.ones((8, 128), dtype="int32"),
+ "padding_mask": ops.ones((8, 128), dtype="int32"),
}
self.input_dataset = tf.data.Dataset.from_tensor_slices(
self.input_batch
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py
index 9365ac4b37..344d19b8a6 100644
--- a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py
@@ -24,7 +24,6 @@
XLMRobertaPreprocessor,
)
from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -199,7 +198,7 @@ def __init__(
),
optimizer=keras.optimizers.Adam(5e-5),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
def preprocess_samples(self, x, y=None, sample_weight=None):
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier_test.py
index 72d7b98927..916b0caf57 100644
--- a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier_test.py
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier_test.py
@@ -21,6 +21,7 @@
import tensorflow as tf
from keras_nlp.backend import keras
+from keras_nlp.backend import ops
from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone
from keras_nlp.models.xlm_roberta.xlm_roberta_classifier import (
XLMRobertaClassifier,
@@ -70,15 +71,13 @@ def setUp(self):
hidden_dim=4,
)
- self.raw_batch = tf.constant(
- [
- "the quick brown fox.",
- "the slow brown fox.",
- ]
- )
+ self.raw_batch = [
+ "the quick brown fox.",
+ "the slow brown fox.",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
- (self.raw_batch, tf.ones((2,)))
+ (self.raw_batch, ops.ones((2,)))
).batch(2)
self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
@@ -92,7 +91,7 @@ def test_classifier_predict(self):
# Assert predictions match.
self.assertAllClose(preds1, preds2)
# Assert valid softmax output.
- self.assertAllClose(tf.reduce_sum(preds2, axis=-1), [1.0, 1.0])
+ self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0])
def test_classifier_fit(self):
self.classifier.fit(self.raw_dataset)
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py
index 61b747c9d5..eeeb051366 100644
--- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py
@@ -25,7 +25,6 @@
XLMRobertaMaskedLMPreprocessor,
)
from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets
-from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
@@ -142,7 +141,7 @@ def __init__(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
- jit_compile=is_xla_compatible(self),
+ jit_compile=True,
)
@classproperty
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py
index 06afc8621b..f8962b94db 100644
--- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor_test.py
@@ -151,17 +151,18 @@ def test_serialization(self):
)
@pytest.mark.large
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant([" quick brown fox"])
inputs = keras.Input(dtype="string", shape=())
- outputs = self.preprocessor(inputs)
+ outputs, y, sw = self.preprocessor(inputs)
model = keras.Model(inputs, outputs)
path = os.path.join(self.get_temp_dir(), "model.keras")
model.save(path, save_format="keras_v3")
restored_model = keras.models.load_model(path)
- outputs = model(input_data)[0]["token_ids"]
- restored_outputs = restored_model(input_data)[0]["token_ids"]
+ outputs = model(input_data)["token_ids"]
+ restored_outputs = restored_model(input_data)["token_ids"]
self.assertAllEqual(outputs, restored_outputs)
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py
index a46985c8a7..b0129f6707 100644
--- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_test.py
@@ -77,9 +77,10 @@ def setUp(self):
preprocessor=self.preprocessor,
)
- self.raw_batch = tf.constant(
- ["the quick brown fox", "the slow brown fox"]
- )
+ self.raw_batch = [
+ "the quick brown fox",
+ "the slow brown fox",
+ ]
self.preprocessed_batch = self.preprocessor(self.raw_batch)[0]
self.raw_dataset = tf.data.Dataset.from_tensor_slices(
self.raw_batch
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py
index 68f145d883..19049fdb56 100644
--- a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor_test.py
@@ -141,6 +141,7 @@ def test_serialization(self):
)
@pytest.mark.large # Saving is slow, so mark these large.
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
inputs = keras.Input(dtype="string", shape=())
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_presets_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_presets_test.py
index e1d74fbaea..1d914bbd60 100644
--- a/keras_nlp/models/xlm_roberta/xlm_roberta_presets_test.py
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_presets_test.py
@@ -14,9 +14,9 @@
"""Tests for loading pretrained model presets."""
import pytest
-import tensorflow as tf
from absl.testing import parameterized
+from keras_nlp.backend import ops
from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone
from keras_nlp.models.xlm_roberta.xlm_roberta_classifier import (
XLMRobertaClassifier,
@@ -61,8 +61,8 @@ def test_preprocessor_output(self):
)
def test_backbone_output(self, load_weights):
input_data = {
- "token_ids": tf.constant([[0, 581, 63773, 2]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[0, 581, 63773, 2]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = XLMRobertaBackbone.from_preset(
"xlm_roberta_base_multi", load_weights=load_weights
@@ -77,7 +77,7 @@ def test_backbone_output(self, load_weights):
("preset_weights", True), ("random_weights", False)
)
def test_classifier_output(self, load_weights):
- input_data = tf.constant(["The quick brown fox."])
+ input_data = ["The quick brown fox."]
model = XLMRobertaClassifier.from_preset(
"xlm_roberta_base_multi", num_classes=2, load_weights=load_weights
)
@@ -89,8 +89,8 @@ def test_classifier_output(self, load_weights):
)
def test_classifier_output_without_preprocessing(self, load_weights):
input_data = {
- "token_ids": tf.constant([[0, 581, 63773, 2]]),
- "padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "token_ids": ops.array([[0, 581, 63773, 2]]),
+ "padding_mask": ops.array([[1, 1, 1, 1]]),
}
model = XLMRobertaClassifier.from_preset(
"xlm_roberta_base_multi",
@@ -143,10 +143,10 @@ def test_load_xlm_roberta(self, load_weights):
preset, load_weights=load_weights
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512), dtype="int64", maxval=model.vocabulary_size
),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
model(input_data)
@@ -160,7 +160,7 @@ def test_load_xlm_roberta_classifier(self, load_weights):
num_classes=4,
load_weights=load_weights,
)
- input_data = tf.constant(["This quick brown fox"])
+ input_data = ["The quick brown fox."]
classifier.predict(input_data)
@parameterized.named_parameters(
@@ -177,12 +177,12 @@ def test_load_xlm_roberta_classifier_without_preprocessing(
preprocessor=None,
)
input_data = {
- "token_ids": tf.random.uniform(
+ "token_ids": ops.random.uniform(
shape=(1, 512),
dtype="int64",
maxval=classifier.backbone.vocabulary_size,
),
- "padding_mask": tf.constant([1] * 512, shape=(1, 512)),
+ "padding_mask": ops.array([1] * 512, shape=(1, 512)),
}
classifier.predict(input_data)
diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py
index f4a469615b..f81d81d75e 100644
--- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py
+++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py
@@ -52,7 +52,7 @@ def test_tokenize(self):
self.assertAllEqual(output, [4, 9, 5, 7])
def test_tokenize_batch(self):
- input_data = tf.constant(["the quick brown fox", "the earth is round"])
+ input_data = ["the quick brown fox", "the earth is round"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[4, 9, 5, 7], [4, 6, 8, 10]])
@@ -63,9 +63,9 @@ def test_unk_token(self):
self.assertAllEqual(output, [4, 9, 5, 7, 3])
def test_detokenize(self):
- input_data = tf.constant([[4, 9, 5, 7]])
+ input_data = [[4, 9, 5, 7]]
output = self.tokenizer.detokenize(input_data)
- self.assertEqual(output, tf.constant(["brown round earth is"]))
+ self.assertEqual(output, ["brown round earth is"])
def test_vocabulary(self):
vocabulary = self.tokenizer.get_vocabulary()
@@ -116,6 +116,7 @@ def test_serialization(self):
)
@pytest.mark.large # Saving is slow, so mark these large.
+ @pytest.mark.tf_only
def test_saved_model(self):
input_data = tf.constant(["the quick brown fox"])
diff --git a/keras_nlp/samplers/random_sampler.py b/keras_nlp/samplers/random_sampler.py
index 6dcf30e828..baf13cf875 100644
--- a/keras_nlp/samplers/random_sampler.py
+++ b/keras_nlp/samplers/random_sampler.py
@@ -66,11 +66,15 @@ def __init__(
):
super().__init__(**kwargs)
self.seed = seed
+ self.seed_generator = random.SeedGenerator(seed)
def get_next_token(self, probabilities):
# Sample the next token from the probability distribution.
next_token_id = random.categorical(
- ops.log(probabilities), 1, seed=self.seed, dtype="int32"
+ ops.log(probabilities),
+ 1,
+ seed=self.seed,
+ dtype="int32",
)
return ops.squeeze(next_token_id, axis=-1)
diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py
index 46a31ee784..7ca4c9b10f 100644
--- a/keras_nlp/samplers/sampler.py
+++ b/keras_nlp/samplers/sampler.py
@@ -14,8 +14,10 @@
"""Base sampler class."""
from keras_nlp.api_export import keras_nlp_export
+from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.backend import ops
+from keras_nlp.backend import random
from keras_nlp.utils.python_utils import format_docstring
call_args_docstring = """next: A function which takes in the
@@ -93,6 +95,22 @@ def __init__(
temperature=1.0,
):
self.temperature = temperature
+ self._seed_generators = []
+
+ def __setattr__(self, name, value):
+ # We could update to the `Tracker` class from keras-core if our needs
+ # become more advanced (e.g. list assignment, nested trackables). For
+ # now, we only track `SeedGenerator` instances directly on the sampler.
+ if isinstance(value, random.SeedGenerator):
+ self._seed_generators.append(value)
+ return super().__setattr__(name, value)
+
+ @property
+ def variables(self):
+ variables = []
+ for sg in self._seed_generators:
+ variables.append(sg.state)
+ return variables
def __call__(
self,
@@ -135,17 +153,53 @@ def body(prompt, cache, index):
# Update the prompt with the next token.
next_token = next_token[:, None]
prompt = ops.slice_update(prompt, [0, index], next_token)
+
# Return the next prompt, cache and incremented index.
return (prompt, cache, index + 1)
- prompt, _, _ = ops.while_loop(
- cond=cond,
- body=body,
+ prompt, _, _ = self.run_loop(
+ cond,
+ body,
loop_vars=(prompt, cache, index),
maximum_iterations=(max_length - index),
)
return prompt
+ def run_loop(self, cond, body, loop_vars=None, maximum_iterations=None):
+ """Run ops.while_loops with a `StatelessScope` if necessary."""
+ if config.backend() == "jax":
+
+ def stateless_cond(variables, *loop_vars):
+ return cond(*loop_vars)
+
+ def stateless_body(variables, *loop_vars):
+ mapping = zip(self.variables, variables)
+ with keras.StatelessScope(state_mapping=mapping) as scope:
+ loop_vars = body(*loop_vars)
+
+ variables = []
+ for v in self.variables:
+ new_v = scope.get_current_value(v)
+ variables.append(new_v if new_v is not None else v)
+ return variables, *loop_vars
+
+ variables = [ops.convert_to_tensor(v) for v in self.variables]
+ variables, *loop_vars = ops.while_loop(
+ cond=stateless_cond,
+ body=stateless_body,
+ loop_vars=(variables, *loop_vars),
+ maximum_iterations=maximum_iterations,
+ )
+ [ref_v.assign(v) for ref_v, v in zip(self.variables, variables)]
+ else:
+ loop_vars = ops.while_loop(
+ cond=cond,
+ body=body,
+ loop_vars=(loop_vars),
+ maximum_iterations=maximum_iterations,
+ )
+ return loop_vars
+
def get_next_token(self, probabilities):
"""Get the next token.
Args:
diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py
index 355d216c6c..39d50d2ad4 100644
--- a/keras_nlp/samplers/top_k_sampler.py
+++ b/keras_nlp/samplers/top_k_sampler.py
@@ -69,6 +69,7 @@ def __init__(
super().__init__(**kwargs)
self.k = k
self.seed = seed
+ self.seed_generator = random.SeedGenerator(seed)
def get_next_token(self, probabilities):
# Filter out top-k tokens.
@@ -81,7 +82,7 @@ def get_next_token(self, probabilities):
sample_indices = random.categorical(
ops.log(top_k_pred),
1,
- seed=self.seed,
+ seed=self.seed_generator,
dtype="int32",
)
diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py
index d1c5350c49..8ff3147b72 100644
--- a/keras_nlp/samplers/top_p_sampler.py
+++ b/keras_nlp/samplers/top_p_sampler.py
@@ -79,6 +79,7 @@ def __init__(
self.p = p
self.k = k
self.seed = seed
+ self.seed_generator = random.SeedGenerator(seed)
def get_next_token(self, probabilities):
cutoff = ops.shape(probabilities)[1]
@@ -105,7 +106,7 @@ def get_next_token(self, probabilities):
sorted_next_token = random.categorical(
ops.log(probabilities),
1,
- seed=self.seed,
+ seed=self.seed_generator,
dtype="int32",
)
output = ops.take_along_axis(sorted_indices, sorted_next_token, axis=-1)
diff --git a/keras_nlp/tests/doc_tests/docstring_test.py b/keras_nlp/tests/doc_tests/docstring_test.py
index 486ecf9389..a484b9f1aa 100644
--- a/keras_nlp/tests/doc_tests/docstring_test.py
+++ b/keras_nlp/tests/doc_tests/docstring_test.py
@@ -48,6 +48,7 @@ def docstring_module(pytestconfig):
return pytestconfig.getoption("docstring_module")
+@pytest.mark.tf_only
def test_docstrings(docstring_module):
keras_nlp_modules = find_modules()
# As of this writing, it doesn't seem like pytest support load_tests
@@ -86,6 +87,7 @@ def test_docstrings(docstring_module):
assert result.wasSuccessful()
+@pytest.mark.tf_only
@pytest.mark.extra_large
@pytest.mark.skipif(
astor is None,
diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py
index 60ff401bbe..1f759f8945 100644
--- a/keras_nlp/tokenizers/word_piece_tokenizer.py
+++ b/keras_nlp/tokenizers/word_piece_tokenizer.py
@@ -277,7 +277,7 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
Custom splitting.
>>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
- >>> inputs = ["The$quick$brown$fox"]
+ >>> inputs = "The$quick$brown$fox"
>>> tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
... vocabulary=vocab,
... split=False,
@@ -285,8 +285,9 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
... dtype='string',
... )
>>> split_inputs = tf.strings.split(inputs, sep="$")
- >>> tokenizer(split_inputs)
-
+ >>> outputs = tokenizer(split_inputs)
+ >>> np.array(outputs).astype("U")
+ array(['the', 'qu', '##ick', 'br', '##own', 'fox'], dtype='