diff --git a/keras_nlp/src/api_export.py b/keras_nlp/src/api_export.py index cfa3519ce9..050f3ee964 100644 --- a/keras_nlp/src/api_export.py +++ b/keras_nlp/src/api_export.py @@ -22,7 +22,16 @@ namex = None -def maybe_register_serializable(symbol): +def maybe_register_serializable(path, symbol): + # If we have multiple export names, actually make sure to register these + # first. This makes sure we have a backward compat mapping of old serialized + # name to new class. + if isinstance(path, (list, tuple)): + for name in path: + name = name.split(".")[-1] + keras.saving.register_keras_serializable( + package="keras_nlp", name=name + )(symbol) if isinstance(symbol, types.FunctionType) or hasattr(symbol, "get_config"): keras.saving.register_keras_serializable(package="keras_nlp")(symbol) @@ -34,7 +43,7 @@ def __init__(self, path): super().__init__(package="keras_nlp", path=path) def __call__(self, symbol): - maybe_register_serializable(symbol) + maybe_register_serializable(self.path, symbol) return super().__call__(symbol) else: diff --git a/keras_nlp/src/models/bert/bert_text_classifier_test.py b/keras_nlp/src/models/bert/bert_text_classifier_test.py index 44b9a3d846..fa4f4cadfc 100644 --- a/keras_nlp/src/models/bert/bert_text_classifier_test.py +++ b/keras_nlp/src/models/bert/bert_text_classifier_test.py @@ -67,6 +67,15 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=BertTextClassifier, + preset="bert_tiny_en_uncased_sst2", + input_data=self.input_data, + expected_output_shape=(2, 2), + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in BertTextClassifier.presets: diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index 34c96de777..bb93e38e4d 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -582,7 +582,16 @@ def load_serialized_object(config, **kwargs): def check_config_class(config): """Validate a preset is being loaded on the correct class.""" - return keras.saving.get_registered_object(config["registered_name"]) + registered_name = config["registered_name"] + cls = keras.saving.get_registered_object(registered_name) + if cls is None: + raise ValueError( + f"Attempting to load class {registered_name} with " + "`from_preset()`, but there is no class registered with Keras " + f"for {registered_name}. Make sure to register any custom " + "classes with `register_keras_serializable()`." + ) + return cls def jax_memory_cleanup(layer): diff --git a/keras_nlp/src/utils/preset_utils_test.py b/keras_nlp/src/utils/preset_utils_test.py index 0e3e022d82..1c73c863b4 100644 --- a/keras_nlp/src/utils/preset_utils_test.py +++ b/keras_nlp/src/utils/preset_utils_test.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import keras import pytest from absl.testing import parameterized @@ -38,6 +40,15 @@ def test_preset_errors(self): with self.assertRaisesRegex(ValueError, "Unknown preset identifier"): AlbertTextClassifier.from_preset("snaggle://bort/bort/bort") + backbone = BertBackbone.from_preset("bert_tiny_en_uncased") + preset_dir = self.get_temp_dir() + config = keras.utils.serialize_keras_object(backbone) + config["registered_name"] = "keras_nlp>BortBackbone" + with open(os.path.join(preset_dir, CONFIG_FILE), "w") as config_file: + config_file.write(json.dumps(config, indent=4)) + with self.assertRaisesRegex(ValueError, "class keras_nlp>BortBackbone"): + BertBackbone.from_preset(preset_dir) + def test_upload_empty_preset(self): temp_dir = self.get_temp_dir() empty_preset = os.path.join(temp_dir, "empty")