Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possibly unintended regression (though maybe intended?) #19561

Closed
LukeWood opened this issue Apr 19, 2024 · 3 comments · Fixed by #19564
Closed

Possibly unintended regression (though maybe intended?) #19561

LukeWood opened this issue Apr 19, 2024 · 3 comments · Fixed by #19564
Assignees
Labels

Comments

@LukeWood
Copy link
Contributor

Hello - have been doing some limits testing on keras3 and exploring possible edge cases and I believe I've found two in the saving/loading flow.

Here's one:


minimal repro

import keras

input_one = keras.layers.Input((10, 10))
input_two = keras.layers.Input((10,))

y = keras.layers.Add()([input_one, input_two[..., None]])

result = keras.Model({'input_one': input_one, 'input_two': input_two}, y)

result.save('test.keras')

results in:

TypeError: Cannot serialize object Ellipsis of type <class 'ellipsis'>. To be serializable, a class must implement the `get_config()` method.

is this a bug?

While the model trains regardless of backend, it fails to save and load. I believe this is unintended from browsing the "Calling TF ops on a Keras Tensor": https://keras.io/guides/migrating_to_keras_3/ - though perhaps it is intended?

A trivial workaround is to just wrap in a lambda:

expanded_input_two = keras.layers.Lambda(x: x[..., None])(input_two)

Perhaps the "workaround" is the intended new flow?


if its a bug happy to take a deeper look at a solution

@fchollet
Copy link
Member

When you do input_two[..., None] (where input_two is a KerasTensor) you are hitting keras/src/ops/numpy.py:GetItem.compute_output_spec(), which supports Ellipsis.

This is why your model can be built and trained. You'd see an error at this line otherwise.

However the framework doesn't know how to serialize Ellipsis. I think we can support it in the general case, by adding a case for it in serialize_keras_object/deserialize_keras_object (in keras/src/saving/serialization_lib.py). Similar to what we do for __slice__. If you're able, please open a PR.

@LukeWood
Copy link
Contributor Author

Sounds good - I'll give this a shot and see if I can add some tests + add support later today.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants