Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 119 additions & 10 deletions keras_nlp/src/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def __init__(

def build(self, inputs_shape=None):
super().build(inputs_shape)

if not self.tie_weights:
if not self.tie_weights and self.quantization_mode != "int8":
self.reverse_embeddings = self.add_weight(
name="reverse_embeddings",
shape=(self.output_dim, self.input_dim),
Expand Down Expand Up @@ -143,20 +142,28 @@ def save_own_variables(self, store):
if not self.built:
return
super().save_own_variables(store)
# Before Keras 3.2, the reverse weight is saved in the super() call.
# After Keras 3.2, the reverse weight must be saved manually.
if len(store.keys()) < len(self.weights):
# Store the reverse embedding as the last weight.
store[str(len(store.keys()))] = self.reverse_embeddings
target_variables = []
if not self.tie_weights:
# Store the reverse embedding weights as the last weights.
target_variables.append(self.reverse_embeddings)
if self.quantization_mode == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(target_variables, start=len(store)):
store[str(i)] = variable

def load_own_variables(self, store):
if not self.built:
self.build()
super().load_own_variables(store)
if not self.tie_weights:
# Last weight in the store is the reverse embedding weights.
key = str(len(store.keys()) - 1)
self.reverse_embeddings.assign(store[key])
# Last weights in the stores are the reverse embedding weights.
target_variables = [self.reverse_embeddings]
if self.quantization_mode == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(
target_variables, start=len(store) - len(target_variables)
):
variable.assign(store[str(i)])

def compute_output_spec(self, inputs, reverse=False):
output_shape = list(inputs.shape)
Expand All @@ -165,3 +172,105 @@ def compute_output_spec(self, inputs, reverse=False):
else:
output_shape += [self.output_dim]
return keras.KerasTensor(output_shape, dtype=self.dtype)

# Quantization-related (int8) methods

def quantized_call(self, inputs, reverse=False):
# TODO (hongyu): This function could be removed once we add `*args` and
# `**kwargs` for `Embedding.quantized_call`
if self.quantization_mode == "int8":
return self._int8_call(inputs, reverse=reverse)
else:
self._quantization_mode_error(self.quantization_mode)

def _int8_build(
self,
embeddings_initializer="zeros",
embeddings_scale_initializer="ones",
reverse_embeddings_initializer="zeros",
reverse_embeddings_scale_initializer="ones",
):
super()._int8_build(
embeddings_initializer, embeddings_scale_initializer
)
self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1)
if not self.tie_weights:
self.reverse_embeddings = self.add_weight(
name="reverse_embeddings",
shape=(self.output_dim, self.input_dim),
initializer=reverse_embeddings_initializer,
dtype="int8",
trainable=False,
)
self.reverse_embeddings_scale = self.add_weight(
name="reverse_embeddings_scale",
shape=(self.input_dim,),
initializer=reverse_embeddings_scale_initializer,
trainable=False,
)

def _int8_call(self, inputs, reverse=False):
if reverse:
if self.tie_weights:
kernel = ops.transpose(self._embeddings)
scale = ops.transpose(self.embeddings_scale)
else:
kernel = self.reverse_embeddings
scale = self.reverse_embeddings_scale
inputs, inputs_scale = self.inputs_quantizer(inputs)
outputs = ops.matmul(inputs, kernel)
# De-scale outputs
outputs = ops.cast(outputs, self.compute_dtype)
outputs = ops.divide(outputs, ops.multiply(inputs_scale, scale))
return outputs

return super()._int8_call(inputs)

def quantize(self, mode):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we chain to super here to keep most of the logic? and just handle the if mode == "int8" and not self.tie_weights case below? Would be great to keep as much logic on the super class as we can.

Copy link
Collaborator Author

@james77777778 james77777778 Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid not.
The raising of NotImplementedError in keras.layers.Embedding is intentional and inevitable. The idea is to prevent undefined behavior when users call Model.quantize.

I can introduce an argument like type_check=True in keras.layers.Embedding to support super in the future.
However, for now, we can only implement quantize from scratch.

EDITED:
keras-team/keras#19949

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for the explainer.

Not to solve on this PR, but I wonder if we can make the contract between Keras and downstream here more public and minimal. I see _int_8_call(), _int_8_build(), _quantization_mode_error(), _tracker, and _untrack_variable() all used here. That's a pretty significant level of private usage, which could easily break.

Separate question, will this work with older version of Keras 3? Or are there small changes we could make so we don't break older versions?

Copy link
Collaborator Author

@james77777778 james77777778 Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that these methods are too verbose for downstream project. I will try to simplify the contract in the future, but currently, I don't have a good idea for it.

will this work with older version of Keras 3?

I haven't check the compatibility. My rough guess is that users will need keras>=3.4.0 due to the introduction of DTypePolicyMap

import gc

if type(self) is not ReversibleEmbedding:
raise NotImplementedError(
f"Layer {self.__class__.__name__} does not have a `quantize()` "
"method implemented."
)
self._check_quantize_args(mode, self.compute_dtype)

self._tracker.unlock()
if mode == "int8":
embeddings, embeddings_scale = keras.quantizers.abs_max_quantize(
self._embeddings, axis=-1
)
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
self._untrack_variable(self._embeddings)
del self._embeddings
if not self.tie_weights:
reverse_embeddings, reverse_embeddings_scale = (
keras.quantizers.abs_max_quantize(
self.reverse_embeddings, axis=0
)
)
reverse_embeddings_scale = ops.squeeze(
reverse_embeddings_scale, axis=0
)
self._untrack_variable(self.reverse_embeddings)
del self.reverse_embeddings
else:
reverse_embeddings = None
reverse_embeddings_scale = None
self._int8_build(
lambda shape, dtype: embeddings,
lambda shape, dtype: embeddings_scale,
lambda shape, dtype: reverse_embeddings,
lambda shape, dtype: reverse_embeddings_scale,
)
else:
raise self._quantization_mode_error(mode)
self._tracker.lock()

if self.dtype_policy.quantization_mode is None:
policy = keras.dtype_policies.get(
f"{mode}_from_{self.dtype_policy.name}"
)
self.dtype_policy = policy
gc.collect()
68 changes: 68 additions & 0 deletions keras_nlp/src/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,71 @@ def test_reverse_dtype(self):
output_data = embedding(input_data, reverse=True)
self.assertEqual(output_data.shape, (4, 10, 100))
self.assertDTypeEqual(output_data, "float16")

@parameterized.named_parameters(
("tie_weights", True), ("untie_weights", False)
)
def test_quantize_int8(self, tie_weights):
layer_config = dict(
input_dim=100, output_dim=32, tie_weights=tie_weights
)
layer = ReversibleEmbedding(**layer_config)
layer.build()
x = random.randint(shape=(64, 100), minval=0, maxval=9)
x_reverse = random.uniform(shape=(64, 32))
y_float = layer(x)
y_reverse_float = layer(x_reverse, reverse=True)
layer.quantize("int8")

# Verify weights dtype
if not tie_weights:
self.assertEqual(
keras.backend.standardize_dtype(layer.reverse_embeddings.dtype),
"int8",
)
self.assertEqual(
keras.backend.standardize_dtype(
layer.reverse_embeddings_scale.dtype
),
layer.variable_dtype,
)

# Try eager call and verify output correctness
y_quantized = layer(x)
y_reverse_quantized = layer(x_reverse, reverse=True)
mse = ops.mean(ops.square(y_float - y_quantized))
mse_reverse = ops.mean(
ops.square(y_reverse_float - y_reverse_quantized)
)
self.assertLess(mse, 1e-3) # A weak correctness test
self.assertLess(mse_reverse, 1e-3) # A weak correctness test

# Try saving and reloading the model
model = keras.models.Sequential([layer])
temp_filepath = os.path.join(
self.get_temp_dir(), "quantized_model.keras"
)
model.save(temp_filepath)
new_model = keras.models.load_model(temp_filepath)
self.assertAllClose(model.predict(x), new_model.predict(x))

@parameterized.named_parameters(
("tie_weights", True),
("untie_weights", False),
)
def test_quantize_dtype_argument(self, tie_weights):
self.run_layer_test(
cls=ReversibleEmbedding,
init_kwargs={
"input_dim": 100,
"output_dim": 32,
"tie_weights": tie_weights,
"embeddings_initializer": "HeNormal",
"dtype": "int8_from_float32",
},
input_data=random.randint(minval=0, maxval=100, shape=(4, 10)),
expected_output_shape=(4, 10, 32),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=2 if tie_weights else 4,
expected_num_non_trainable_variables=2 if tie_weights else 4,
)
17 changes: 11 additions & 6 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,7 @@ def __init__(self, *args, dtype=None, **kwargs):
id(layer) for layer in self._flatten_layers()
)
self._initialized = True
if dtype is not None:
if isinstance(dtype, keras.DTypePolicy):
self.dtype_policy = dtype
else:
self.dtype_policy = keras.DTypePolicy(dtype)
self.dtype_policy = keras.dtype_policies.get(dtype)

