Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ops.rsqrt, improve normalization layers and enable ops fusion in tflite #892

Merged
merged 6 commits into from
Sep 16, 2023

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Sep 15, 2023

Fixes #824

This PR accomplishes the following:

  1. adding support for rsqrt in numpy backend (using jax's impl)
  2. replacing 1 / ops.sqrt(x) with ops.rsqrt for improved speed
  3. reordering the ops in normalization layers to unify the implementation and match the expression of tf.nn.batch_normalization link
  4. Ensuring 100% unit test coverage for all normalization layers

After completing 3, tflite recognizes the pattern of CONV+BN+ReLU, and the ops are fused successfully.

standalone MobileNetV3 export script
import tensorflow as tf

from keras_core.applications.mobilenet_v3 import MobileNetV3Small

keras_core_model = MobileNetV3Small(
    input_shape=(224, 224, 3), minimalistic=True
)

tf_callable = tf.function(
    keras_core_model.call,
    input_signature=[tf.TensorSpec((1, 224, 224, 3), tf.float32)],
    autograph=True,
    jit_compile=True,
)
tf_concrete_function = tf_callable.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [tf_concrete_function], tf_callable
)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with open("model.tflite", "wb") as f:
    f.write(tflite_model)

The visualization from netron: (before this PR vs. after this PR)
圖片

benchmark script
from keras_core import layers
from keras_core import mixed_precision
from keras_core import models
from keras_core import ops

# "float32"
# "mixed_float16"
# "mixed_bfloat16"
dtype_policy = "float32"
mixed_precision.set_dtype_policy(dtype_policy)

x_train = ops.random.uniform(shape=(512, 64, 64, 64))
y_train = ops.random.uniform(shape=(512, 64, 64, 64))

# layers.BatchNormalization
# layers.GroupNormalization
# layers.LayerNormalization
normalization_cls = layers.LayerNormalization
normalization_args = {}
if normalization_cls is layers.GroupNormalization:
    normalization_args = {"groups": -1}

model = models.Sequential(
    [
        layers.InputLayer(shape=(64, 64, 64)),
        normalization_cls(**normalization_args),
        normalization_cls(**normalization_args),
        normalization_cls(**normalization_args),
    ]
)
model.compile(loss="mse", optimizer="adam")
model.fit(x_train, y_train, batch_size=128, epochs=3)

And the improvement:

backend layer before this PR after this PR
tensorflow BatchNormalization 48ms/step 46ms/step
jax BatchNormalization 49ms/step 48ms/step
torch BatchNormalization 127ms/step 127ms/step
tensorflow GroupNormalization 50ms/step 49ms/step
jax GroupNormalization 51ms/step 50ms/step
torch GroupNormalization 129ms/step 129ms/step
tensorflow LayerNormalization 54ms/step 53ms/step
jax LayerNormalization 55ms/step 54ms/step
torch LayerNormalization 165ms/step 122ms/step

@codecov
Copy link

codecov bot commented Sep 15, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.25% 🎉

Comparison is base (94b5361) 76.56% compared to head (10e4a03) 76.82%.
Report is 4 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #892      +/-   ##
==========================================
+ Coverage   76.56%   76.82%   +0.25%     
==========================================
  Files         329      329              
  Lines       31429    31426       -3     
  Branches     6114     6111       -3     
==========================================
+ Hits        24064    24143      +79     
+ Misses       5786     5719      -67     
+ Partials     1579     1564      -15     
Flag Coverage Δ
keras_core 76.72% <100.00%> (+0.25%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
keras_core/backend/numpy/math.py 82.43% <100.00%> (+0.24%) ⬆️
...s_core/layers/normalization/batch_normalization.py 100.00% <100.00%> (ø)
...s_core/layers/normalization/group_normalization.py 97.64% <100.00%> (+8.63%) ⬆️
...s_core/layers/normalization/layer_normalization.py 100.00% <100.00%> (+2.59%) ⬆️
...as_core/layers/normalization/unit_normalization.py 100.00% <100.00%> (+7.69%) ⬆️

... and 11 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR!



def rsqrt(x):
return np.array(jax_rsqrt(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not 1. / sqrt(x)? It's numpy native, and we're not worried about performance for the numpy backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of being consistent with other backends, but it should be okay to use 1. / sqrt(x)
Fixed.

res = res + beta

# Note: Folding BatchNormalization depends on the precise order of ops
# that are generated by the expression below
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good comment!

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- thank you for the great contribution!

@fchollet fchollet merged commit c663efd into keras-team:main Sep 16, 2023
8 checks passed
@james77777778 james77777778 deleted the improve-normalization-layers branch September 17, 2023 02:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

tflite cannot fuse BatchNormalization in Keras Core as effectively as in the original Keras
2 participants