Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ops/numpy.py: support key as list in GetItem #19310

Merged
merged 1 commit into from Mar 30, 2024

Conversation

tvogel
Copy link
Contributor

@tvogel tvogel commented Mar 14, 2024

Please take a look, I had trouble loading a model that did things like this:

kernel_output = keras.layers.Concatenate()(
  [ 
    kernel(input[:, 2*other:2*other+2] - input[:, 2*ref:2*ref+2]) 
    for ref in range(n)
    for other in range(n) 
    if other != ref 
  ])

Copy link

google-cla bot commented Mar 14, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Please add a unit test.

@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Mar 15, 2024
@codecov-commenter
Copy link

codecov-commenter commented Mar 15, 2024

Codecov Report

Attention: Patch coverage is 50.00000% with 1 lines in your changes are missing coverage. Please review.

Project coverage is 75.99%. Comparing base (104fe8e) to head (3125ec6).

Files Patch % Lines
keras/ops/numpy.py 50.00% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19310      +/-   ##
==========================================
+ Coverage   75.97%   75.99%   +0.01%     
==========================================
  Files         366      366              
  Lines       40740    40742       +2     
  Branches     7944     7945       +1     
==========================================
+ Hits        30952    30960       +8     
+ Misses       8075     8071       -4     
+ Partials     1713     1711       -2     
Flag Coverage Δ
keras 75.84% <50.00%> (+0.01%) ⬆️
keras-jax 60.13% <50.00%> (+0.01%) ⬆️
keras-numpy 54.09% <50.00%> (+0.06%) ⬆️
keras-tensorflow 61.37% <50.00%> (+<0.01%) ⬆️
keras-torch 60.27% <50.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@tvogel
Copy link
Contributor Author

tvogel commented Mar 16, 2024

Thanks for the PR! Please add a unit test.

Just did! I was not entirely sure which test to add to - I hope, I made a reasonable choice. Otherwise, please let me know!

layer = input[:, 1]
model = keras.Model(input, layer)
serialized, restored, reserialized = self.roundtrip(model)
self.assertEqual(serialized, reserialized)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove all the other added code by just doing e.g.

serialized_str = str(serialized).replace("(", "[").replace(")", "]")
reserialized_str = str(reserialized)
self.assertEqual(serialized_str, reserialized_str)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you prefer it, I can do that but honestly, isn't that a bit ad-hoc to fix this only for this special case? Wouldn't that problem turn up in general for future tests? Tuples being left out in test_simple_objects() even though they appear a lot, e.g. for all shape specifications, seemed odd to me, too.

Also, plain string manipulation feels a bit arbitrary and could hide problems when the replacement happens e.g. inside string payload where in general, you would want to detect differences. Of course, for this specific case, it's not a problem but depending on how the serialization format evolves, I am not sure, it will stay like that. Are you?

So, please confirm, you prefer the string comparison and I will adjust it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet Please let me know!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a matter of complexity. The cost of complexity is high. The need to check object equality on one line in a test does not justify the added complexity and its cost. The alternative formulation is adhoc, but it is simple and easy to read, so its cost is almost 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, your project, your call. I have been developing complex software for many years now, and my take is to always prefer the proper solution in contrast to ad-hoc fix-ups but of course, in a unit-test only context, it does not pay so much to spend more time debating this. I will change it probably tomorrow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet Actually, did you try your suggestion? It does not work here because 'shape': (None, 2) is actually still handled as a tuple after reserialization in contrast to other things. So, I have to map both strings.

@fchollet
Copy link
Member

You'll also need to run sh shell/format.sh to format the code.

@tvogel
Copy link
Contributor Author

tvogel commented Mar 16, 2024

You'll also need to run sh shell/format.sh to format the code.

I adjusted the patch with shell/format.sh.

@tvogel
Copy link
Contributor Author

tvogel commented Mar 29, 2024

@fchollet I updated the PR!

PR Queue automation moved this from Assigned Reviewer to Approved by Reviewer Mar 29, 2024
@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 29, 2024
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Mar 30, 2024
When loading a model that contains `GetItem` nodes with multidimensional
indices/slices as `key`, the `key` argument is loaded from JSON as a `list`,
not a `tuple` (because JSON does not have the distinction).

So, treat the `key list` as equivalent to the `key tuple`.
Copying is important: otherwise, the later `pop()` will remove the bound
slice elements from the op itself.

`saving/serialization_lib_test.py`:

* Add `test_numpy_get_item_layer()`:
	test for consistent serialization/deserialization of a model which
	contains `ops.numpy.GetItem`;
@fchollet fchollet merged commit b57bfcd into keras-team:master Mar 30, 2024
6 checks passed
PR Queue automation moved this from Approved by Reviewer to Merged Mar 30, 2024
@tvogel tvogel deleted the patch-2 branch March 30, 2024 23:34
james77777778 added a commit to james77777778/keras that referenced this pull request Apr 3, 2024
* Refactor dtypes in codebase and add float8_* dtypes

* Update comments

Fix for JAX export on GPU. (keras-team#19404)

Fix formatting in export_lib. (keras-team#19405)

`ops/numpy.py`: Support `key` as `list` in `GetItem` (keras-team#19310)

When loading a model that contains `GetItem` nodes with multidimensional
indices/slices as `key`, the `key` argument is loaded from JSON as a `list`,
not a `tuple` (because JSON does not have the distinction).

So, treat the `key list` as equivalent to the `key tuple`.
Copying is important: otherwise, the later `pop()` will remove the bound
slice elements from the op itself.

`saving/serialization_lib_test.py`:

* Add `test_numpy_get_item_layer()`:
	test for consistent serialization/deserialization of a model which
	contains `ops.numpy.GetItem`;

feat(losses): add Dice loss implementation (keras-team#19409)

* feat(losses): add Dice loss implementation

* removed smooth parameter and type casting

* adjusted casting and dot operator

Update casting

Bump the github-actions group with 1 update (keras-team#19412)

Bumps the github-actions group with 1 update: [github/codeql-action](https://github.com/github/codeql-action).

Updates `github/codeql-action` from 3.24.6 to 3.24.9
- [Release notes](https://github.com/github/codeql-action/releases)
- [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md)
- [Commits](github/codeql-action@8a470fd...1b1aada)

---
updated-dependencies:
- dependency-name: github/codeql-action
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

Fix issue with shared layer deserialization

Remove dead code in saving lib (keras-team#19415)

Remove unused beta param for silu, use torch op directly (keras-team#19417)

The beta param was only accepted on the tensorflow/torch backends
and not in the `keras.ops` API, nor was it tested. I think best
just to ditch, since no one could be relying on it.

Fix print_fn for custom function (keras-team#19419)

Add fp8 to `EinsumDense`

Add test script
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
PR Queue
Merged
Development

Successfully merging this pull request may close these issues.

None yet

5 participants