-
Notifications
You must be signed in to change notification settings - Fork 301
Add quantization support for Gemma
, Gemma2
and PaliGemma
#1670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
48ab6df
97305f7
1fb7ffe
876a030
b4e9990
8f65809
a71f83b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,11 +75,7 @@ def __init__(self, *args, dtype=None, **kwargs): | |
id(layer) for layer in self._flatten_layers() | ||
) | ||
self._initialized = True | ||
if dtype is not None: | ||
if isinstance(dtype, keras.DTypePolicy): | ||
self.dtype_policy = dtype | ||
else: | ||
self.dtype_policy = keras.DTypePolicy(dtype) | ||
self.dtype_policy = keras.dtype_policies.get(dtype) | ||
|
||
def __setattr__(self, name, value): | ||
# Work around setattr issues for Keras 2 and Keras 3 torch backend. | ||
|
@@ -107,11 +103,20 @@ def token_embedding(self, value): | |
def get_config(self): | ||
# Don't chain to super here. `get_config()` for functional models is | ||
# a nested layer config and cannot be passed to Backbone constructors. | ||
return { | ||
config = { | ||
"name": self.name, | ||
"trainable": self.trainable, | ||
} | ||
|
||
# Add quantization support by utilizing `DTypePolicyMap` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is great! This should buy us support for all models right? If possible we should consider extending our common backbones tests for this... Doing so would test quantization for the whole library. Seems like it should be doable, call quantize, asset output. WDYT? If we run into failures for certain models, we could add an option to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it is doable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think having the coverage is important. Let's pull this in, and see if we can improve the runtime efficiency as a follow up. Saving is slow. So maybe we can just do something like Something like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will try this in another PR. |
||
policy_map = keras.dtype_policies.DTypePolicyMap() | ||
for layer in self._flatten_layers(): | ||
if layer.quantization_mode is not None: | ||
policy_map[layer.path] = layer.dtype_policy | ||
if len(policy_map) > 0: | ||
config.update({"dtype": policy_map}) | ||
return config | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
# The default `from_config()` for functional models will return a | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we chain to super here to keep most of the logic? and just handle the
if mode == "int8" and not self.tie_weights
case below? Would be great to keep as much logic on the super class as we can.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid not.
The raising of
NotImplementedError
inkeras.layers.Embedding
is intentional and inevitable. The idea is to prevent undefined behavior when users callModel.quantize
.I can introduce an argument like
type_check=True
inkeras.layers.Embedding
to supportsuper
in the future.However, for now, we can only implement
quantize
from scratch.EDITED:
keras-team/keras#19949
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for the explainer.
Not to solve on this PR, but I wonder if we can make the contract between Keras and downstream here more public and minimal. I see
_int_8_call()
,_int_8_build()
,_quantization_mode_error()
,_tracker
, and_untrack_variable()
all used here. That's a pretty significant level of private usage, which could easily break.Separate question, will this work with older version of Keras 3? Or are there small changes we could make so we don't break older versions?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that these methods are too verbose for downstream project. I will try to simplify the contract in the future, but currently, I don't have a good idea for it.
I haven't check the compatibility. My rough guess is that users will need
keras>=3.4.0
due to the introduction ofDTypePolicyMap