def __setattr__(self, name, value):
# Work around setattr issues for Keras 2 and Keras 3 torch backend.
Expand Down Expand Up @@ -107,11 +103,20 @@ def token_embedding(self, value):
def get_config(self):
# Don't chain to super here. `get_config()` for functional models is
# a nested layer config and cannot be passed to Backbone constructors.
return {
config = {
"name": self.name,
"trainable": self.trainable,
}

# Add quantization support by utilizing `DTypePolicyMap`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! This should buy us support for all models right? If possible we should consider extending our common backbones tests for this...

https://github.com/keras-team/keras-nlp/blob/e4f09b24c699857edae27c8054aab44078e9cbd5/keras_nlp/src/tests/test_case.py#L359-L367

https://github.com/keras-team/keras-nlp/blob/e4f09b24c699857edae27c8054aab44078e9cbd5/keras_nlp/src/models/gemma/gemma_backbone_test.py#L39-L45

Doing so would test quantization for the whole library. Seems like it should be doable, call quantize, asset output. WDYT?

If we run into failures for certain models, we could add an option to run_backbone_test, called run_quantization_check=True, and set the option to false if the model fails, with a TODO to investigate.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is doable.
I have added run_quantization_test to run_backbone_test. Only Bloom and OPT failed the test.
However, there is a significant speed regression after adding this test. The CI time increased from ~19mins to ~27mins. Is this acceptable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having the coverage is important. Let's pull this in, and see if we can improve the runtime efficiency as a follow up.

Saving is slow. So maybe we can just do something like

Something like:

  • Basic quantization tests do not hit saving. Just test get_config(), from_config() maybe assigning weights over.
  • Separate quantization testing in our saving test harness. That is marked with large, and is only run on larger/faster hardware.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try this in another PR.

policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
return config

@classmethod
def from_config(cls, config):
# The default `from_config()` for functional models will return a
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/bloom/bloom_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 8),
# TODO: Set to `True`. Error msg: Layer LayerNormalization does not
# have a `quantized_call()` method implemented.
run_quantization_check=False,
)

@pytest.mark.large
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def test_distribution_with_lora(self):
self.assertEqual(tuple(w.value.sharding.spec), (None, None))


@pytest.mark.keras_3_only
class Gemma2BackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/src/models/opt/opt_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 2),
# TODO: Set to `True`. Error msg: Layer 'token_embedding' expected 1
# variables, but received 0 variables during loading. Expected:
# ['embeddings']
run_quantization_check=False,
)

@pytest.mark.large
Expand Down
Loading