Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions keras_hub/src/models/backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ def test_save_to_preset(self):
self.assertTrue("build_config" not in backbone_config)
self.assertTrue("compile_config" not in backbone_config)

# Check the metadata.
metadata_config = load_json(save_dir, METADATA_FILE)
self.assertTrue("keras_version" in metadata_config)
self.assertTrue("keras_hub_version" in metadata_config)
self.assertTrue("parameter_count" in metadata_config)
self.assertTrue("TextClassifier" in metadata_config["tasks"])
self.assertTrue("CausalLM" not in metadata_config["tasks"])

# Try config class.
self.assertEqual(BertBackbone, check_config_class(backbone_config))

Expand Down
9 changes: 9 additions & 0 deletions keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,14 +765,23 @@ def _save_serialized_object(self, layer, config_file):
config_file.write(json.dumps(config, indent=4))

def _save_metadata(self, layer):
from keras_hub.src.models.task import Task
from keras_hub.src.version_utils import __version__ as keras_hub_version

# Find all tasks that are compatible with the backbone.
# E.g. for `BertBackbone` we would have `TextClassifier` and `MaskedLM`.
# For `ResNetBackbone` we would have `ImageClassifier`.
tasks = list_subclasses(Task)
tasks = filter(lambda x: x.backbone_cls == type(layer), tasks)
tasks = [task.__base__.__name__ for task in tasks]

keras_version = keras.version() if hasattr(keras, "version") else None
metadata = {
"keras_version": keras_version,
"keras_hub_version": keras_hub_version,
"parameter_count": layer.count_params(),
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
"tasks": tasks,
}
metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
with open(metadata_path, "w") as metadata_file:
Expand Down
Loading