Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,16 +454,6 @@ def load_json(preset, config_file=CONFIG_FILE):
return config


def load_serialized_object(config, **kwargs):
# `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
# Ensure that `dtype` is properly configured.
dtype = kwargs.pop("dtype", None)
config = set_dtype_in_config(config, dtype)

config["config"] = {**config["config"], **kwargs}
return keras.saving.deserialize_keras_object(config)


def check_config_class(config):
"""Validate a preset is being loaded on the correct class."""
registered_name = config["registered_name"]
Expand Down Expand Up @@ -631,26 +621,26 @@ def check_backbone_class(self):
return check_config_class(self.config)

def load_backbone(self, cls, load_weights, **kwargs):
backbone = load_serialized_object(self.config, **kwargs)
backbone = self._load_serialized_object(self.config, **kwargs)
if load_weights:
jax_memory_cleanup(backbone)
backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
return backbone

def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
tokenizer_config = load_json(self.preset, config_file)
tokenizer = load_serialized_object(tokenizer_config, **kwargs)
tokenizer = self._load_serialized_object(tokenizer_config, **kwargs)
if hasattr(tokenizer, "load_preset_assets"):
tokenizer.load_preset_assets(self.preset)
return tokenizer

def load_audio_converter(self, cls, **kwargs):
converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE)
return load_serialized_object(converter_config, **kwargs)
return self._load_serialized_object(converter_config, **kwargs)

def load_image_converter(self, cls, **kwargs):
converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
return load_serialized_object(converter_config, **kwargs)
return self._load_serialized_object(converter_config, **kwargs)

def load_task(self, cls, load_weights, load_task_weights, **kwargs):
# If there is no `task.json` or it's for the wrong class delegate to the
Expand All @@ -671,7 +661,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
backbone_config = task_config["config"]["backbone"]["config"]
backbone_config = {**backbone_config, **backbone_kwargs}
task_config["config"]["backbone"]["config"] = backbone_config
task = load_serialized_object(task_config, **kwargs)
task = self._load_serialized_object(task_config, **kwargs)
if task.preprocessor and hasattr(
task.preprocessor, "load_preset_assets"
):
Expand Down Expand Up @@ -699,11 +689,20 @@ def load_preprocessor(
if not issubclass(check_config_class(preprocessor_json), cls):
return super().load_preprocessor(cls, **kwargs)
# We found a `preprocessing.json` with a complete config for our class.
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
preprocessor = self._load_serialized_object(preprocessor_json, **kwargs)
if hasattr(preprocessor, "load_preset_assets"):
preprocessor.load_preset_assets(self.preset)
return preprocessor

def _load_serialized_object(self, config, **kwargs):
# `dtype` in config might be a serialized `DTypePolicy` or
# `DTypePolicyMap`. Ensure that `dtype` is properly configured.
dtype = kwargs.pop("dtype", None)
config = set_dtype_in_config(config, dtype)

config["config"] = {**config["config"], **kwargs}
return keras.saving.deserialize_keras_object(config)


class KerasPresetSaver:
def __init__(self, preset_dir):
Expand Down Expand Up @@ -787,6 +786,8 @@ def _save_metadata(self, layer):
tasks = list_subclasses(Task)
tasks = filter(lambda x: x.backbone_cls is type(layer), tasks)
tasks = [task.__base__.__name__ for task in tasks]
# Keep task list alphabetical.
tasks = sorted(tasks)

keras_version = keras.version() if hasattr(keras, "version") else None
metadata = {
Expand Down
17 changes: 0 additions & 17 deletions keras_hub/src/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from keras_hub.src.models.bert.bert_backbone import BertBackbone
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.keras_utils import has_quantization_support
from keras_hub.src.utils.preset_utils import CONFIG_FILE
from keras_hub.src.utils.preset_utils import load_serialized_object
from keras_hub.src.utils.preset_utils import upload_preset


Expand Down Expand Up @@ -88,18 +86,3 @@ def test_upload_with_invalid_json(self):
# Verify error handling.
with self.assertRaisesRegex(ValueError, "is an invalid json"):
upload_preset("kaggle://test/test/test", local_preset_dir)

@parameterized.named_parameters(
("gemma2_2b_en", "gemma2_2b_en", "bfloat16", False),
("llama2_7b_en_int8", "llama2_7b_en_int8", "bfloat16", True),
)
@pytest.mark.extra_large
def test_load_serialized_object(self, preset, dtype, is_quantized):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the test you mentioned it's broken and we don't run? It seems like a good test though!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could bring it back, but it's missing the mark on a few counts.

  • Too big to run on our regular CI.
  • Doesn't actually work.
  • Not testing anything an end user would hit, it's just testing an internal utility in a way that will be brittle. We could test similar logic with from_preset that would be much less brittle.

Let's ditch for now!

if is_quantized and not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

model = load_serialized_object(preset, dtype=dtype)
if is_quantized:
self.assertEqual(model.dtype_policy.name, "map_bfloat16")
else:
self.assertEqual(model.dtype_policy.name, "bfloat16")
File renamed without changes.
File renamed without changes.
File renamed without changes.
104 changes: 0 additions & 104 deletions tools/convert_legacy_presets.py

This file was deleted.

Loading