-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use ops.rsqrt
, improve normalization layers and enable ops fusion in tflite
#892
Use ops.rsqrt
, improve normalization layers and enable ops fusion in tflite
#892
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #892 +/- ##
==========================================
+ Coverage 76.56% 76.82% +0.25%
==========================================
Files 329 329
Lines 31429 31426 -3
Branches 6114 6111 -3
==========================================
+ Hits 24064 24143 +79
+ Misses 5786 5719 -67
+ Partials 1579 1564 -15
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
keras_core/backend/numpy/math.py
Outdated
|
||
|
||
def rsqrt(x): | ||
return np.array(jax_rsqrt(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not 1. / sqrt(x)
? It's numpy native, and we're not worried about performance for the numpy backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking of being consistent with other backends, but it should be okay to use 1. / sqrt(x)
Fixed.
res = res + beta | ||
|
||
# Note: Folding BatchNormalization depends on the precise order of ops | ||
# that are generated by the expression below |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good comment!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM -- thank you for the great contribution!
Fixes #824
This PR accomplishes the following:
rsqrt
in numpy backend (using jax's impl)1 / ops.sqrt(x)
withops.rsqrt
for improved speedtf.nn.batch_normalization
linkAfter completing 3, tflite recognizes the pattern of CONV+BN+ReLU, and the ops are fused successfully.
standalone MobileNetV3 export script
The visualization from netron: (before this PR vs. after this PR)
benchmark script
And the improvement: