-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor keras.dtype_policies
#19711
Refactor keras.dtype_policies
#19711
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #19711 +/- ##
=======================================
Coverage 78.52% 78.53%
=======================================
Files 498 498
Lines 45769 45756 -13
Branches 8456 8454 -2
=======================================
- Hits 35942 35936 -6
+ Misses 8091 8087 -4
+ Partials 1736 1733 -3
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
@@ -202,6 +202,10 @@ def __repr__(self): | |||
return f'<FloatDTypePolicy "{self._name}">' | |||
|
|||
|
|||
GLOBAL_DEFAULT_PLACEHOLDER = "global_default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a more explicit name, e.g. "DEFAULT_DTYPE_POLICY". Why use this string as the initial value, instead of e.g. None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use this string as the initial value, instead of e.g. None?
Currently, DTypePolicy
and its subclasses rely on string value for parsing.
It is not clear for me how we can pass None
in combination with the quantization mode.
Should we refactor QuantizedDTypePolicy
to support a signature for both the quantization mode and the source dtype policy?
Ex:
policy = QuantizedDTypePolicy(mode="int8", source_dtype_policy="mixed_bfloat16")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, DTypePolicy and its subclasses rely on string value for parsing.
It is not clear for me how we can pass None in combination with the quantization mode.
We could just modify DTypePolicy
to support None
, meaning "default".
Should we refactor QuantizedDTypePolicy to support a signature for both the quantization mode and the source dtype policy?
Yes, that's a great idea!
QuantizedDTypePolicy
keras.dtype_policies
I've significantly refactored the Some notes:
Imcompatible warning:
To add flexibility to quantized dtype policy: import keras
from keras import dtype_policies
from keras import layers
from keras import models
@keras.saving.register_keras_serializable("MyPackage")
class MySubclass(layers.Layer):
def __init__(self, **kwargs):
dtypes = kwargs.pop("dtypes", {})
super().__init__(**kwargs)
self.layer = layers.Dense(8, dtype=dtypes.pop("layer", None))
def call(self, inputs, training=None):
return self.layer(inputs)
def get_config(self):
config = super().get_config()
config.pop("dtype")
if self.layer.dtype_policy.is_quantized:
_config = dtype_policies.serialize(self.layer.dtype_policy)
_config["config"]["source_name"] = None
config.update({"dtypes": {"layer": _config}})
return config
inputs = layers.Input(shape=[None, 4])
outputs = MySubclass()(inputs)
model = models.Model(inputs, outputs)
"""global dtype policy (float32)"""
model.quantize("int8")
for layer in model._flatten_layers(include_self=False, recursive=True):
print(layer.name, layer.dtype_policy)
model.save("model.keras")
"""global dtype policy (bfloat16)"""
keras.config.set_dtype_policy("bfloat16")
new_model = models.load_model("model.keras")
for layer in new_model._flatten_layers(include_self=False, recursive=True):
print(layer.name, layer.dtype_policy) The outputs: # global dtype policy: float32
input_layer <FloatDTypePolicy "float32">
my_subclass <FloatDTypePolicy "float32">
dense <QuantizedDTypePolicy "int8_from_float32">
# global dtype policy: bfloat16
input_layer <FloatDTypePolicy "bfloat16">
my_subclass <FloatDTypePolicy "bfloat16">
dense_1 <QuantizedDTypePolicy "int8_from_bfloat16"> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work -- it's definitely cleaner this way! LGTM
Keras' output format was slightly changed in keras-team/keras#19711; in some cases dtypes will now be exported as a config map instead of just a string. This fixes test breakages when using ToT keras.
Keras' output format was slightly changed in keras-team/keras#19711; for non-input layers dtypes will now be exported as a config map instead of just a string. This fixes test breakages when using ToT keras.
Keras' output format was slightly changed in keras-team/keras#19711; for non-input layers dtypes will now be exported as a config map instead of just a string. This fixes test breakages when using ToT keras.
Keras' output format was slightly changed in keras-team/keras#19711; for non-input layers dtypes will now be exported as a config map instead of just a string. This fixes test breakages when using ToT keras. Alternative to #6855
EDITED:
Please refer to #19711 (comment) for the new updates.
I think it would be beneficial to provide some flexibility to
QuantizedDTypePolicy
regarding the global dtype policykeras.config.dtype_policy()
Additionally, there is a new property in
DTypePolicy
:is_quantized
that should be useful for these quantization-related methods.With this PR, we can do the following:
Outputs:
@mattdangerw has pointed out that currently the dtype policies of the quantized saves are immutable regarding the global dtype policy. keras-team/keras-nlp#1612 (comment)
With this PR, we can make a slight modification in
get_config
to support that feature.