From 7dfae95ed0ac60af7bba155f0ed4283b7267cb7d Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Mon, 16 Sep 2024 20:24:10 -0700 Subject: [PATCH] Fix saved classifier models from before 0.14 We switched the class name for `XXClassifier` models to `XXTextClasssifier`. However, a saved classifier before 0.16 will still be looking for the class under the old name. This updates our export helper to also registered the old name, so we can restore to the new class when loading the model. I also try to improve our error messages when we do encounter an unrecognized class. --- keras_nlp/src/api_export.py | 13 +++++++++++-- .../src/models/bert/bert_text_classifier_test.py | 9 +++++++++ keras_nlp/src/utils/preset_utils.py | 11 ++++++++++- keras_nlp/src/utils/preset_utils_test.py | 11 +++++++++++ 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/keras_nlp/src/api_export.py b/keras_nlp/src/api_export.py index cfa3519ce9..050f3ee964 100644 --- a/keras_nlp/src/api_export.py +++ b/keras_nlp/src/api_export.py @@ -22,7 +22,16 @@ namex = None -def maybe_register_serializable(symbol): +def maybe_register_serializable(path, symbol): + # If we have multiple export names, actually make sure to register these + # first. This makes sure we have a backward compat mapping of old serialized + # name to new class. + if isinstance(path, (list, tuple)): + for name in path: + name = name.split(".")[-1] + keras.saving.register_keras_serializable( + package="keras_nlp", name=name + )(symbol) if isinstance(symbol, types.FunctionType) or hasattr(symbol, "get_config"): keras.saving.register_keras_serializable(package="keras_nlp")(symbol) @@ -34,7 +43,7 @@ def __init__(self, path): super().__init__(package="keras_nlp", path=path) def __call__(self, symbol): - maybe_register_serializable(symbol) + maybe_register_serializable(self.path, symbol) return super().__call__(symbol) else: diff --git a/keras_nlp/src/models/bert/bert_text_classifier_test.py b/keras_nlp/src/models/bert/bert_text_classifier_test.py index 44b9a3d846..fa4f4cadfc 100644 --- a/keras_nlp/src/models/bert/bert_text_classifier_test.py +++ b/keras_nlp/src/models/bert/bert_text_classifier_test.py @@ -67,6 +67,15 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=BertTextClassifier, + preset="bert_tiny_en_uncased_sst2", + input_data=self.input_data, + expected_output_shape=(2, 2), + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in BertTextClassifier.presets: diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index 34c96de777..bb93e38e4d 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -582,7 +582,16 @@ def load_serialized_object(config, **kwargs): def check_config_class(config): """Validate a preset is being loaded on the correct class.""" - return keras.saving.get_registered_object(config["registered_name"]) + registered_name = config["registered_name"] + cls = keras.saving.get_registered_object(registered_name) + if cls is None: + raise ValueError( + f"Attempting to load class {registered_name} with " + "`from_preset()`, but there is no class registered with Keras " + f"for {registered_name}. Make sure to register any custom " + "classes with `register_keras_serializable()`." + ) + return cls def jax_memory_cleanup(layer): diff --git a/keras_nlp/src/utils/preset_utils_test.py b/keras_nlp/src/utils/preset_utils_test.py index 0e3e022d82..1c73c863b4 100644 --- a/keras_nlp/src/utils/preset_utils_test.py +++ b/keras_nlp/src/utils/preset_utils_test.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import keras import pytest from absl.testing import parameterized @@ -38,6 +40,15 @@ def test_preset_errors(self): with self.assertRaisesRegex(ValueError, "Unknown preset identifier"): AlbertTextClassifier.from_preset("snaggle://bort/bort/bort") + backbone = BertBackbone.from_preset("bert_tiny_en_uncased") + preset_dir = self.get_temp_dir() + config = keras.utils.serialize_keras_object(backbone) + config["registered_name"] = "keras_nlp>BortBackbone" + with open(os.path.join(preset_dir, CONFIG_FILE), "w") as config_file: + config_file.write(json.dumps(config, indent=4)) + with self.assertRaisesRegex(ValueError, "class keras_nlp>BortBackbone"): + BertBackbone.from_preset(preset_dir) + def test_upload_empty_preset(self): temp_dir = self.get_temp_dir() empty_preset = os.path.join(temp_dir, "empty")