New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ops/numpy.py: support key as list in GetItem #19310
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Please add a unit test.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19310 +/- ##
==========================================
+ Coverage 75.97% 75.99% +0.01%
==========================================
Files 366 366
Lines 40740 40742 +2
Branches 7944 7945 +1
==========================================
+ Hits 30952 30960 +8
+ Misses 8075 8071 -4
+ Partials 1713 1711 -2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Just did! I was not entirely sure which test to add to - I hope, I made a reasonable choice. Otherwise, please let me know! |
layer = input[:, 1] | ||
model = keras.Model(input, layer) | ||
serialized, restored, reserialized = self.roundtrip(model) | ||
self.assertEqual(serialized, reserialized) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove all the other added code by just doing e.g.
serialized_str = str(serialized).replace("(", "[").replace(")", "]")
reserialized_str = str(reserialized)
self.assertEqual(serialized_str, reserialized_str)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you prefer it, I can do that but honestly, isn't that a bit ad-hoc to fix this only for this special case? Wouldn't that problem turn up in general for future tests? Tuples being left out in test_simple_objects()
even though they appear a lot, e.g. for all shape specifications, seemed odd to me, too.
Also, plain string manipulation feels a bit arbitrary and could hide problems when the replacement happens e.g. inside string payload where in general, you would want to detect differences. Of course, for this specific case, it's not a problem but depending on how the serialization format evolves, I am not sure, it will stay like that. Are you?
So, please confirm, you prefer the string comparison and I will adjust it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet Please let me know!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a matter of complexity. The cost of complexity is high. The need to check object equality on one line in a test does not justify the added complexity and its cost. The alternative formulation is adhoc, but it is simple and easy to read, so its cost is almost 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, your project, your call. I have been developing complex software for many years now, and my take is to always prefer the proper solution in contrast to ad-hoc fix-ups but of course, in a unit-test only context, it does not pay so much to spend more time debating this. I will change it probably tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet Actually, did you try your suggestion? It does not work here because 'shape': (None, 2)
is actually still handled as a tuple after reserialization in contrast to other things. So, I have to map both strings.
You'll also need to run |
I adjusted the patch with |
@fchollet I updated the PR! |
When loading a model that contains `GetItem` nodes with multidimensional indices/slices as `key`, the `key` argument is loaded from JSON as a `list`, not a `tuple` (because JSON does not have the distinction). So, treat the `key list` as equivalent to the `key tuple`. Copying is important: otherwise, the later `pop()` will remove the bound slice elements from the op itself. `saving/serialization_lib_test.py`: * Add `test_numpy_get_item_layer()`: test for consistent serialization/deserialization of a model which contains `ops.numpy.GetItem`;
* Refactor dtypes in codebase and add float8_* dtypes * Update comments Fix for JAX export on GPU. (keras-team#19404) Fix formatting in export_lib. (keras-team#19405) `ops/numpy.py`: Support `key` as `list` in `GetItem` (keras-team#19310) When loading a model that contains `GetItem` nodes with multidimensional indices/slices as `key`, the `key` argument is loaded from JSON as a `list`, not a `tuple` (because JSON does not have the distinction). So, treat the `key list` as equivalent to the `key tuple`. Copying is important: otherwise, the later `pop()` will remove the bound slice elements from the op itself. `saving/serialization_lib_test.py`: * Add `test_numpy_get_item_layer()`: test for consistent serialization/deserialization of a model which contains `ops.numpy.GetItem`; feat(losses): add Dice loss implementation (keras-team#19409) * feat(losses): add Dice loss implementation * removed smooth parameter and type casting * adjusted casting and dot operator Update casting Bump the github-actions group with 1 update (keras-team#19412) Bumps the github-actions group with 1 update: [github/codeql-action](https://github.com/github/codeql-action). Updates `github/codeql-action` from 3.24.6 to 3.24.9 - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](github/codeql-action@8a470fd...1b1aada) --- updated-dependencies: - dependency-name: github/codeql-action dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Fix issue with shared layer deserialization Remove dead code in saving lib (keras-team#19415) Remove unused beta param for silu, use torch op directly (keras-team#19417) The beta param was only accepted on the tensorflow/torch backends and not in the `keras.ops` API, nor was it tested. I think best just to ditch, since no one could be relying on it. Fix print_fn for custom function (keras-team#19419) Add fp8 to `EinsumDense` Add test script
Please take a look, I had trouble loading a model that did things like this: