Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,15 @@ def get_config(self):
}

# Add quantization support by utilizing `DTypePolicyMap`
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
if isinstance(self.dtype_policy, keras.dtype_policies.DTypePolicyMap):
config.update({"dtype": self.dtype_policy})
else:
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
return config

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/src/models/bloom/bloom_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
self.embeddings_layer_norm = keras.layers.LayerNormalization(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="token_embedding_layernorm",
name="embedding_layernorm",
)
self.transformer_layers = []
for i in range(num_layers):
Expand Down
3 changes: 0 additions & 3 deletions keras_nlp/src/models/bloom/bloom_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 8),
# TODO: Set to `True`. Error msg: Layer LayerNormalization does not
# have a `quantized_call()` method implemented.
run_quantization_check=False,
)

@pytest.mark.large
Expand Down
22 changes: 13 additions & 9 deletions keras_nlp/src/models/opt/opt_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,16 @@ def __init__(
self.max_sequence_length = max_sequence_length

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
}
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
}
)
return config
4 changes: 0 additions & 4 deletions keras_nlp/src/models/opt/opt_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 2),
# TODO: Set to `True`. Error msg: Layer 'token_embedding' expected 1
# variables, but received 0 variables during loading. Expected:
# ['embeddings']
run_quantization_check=False,
)

@pytest.mark.large
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/src/models/xlnet/xlnet_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 2),
run_quantization_check=False, # TODO(hongyu): set to `True`
)

@pytest.mark.large
Expand Down
58 changes: 37 additions & 21 deletions keras_nlp/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from keras import ops
from keras import tree

from keras_nlp.src import layers as keras_nlp_layers
from keras_nlp.src.tokenizers.tokenizer import Tokenizer
from keras_nlp.src.utils.tensor_utils import is_float_dtype

Expand Down Expand Up @@ -336,29 +337,44 @@ def run_precision_test(self, cls, init_kwargs, input_data):
self.assertEqual(policy.compute_dtype, sublayer.compute_dtype)
self.assertEqual(policy.variable_dtype, sublayer.variable_dtype)

def run_quantization_test(self, cls, init_kwargs, input_data):
policy = keras.DTypePolicy("float32")
def run_quantization_test(self, instance, cls, init_kwargs, input_data):
def _get_supported_layers(mode):
supported_layers = [keras.layers.Dense, keras.layers.EinsumDense]
if mode == "int8":
supported_layers.append(keras.layers.Embedding)
supported_layers.append(keras_nlp_layers.ReversibleEmbedding)
return supported_layers

for mode in ["int8", "float8"]:
layer = cls(**{**init_kwargs, "dtype": policy})
layer.quantize(mode)
# Try eager call
if isinstance(layer, keras.Model):
_ = layer(input_data)
# Manually configure DTypePolicyMap to avoid intensive computation
# in `Model.quantize`.
policy_map = keras.dtype_policies.DTypePolicyMap("float32")
for layer in instance._flatten_layers():
if type(layer) in _get_supported_layers(mode):
policy_map[layer.path] = keras.dtype_policies.get(
f"{mode}_from_float32"
)
# Instantiate the layer.
model = cls(**{**init_kwargs, "dtype": policy_map})
# Call layer eagerly.
if isinstance(model, keras.Model):
_ = model(input_data)
elif isinstance(input_data, dict):
_ = layer(**input_data)
_ = model(**input_data)
else:
_ = layer(input_data)
# Verify sublayer's dtype policy
for sublayer in layer._flatten_layers():
if type(sublayer) is keras.layers.Dense:
self.assertEqual(
f"{mode}_from_float32", sublayer.dtype_policy.name
)
# Try saving and reloading the model
temp_filepath = os.path.join(self.get_temp_dir(), "layer.keras")
layer.save(temp_filepath)
reloaded_layer = keras.models.load_model(temp_filepath)
self.assertAllClose(layer(input_data), reloaded_layer(input_data))
_ = model(input_data)
# Verify sublayer's dtype policy.
for sublayer in model._flatten_layers():
if type(sublayer) in _get_supported_layers(mode):
self.assertEqual(mode, sublayer.quantization_mode)
# `get_config` roundtrip.
cfg = model.get_config()
revived_model = cls.from_config(cfg)
revived_cfg = revived_model.get_config()
self.assertEqual(cfg, revived_cfg)
# Check weights loading.
weights = model.get_weights()
revived_model.set_weights(weights)

def run_model_saving_test(
self,
Expand Down Expand Up @@ -436,7 +452,7 @@ def run_backbone_test(

# Check quantization.
if run_quantization_check:
self.run_quantization_test(cls, init_kwargs, input_data)
self.run_quantization_test(backbone, cls, init_kwargs, input_data)

def run_task_test(
self,
Expand Down