-
Notifications
You must be signed in to change notification settings - Fork 19.7k
add associative_scan #19938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add associative_scan #19938
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
fchollet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! The code looks good to me.
keras/src/ops/core.py
Outdated
| Args: | ||
| f: A Python callable implementing an associative binary operation with | ||
| signature ``r = f(a, b)``. Function `f` must be associative, i.e., |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use single backticks for code keywords, throughout.
keras/src/ops/core.py
Outdated
| def associative_scan(f, elems, reverse=False, axis=0): | ||
| """Performs a scan with an associative binary operation, in parallel. | ||
| For an introduction to associative scans, see [BLE1990]_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer in-lining the link with a proper title here.
|
|
||
| @keras_export("keras.ops.associative_scan") | ||
| def associative_scan(f, elems, reverse=False, axis=0): | ||
| """Performs a scan with an associative binary operation, in parallel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Below, it would be worth introducing more info about use cases for this op. The docstring should answer, what does this do, and why would I want to use it? Perhaps contrast it with scan.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19938 +/- ##
==========================================
- Coverage 79.00% 72.52% -6.48%
==========================================
Files 499 499
Lines 46531 46724 +193
Branches 8561 8617 +56
==========================================
- Hits 36761 33888 -2873
- Misses 8039 11106 +3067
+ Partials 1731 1730 -1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
| def associative_scan(f, elems, reverse=False, axis=0): | ||
| """Performs a scan with an associative binary operation, in parallel. | ||
| This operation has a similar use-case to scan. The key difference is that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed feedback! I'd also like to add a little bit of context to this statement using some ad hoc benchmarking results from this script. I have results for different backends as well showing a similar trend.
fchollet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM -- thank you for the contribution. Will apply some minor docstring fixes afterwards.
|
It seems there is a device placement issue with torch on GPU, could you take a look? https://btx.cloud.google.com/invocations/0b42aedc-c253-401a-b110-1f441959ccc8/targets/keras%2Fgithub%2Fubuntu%2Fgpu%2Ftorch%2Fcontinuous/log |
Thanks for the info! I made a PR with a fix: #19940 |
|
@SeKim12 thanks for this nice contribution! Did you test using the pytorch backend? Are you able to compile the associative scan operator? I will be testing this myself on an isolated project, but wanted to get your feedback in case you already ran some benchmarks. Thanks! |
Hi @carlosluis! I was able to compile it and use it for the PyTorch backend. Here is a script that I used to quickly benchmark performance: https://gist.github.com/SeKim12/0b5a77fbb05c707e60dcee03cfd7c24b |
|
Hi @SamanehSaadat, My apologies -- I just saw this notification! It seems like the optree imports were handled by @fchollet. Thank you! Please let me know if there is anything else I can do! |
Original PR #19938 by SeKim12 Original: keras-team/keras#19938
Merged from original PR #19938 Original: keras-team/keras#19938
Original PR #19938 by SeKim12 Original: keras-team/keras#19938
Merged from original PR #19938 Original: keras-team/keras#19938

Addresses #19904, adds
associative_scansupport for all backends.This is my first time contributing, so would greatly appreciate any feedback or something I may be missing (and would be more than happy to apply them)! Thanks!