fix(export+models): Enhance support for dictionary-based model input signatures in TensorFlow and JAX#20842
Conversation
…cases) - Improves input structure validation in Model and Functional classes - Adds strict validation with clear error messages for mismatches
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20842 +/- ##
==========================================
- Coverage 82.26% 82.20% -0.06%
==========================================
Files 561 561
Lines 52693 53035 +342
Branches 8146 8228 +82
==========================================
+ Hits 43347 43600 +253
- Misses 7344 7391 +47
- Partials 2002 2044 +42
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
- Make the linter happy :)
- Added shape logging for preprocessed and crossed features - Added debug messages in Functional model processing (To be removed)
|
TODO: Remove debugging prints. |
|
Ping: @mattdangerw @fchollet. |
jeffcarp
left a comment
There was a problem hiding this comment.
Thanks for the contribution! Left some comments.
|
Thanks for the review! Will fix it at the earliest. |
|
Please do look into the changes and let me know if it's the intended behavior. Looking forward to your guidance, thanks! |
|
Ping: @fchollet, looking forward to an update on this PR. Thanks! |
|
Thanks for the update!
Please resolve merge conflicts. @jeffcarp does the PR look good? |
|
@fchollet |
|
@fchollet, just checking in on this PR. Since the changes are approved, is there anything else needed before merging? Looking forward to your feedback, thanks! |
hertschuh
left a comment
There was a problem hiding this comment.
Thank you for your work on this. I can see that a lot of effort was put into this. However, I believe this can be simplified quite a bit.
There is a lot of logic around creating, maintaining, restoring self._input_names. However, it is only actually needed in export_utils.get_input_signature, and it's only used for one purpose, which is to know the order of inputs. But the order of inputs is already known, right? It's in self._input_struct. So you can just use self._input_struct in export_utils.get_input_signature and remove self._input_names.
Additionally, I don't think dicts constitute a special case in most cases. Layers can have 1 or more positional arguments as inputs, and each one of those is a nested structure that can have dicts at any layer.
Out of curiosity, I added the following 2 tests in saved_model_test.py, and they don't pass:
def test_export_with_two_dict_inputs_functional(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
inputs1 = {
"foo": layers.Input(shape=()),
"bar": layers.Input(shape=()),
}
inputs2 = {
"baz": layers.Input(shape=()),
}
outputs = layers.Add()([inputs1["foo"], inputs1["bar"], inputs2["baz"]])
model = models.Model((inputs1, inputs2), outputs)
ref_input = (
{"foo": tf.constant([1.0]), "bar": tf.constant([2.0])},
{"baz": tf.constant([2.0])},
)
ref_output = model(ref_input)
model.export(temp_filepath, format="tf_saved_model")
revived_model = tf.saved_model.load(temp_filepath)
revived_output = revived_model.serve(ref_input)
self.assertAllClose(ref_output, revived_output)
def test_export_with_two_dict_inputs_subclass(self):
class TwoDictModel(models.Model):
def call(self, inputs1, inputs2):
return ops.add(
ops.add(inputs1["foo"], inputs1["bar"]), inputs2["baz"]
)
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = TwoDictModel()
ref_input1 = {"foo": tf.constant([1.0]), "bar": tf.constant([2.0])}
ref_input2 = {"baz": tf.constant([2.0])}
ref_output = model(ref_input1, ref_input2)
model.export(temp_filepath, format="tf_saved_model")
revived_model = tf.saved_model.load(temp_filepath)
revived_output = revived_model.serve(ref_input1, ref_input2)
self.assertAllClose(ref_output, revived_output)This shows that dicts as second arguments are not supported.
| # Create a simplified wrapper that handles both dict and | ||
| # positional args, similar to TensorFlow implementation. | ||
| def wrapped_fn(arg, **kwargs): | ||
| return fn(arg) |
There was a problem hiding this comment.
I'm not following
- why this is needed
- what tensorflow implementation it follows
All it does is keep the first position argument and drop everything else. What if you have multiple positional arguments?
| jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape( | ||
| input_signature | ||
| ) | ||
| if isinstance(input_signature, dict): |
There was a problem hiding this comment.
This shouldn't be needed. input_signature should always be a list because the model can have multiple positional arguments. If the inputs are a dict, input_signature should be a list of one dict already. If you need to add the outer list here, it means that an incorrect input_signature was passed in the first place.
| decorated_fn = tf.function( | ||
| fn, input_signature=input_signature, autograph=False | ||
| ) | ||
| if isinstance(input_signature, dict): |
There was a problem hiding this comment.
Same comment about why this shouldn't be needed. input_signature should always be a list because the model can have multiple positional arguments. If the inputs are a dict, input_signature should be a list of one dict already. If you need to add the outer list here, it means that an incorrect input_signature was passed in the first place.
| if isinstance(input_signature, dict): | ||
| # Create a simplified wrapper that handles both dict and | ||
| # positional args. | ||
| def wrapped_fn(arg, **kwargs): |
There was a problem hiding this comment.
Same comment about why this is needed and dropping all but the first positional argument.
| tensor_spec = x | ||
| else: | ||
| return x | ||
| if isinstance(x, dict): |
There was a problem hiding this comment.
make_tf_tensor_spec is always called (and should always be called) within keras.tree.map_structure, so you'll never handles dicts here. Remove this case.
It is indeed not covered, showing it's not used:
https://app.codecov.io/gh/keras-team/keras/pull/20842?src=pr&el=tree&filepath=keras%2Fsrc%2Fexport%2Fexport_utils.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#4690cab8cef00d17f2a67990022886ff-R103
| def build_from_config(self, config): | ||
| if not config: | ||
| return | ||
| # Fetch the input structure from config if available. |
There was a problem hiding this comment.
Lines 416-424 is all dead code. "input_names" is never in the config because nothing adds it there. Remove.
| return self.layers[index] | ||
|
|
||
| if name is not None: | ||
| # Check if the name matches any of the input names. |
There was a problem hiding this comment.
I don't think this is needed, the fallback should work for this use case.
Also, coverage shows it's never used:
https://app.codecov.io/gh/keras-team/keras/pull/20842?src=pr&el=tree&filepath=keras%2Fsrc%2Fmodels%2Fmodel.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#b343e88c9bde68f6f2b9d00f13fa225e-R218
| else: | ||
| Layer.__init__(self, *args, **kwargs) | ||
| if args: | ||
| inputs = args[0] |
There was a problem hiding this comment.
This assumes there is only one input argument. That's true for Functional models, that is not necessarily true for subclass model.
| Layer.__init__(self, *args, **kwargs) | ||
| if args: | ||
| inputs = args[0] | ||
| # Only set _input_names if not already initialized by Functional. |
There was a problem hiding this comment.
You're in an else, you know for sure it's not a Functional model if you're here.
| if isinstance(input_signature, list) and len(input_signature) > 1: | ||
| input_signature = [input_signature] | ||
| if hasattr(model, "_input_names") and model._input_names: | ||
| if isinstance(model._inputs_struct, dict): |
There was a problem hiding this comment.
Why would _inputs_struct not have the right order?
|
This is lovely! Thanks for sharing your insights, @hertschuh. I believe I was missing the fact that I could make much better use of the existing |
|
@harshaljanjani @hertschuh are you still working on this PR? |
Apologies for the delay, I've been caught up in this PR for the past month: I'll leave a comment in the issue regarding the same. |
I can take over. I think I have a fix. |
Thanks a lot for your consideration! |
This fix goes beyond the requirements of the issue and adds support for handling Keras models with dictionary-based inputs, particularly when exporting to the
TFSavedModelformat for both the TensorFlow and JAX backends. Previously, models with dictionary inputs would fail during export with ValueErrors related to input structure mismatches.Key changes:
model._input_namesfor dictionary-based inputs in Functional and Model classesThis PR aims to fix #20835 where models with dictionary inputs would fail to export properly to
SavedModelformat.Example of fixed functionality: