From b583038567ce014b3b3b3293ac1788c49d03be89 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Thu, 5 Dec 2024 13:39:14 -0800 Subject: [PATCH] Fix dtype when creating a task with a task.json --- keras_hub/src/models/task_test.py | 4 ++++ keras_hub/src/utils/preset_utils.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index bf57c88912..ac4f8def5b 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -147,6 +147,10 @@ def test_save_to_preset(self): new_out = restored_task.backbone.predict(data) self.assertAllClose(ref_out, new_out) + # Check setting dtype. + restored_task = TextClassifier.from_preset(save_dir, dtype="float16") + self.assertEqual("float16", restored_task.backbone.dtype_policy.name) + @pytest.mark.large def test_save_to_preset_custom_backbone_and_preprocessor(self): preprocessor = keras.layers.Rescaling(1 / 255.0) diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 52aad373a0..4f80f5227f 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -658,6 +658,12 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): cls, load_weights, load_task_weights, **kwargs ) # We found a `task.json` with a complete config for our class. + # Forward backbone args. + backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs) + if "backbone" in task_config["config"]: + backbone_config = task_config["config"]["backbone"]["config"] + backbone_config = {**backbone_config, **backbone_kwargs} + task_config["config"]["backbone"]["config"] = backbone_config task = load_serialized_object(task_config, **kwargs) if task.preprocessor and hasattr( task.preprocessor, "load_preset_assets"