Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ jobs:
fail-fast: false
matrix:
backend: [tensorflow, jax, torch]
version: [latest]
include:
- backend: torch
version: 3.1
runs-on: ubuntu-latest
env:
KERAS_BACKEND: ${{ matrix.backend }}
Expand All @@ -42,6 +46,11 @@ jobs:
run: |
pip install -r requirements.txt --progress-bar off
pip install --no-deps -e "." --progress-bar off
- name: Pin Keras version
if: ${{ matrix.version == '3.1'}}
run: |
pip uninstall -y keras
pip install keras==3.1.0 --progress-bar off
- name: Test with pytest
run: |
pytest keras_nlp/
Expand Down
20 changes: 15 additions & 5 deletions keras_nlp/src/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import keras
from keras import ops
from packaging.version import parse

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.utils.keras_utils import assert_quantization_support


@keras_nlp_export("keras_nlp.layers.ReversibleEmbedding")
Expand Down Expand Up @@ -107,7 +109,10 @@ def __init__(

def build(self, inputs_shape=None):
super().build(inputs_shape)
if not self.tie_weights and self.quantization_mode != "int8":
if (
not self.tie_weights
and getattr(self, "quantization_mode", None) != "int8"
):
self.reverse_embeddings = self.add_weight(
name="reverse_embeddings",
shape=(self.output_dim, self.input_dim),
Expand Down Expand Up @@ -142,11 +147,15 @@ def save_own_variables(self, store):
if not self.built:
return
super().save_own_variables(store)
# Before Keras 3.2, the reverse weight is saved in the super() call.
# After Keras 3.2, the reverse weight must be saved manually.
if parse(keras.version()) < parse("3.2.0"):
return
target_variables = []
if not self.tie_weights:
# Store the reverse embedding weights as the last weights.
target_variables.append(self.reverse_embeddings)
if self.quantization_mode == "int8":
if getattr(self, "quantization_mode", None) == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(target_variables, start=len(store)):
store[str(i)] = variable
Expand All @@ -158,7 +167,7 @@ def load_own_variables(self, store):
if not self.tie_weights:
# Last weights in the stores are the reverse embedding weights.
target_variables = [self.reverse_embeddings]
if self.quantization_mode == "int8":
if getattr(self, "quantization_mode", None) == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(
target_variables, start=len(store) - len(target_variables)
Expand Down Expand Up @@ -226,10 +235,11 @@ def _int8_call(self, inputs, reverse=False):

return super()._int8_call(inputs)

def quantize(self, mode):
def quantize(self, mode, type_check=True):
import gc

if type(self) is not ReversibleEmbedding:
assert_quantization_support()
if type_check and type(self) is not ReversibleEmbedding:
raise NotImplementedError(
f"Layer {self.__class__.__name__} does not have a `quantize()` "
"method implemented."
Expand Down
7 changes: 7 additions & 0 deletions keras_nlp/src/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ReversibleEmbedding,
)
from keras_nlp.src.tests.test_case import TestCase
from keras_nlp.src.utils.keras_utils import has_quantization_support


class ReversibleEmbeddingTest(TestCase):
Expand Down Expand Up @@ -103,6 +104,9 @@ def test_reverse_dtype(self):
("tie_weights", True), ("untie_weights", False)
)
def test_quantize_int8(self, tie_weights):
if not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

layer_config = dict(
input_dim=100, output_dim=32, tie_weights=tie_weights
)
Expand Down Expand Up @@ -151,6 +155,9 @@ def test_quantize_int8(self, tie_weights):
("untie_weights", False),
)
def test_quantize_dtype_argument(self, tie_weights):
if not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

self.run_layer_test(
cls=ReversibleEmbedding,
init_kwargs={
Expand Down
35 changes: 25 additions & 10 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import keras

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.utils.keras_utils import assert_quantization_support
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
Expand Down Expand Up @@ -75,7 +76,14 @@ def __init__(self, *args, dtype=None, **kwargs):
id(layer) for layer in self._flatten_layers()
)
self._initialized = True
self.dtype_policy = keras.dtype_policies.get(dtype)
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
if hasattr(keras.dtype_policies, "get"):
self.dtype_policy = keras.dtype_policies.get(dtype)
else:
if isinstance(dtype, keras.dtype_policies.DTypePolicy):
dtype = dtype.name
dtype = dtype or keras.config.dtype_policy().name
self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)

def __setattr__(self, name, value):
# Work around setattr issues for Keras 2 and Keras 3 torch backend.
Expand All @@ -100,6 +108,10 @@ def token_embedding(self):
def token_embedding(self, value):
self._token_embedding = value

def quantize(self, mode, **kwargs):
assert_quantization_support()
return super().quantize(mode, **kwargs)

def get_config(self):
# Don't chain to super here. `get_config()` for functional models is
# a nested layer config and cannot be passed to Backbone constructors.
Expand All @@ -109,15 +121,18 @@ def get_config(self):
}

# Add quantization support by utilizing `DTypePolicyMap`
if isinstance(self.dtype_policy, keras.dtype_policies.DTypePolicyMap):
config.update({"dtype": self.dtype_policy})
else:
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
if hasattr(keras.dtype_policies, "DTypePolicyMap"):
if isinstance(
self.dtype_policy, keras.dtype_policies.DTypePolicyMap
):
config.update({"dtype": self.dtype_policy})
else:
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
return config

@classmethod
Expand Down
9 changes: 8 additions & 1 deletion keras_nlp/src/models/pali_gemma/pali_gemma_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,14 @@ def __init__(
classifier_activation
)
self.image_sequence_length = int((image_size / patch_size) ** 2)
self.dtype_policy = keras.dtype_policies.get(dtype)
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
if hasattr(keras.dtype_policies, "get"):
self.dtype_policy = keras.dtype_policies.get(dtype)
else:
if isinstance(dtype, keras.dtype_policies.DTypePolicy):
dtype = dtype.name
dtype = dtype or keras.config.dtype_policy().name
self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)

def get_config(self):
config = super().get_config()
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from keras_nlp.src import layers as keras_nlp_layers
from keras_nlp.src.tokenizers.tokenizer import Tokenizer
from keras_nlp.src.utils.keras_utils import has_quantization_support
from keras_nlp.src.utils.tensor_utils import is_float_dtype


Expand Down Expand Up @@ -445,7 +446,7 @@ def run_backbone_test(
self.run_precision_test(cls, init_kwargs, input_data)

# Check quantization.
if run_quantization_check:
if run_quantization_check and has_quantization_support():
self.run_quantization_test(backbone, cls, init_kwargs, input_data)

def run_task_test(
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import keras
from absl import logging
from packaging.version import parse

from keras_nlp.src.utils.tensor_utils import is_tensor_type

Expand Down Expand Up @@ -102,3 +103,15 @@ def print_msg(message, line_break=True):
@keras.saving.register_keras_serializable(package="keras_nlp")
def gelu_approximate(x):
return keras.activations.gelu(x, approximate=True)


def has_quantization_support():
return False if parse(keras.version()) < parse("3.4.0") else True


def assert_quantization_support():
if not has_quantization_support():
raise ValueError(
"Quantization API requires Keras >= 3.4.0 to function "
f"correctly. Received: '{keras.version()}'"
)