diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index bf57c88912..ac4f8def5b 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -147,6 +147,10 @@ def test_save_to_preset(self): new_out = restored_task.backbone.predict(data) self.assertAllClose(ref_out, new_out) + # Check setting dtype. + restored_task = TextClassifier.from_preset(save_dir, dtype="float16") + self.assertEqual("float16", restored_task.backbone.dtype_policy.name) + @pytest.mark.large def test_save_to_preset_custom_backbone_and_preprocessor(self): preprocessor = keras.layers.Rescaling(1 / 255.0) diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 52aad373a0..4f80f5227f 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -658,6 +658,12 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): cls, load_weights, load_task_weights, **kwargs ) # We found a `task.json` with a complete config for our class. + # Forward backbone args. + backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs) + if "backbone" in task_config["config"]: + backbone_config = task_config["config"]["backbone"]["config"] + backbone_config = {**backbone_config, **backbone_kwargs} + task_config["config"]["backbone"]["config"] = backbone_config task = load_serialized_object(task_config, **kwargs) if task.preprocessor and hasattr( task.preprocessor, "load_preset_assets"