Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
version: [keras-stable]
include:
- backend: jax
version: keras-3.1
version: keras-3.5
- backend: jax
version: keras-nightly
runs-on: ubuntu-latest
Expand Down Expand Up @@ -48,11 +48,11 @@ jobs:
run: |
pip install -r requirements.txt --progress-bar off
pip install --no-deps -e "." --progress-bar off
- name: Pin Keras 3.1
if: ${{ matrix.version == 'keras-3.1'}}
- name: Pin Keras 3.5
if: ${{ matrix.version == 'keras-3.5'}}
run: |
pip uninstall -y keras
pip install keras==3.1.0 --progress-bar off
pip install keras==3.5.0 --progress-bar off
- name: Pin Keras Nightly
if: ${{ matrix.version == 'keras-nightly'}}
run: |
Expand Down
19 changes: 3 additions & 16 deletions keras_hub/src/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import keras
from keras import ops
from packaging.version import parse

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.utils.keras_utils import assert_quantization_support


@keras_hub_export("keras_hub.layers.ReversibleEmbedding")
Expand Down Expand Up @@ -145,10 +143,6 @@ def save_own_variables(self, store):
if not self.built:
return
super().save_own_variables(store)
# Before Keras 3.2, the reverse weight is saved in the super() call.
# After Keras 3.2, the reverse weight must be saved manually.
if parse(keras.version()) < parse("3.2.0"):
return
target_variables = []
if not self.tie_weights:
# Store the reverse embedding weights as the last weights.
Expand Down Expand Up @@ -239,9 +233,7 @@ def _int8_call(self, inputs, reverse=False):

def quantize(self, mode, type_check=True):
import gc
import inspect

assert_quantization_support()
if type_check and type(self) is not ReversibleEmbedding:
raise NotImplementedError(
f"Layer {self.__class__.__name__} does not have a `quantize()` "
Expand All @@ -250,14 +242,9 @@ def quantize(self, mode, type_check=True):
self._check_quantize_args(mode, self.compute_dtype)

def abs_max_quantize(inputs, axis):
sig = inspect.signature(keras.quantizers.abs_max_quantize)
if "to_numpy" in sig.parameters:
return keras.quantizers.abs_max_quantize(
inputs, axis=axis, to_numpy=True
)
else:
# `keras<=3.4.1` doesn't support `to_numpy`
return keras.quantizers.abs_max_quantize(inputs, axis=axis)
return keras.quantizers.abs_max_quantize(
inputs, axis=axis, to_numpy=True
)

self._tracker.unlock()
if mode == "int8":
Expand Down
7 changes: 0 additions & 7 deletions keras_hub/src/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
ReversibleEmbedding,
)
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.keras_utils import has_quantization_support


class ReversibleEmbeddingTest(TestCase):
Expand Down Expand Up @@ -97,9 +96,6 @@ def test_reverse_dtype(self):
("tie_weights", True), ("untie_weights", False)
)
def test_quantize_int8(self, tie_weights):
if not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

layer_config = dict(
input_dim=100, output_dim=32, tie_weights=tie_weights
)
Expand Down Expand Up @@ -148,9 +144,6 @@ def test_quantize_int8(self, tie_weights):
("untie_weights", False),
)
def test_quantize_dtype_argument(self, tie_weights):
if not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

self.run_layer_test(
cls=ReversibleEmbedding,
init_kwargs={
Expand Down
5 changes: 0 additions & 5 deletions keras_hub/src/models/backbone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.utils.keras_utils import assert_quantization_support
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import get_preset_saver
Expand Down Expand Up @@ -83,10 +82,6 @@ def token_embedding(self):
def token_embedding(self, value):
self._token_embedding = value

def quantize(self, mode, **kwargs):
assert_quantization_support()
return super().quantize(mode, **kwargs)

def get_config(self):
# Don't chain to super here. `get_config()` for functional models is
# a nested layer config and cannot be passed to Backbone constructors.
Expand Down
3 changes: 1 addition & 2 deletions keras_hub/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)
from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid
from keras_hub.src.tokenizers.tokenizer import Tokenizer
from keras_hub.src.utils.keras_utils import has_quantization_support
from keras_hub.src.utils.tensor_utils import is_float_dtype


Expand Down Expand Up @@ -487,7 +486,7 @@ def run_backbone_test(
self.run_precision_test(cls, init_kwargs, input_data)

# Check quantization.
if run_quantization_check and has_quantization_support():
if run_quantization_check:
self.run_quantization_test(backbone, cls, init_kwargs, input_data)

def run_vision_backbone_test(
Expand Down
13 changes: 0 additions & 13 deletions keras_hub/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import keras
from absl import logging
from packaging.version import parse

try:
import tensorflow as tf
Expand Down Expand Up @@ -41,18 +40,6 @@ def gelu_approximate(x):
return keras.activations.gelu(x, approximate=True)


def has_quantization_support():
return False if parse(keras.version()) < parse("3.4.0") else True


def assert_quantization_support():
if not has_quantization_support():
raise ValueError(
"Quantization API requires Keras >= 3.4.0 to function "
f"correctly. Received: '{keras.version()}'"
)


def standardize_data_format(data_format):
if data_format is None:
return keras.config.image_data_format()
Expand Down
7 changes: 2 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,14 @@ def get_version(rel_path):
author_email="keras-hub@google.com",
license="Apache License 2.0",
install_requires=[
"keras>=3.5",
"absl-py",
"numpy",
"packaging",
"regex",
"rich",
"kagglehub",
# Don't require tensorflow-text on MacOS, there are no binaries for ARM.
# Also, we rely on tensorflow *transitively* through tensorflow-text.
# This avoid a slowdown during `pip install keras-hub` where pip would
# download many version of both libraries to find compatible versions.
"tensorflow-text; platform_system != 'Darwin'",
"tensorflow-text",
],
extras_require={
"extras": [
Expand Down
Loading