diff --git a/keras_nlp/api/__init__.py b/keras_nlp/api/__init__.py index 0225dd69a0..d0dc4576c6 100644 --- a/keras_nlp/api/__init__.py +++ b/keras_nlp/api/__init__.py @@ -22,7 +22,6 @@ from keras_nlp.api import models from keras_nlp.api import samplers from keras_nlp.api import tokenizers -from keras_nlp.api import utils from keras_nlp.src.utils.preset_utils import upload_preset from keras_nlp.src.version_utils import __version__ from keras_nlp.src.version_utils import version diff --git a/keras_nlp/api/layers/__init__.py b/keras_nlp/api/layers/__init__.py index 73ad66b30b..8b92cc11b0 100644 --- a/keras_nlp/api/layers/__init__.py +++ b/keras_nlp/api/layers/__init__.py @@ -36,6 +36,8 @@ ) from keras_nlp.src.layers.modeling.transformer_decoder import TransformerDecoder from keras_nlp.src.layers.modeling.transformer_encoder import TransformerEncoder +from keras_nlp.src.layers.preprocessing.audio_converter import AudioConverter +from keras_nlp.src.layers.preprocessing.image_converter import ImageConverter from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import ( MaskedLMMaskGenerator, ) @@ -44,4 +46,13 @@ ) from keras_nlp.src.layers.preprocessing.random_deletion import RandomDeletion from keras_nlp.src.layers.preprocessing.random_swap import RandomSwap +from keras_nlp.src.layers.preprocessing.resizing_image_converter import ( + ResizingImageConverter, +) from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.pali_gemma.pali_gemma_image_converter import ( + PaliGemmaImageConverter, +) +from keras_nlp.src.models.whisper.whisper_audio_converter import ( + WhisperAudioConverter, +) diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index e78271531c..64368e4c45 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -228,13 +228,7 @@ from keras_nlp.src.models.text_classifier_preprocessor import ( TextClassifierPreprocessor, ) -from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( - WhisperAudioFeatureExtractor, -) from keras_nlp.src.models.whisper.whisper_backbone import WhisperBackbone -from keras_nlp.src.models.whisper.whisper_preprocessor import ( - WhisperPreprocessor, -) from keras_nlp.src.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_nlp.src.models.xlm_roberta.xlm_roberta_backbone import ( XLMRobertaBackbone, diff --git a/keras_nlp/src/layers/preprocessing/audio_converter.py b/keras_nlp/src/layers/preprocessing/audio_converter.py new file mode 100644 index 0000000000..c552b86e04 --- /dev/null +++ b/keras_nlp/src/layers/preprocessing/audio_converter.py @@ -0,0 +1,121 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.layers.preprocessing.preprocessing_layer import ( + PreprocessingLayer, +) +from keras_nlp.src.utils.preset_utils import AUDIO_CONVERTER_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import find_subclass +from keras_nlp.src.utils.preset_utils import get_preset_loader +from keras_nlp.src.utils.preset_utils import list_presets +from keras_nlp.src.utils.preset_utils import list_subclasses +from keras_nlp.src.utils.preset_utils import save_serialized_object +from keras_nlp.src.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.layers.AudioConverter") +class AudioConverter(PreprocessingLayer): + """Convert raw audio for models that support audio input. + + This class converts from raw audio tensors of any length, to preprocessed + audio for pretrained model inputs. It is meant to be a convenient way to + write custom preprocessing code that is not model specific. This layer + should be instantiated via the `from_preset()` constructor, which will + create the correct subclass of this layer for the model preset. + + The layer will take as input a raw audio tensor with shape `(batch_size, + num_samples)`, and output a preprocessed audio input for modeling. The exact + structure of the preprocessed input will vary per model. Preprocessing + will often include computing a spectogram of the raw audio signal. + + Examples: + ```python + # Load an audio converter from a preset. + converter = keras_nlp.layers.AudioConverter.from_preset("whisper_base_en") + # Convert some raw audio input. + converter(np.ones(2, 1_000)) + ``` + """ + + backbone_cls = None + + @classproperty + def presets(cls): + """List built-in presets for a `Task` subclass.""" + presets = list_presets(cls) + for subclass in list_subclasses(cls): + presets.update(subclass.presets) + return presets + + @classmethod + def from_preset( + cls, + preset, + **kwargs, + ): + """Instantiate a `keras_nlp.layers.AudioConverter` from a model preset. + + A preset is a directory of configs, weights and other file assets used + to save and load a pre-trained model. The `preset` can be passed as + one of: + + 1. a built-in preset identifier like `'whisper_base_en'` + 2. a Kaggle Models handle like + `'kaggle://user/whisper/keras/whisper_base_en'` + 3. a Hugging Face handle like `'hf://user/whisper_base_en'` + 4. a path to a local preset directory like `'./whisper_base_en'` + + You can run `cls.presets.keys()` to list all built-in presets available + on the class. + + This constructor can be called in one of two ways. Either from the base + class like `keras_nlp.models.AudioConverter.from_preset()`, or from a + model class like `keras_nlp.models.WhisperAudioConverter.from_preset()`. + If calling from the base class, the subclass of the returning object + will be inferred from the config in the preset directory. + + Args: + preset: string. A built-in preset identifier, a Kaggle Models + handle, a Hugging Face handle, or a path to a local directory. + load_weights: bool. If `True`, the weights will be loaded into the + model architecture. If `False`, the weights will be randomly + initialized. + + Examples: + ```python + # Load an audio converter from a preset. + converter = keras_nlp.layers.AudioConverter.from_preset( + "whisper_base_en" + ) + # Convert some raw mono channel audio input. + converter(np.ones(2, 1_000)) + ``` + """ + loader = get_preset_loader(preset) + backbone_cls = loader.check_backbone_class() + if cls.backbone_cls != backbone_cls: + cls = find_subclass(preset, cls, backbone_cls) + return loader.load_audio_converter(cls, **kwargs) + + def save_to_preset(self, preset_dir): + """Save audio converter to a preset directory. + + Args: + preset_dir: The path to the local model preset directory. + """ + save_serialized_object( + self, + preset_dir, + config_file=AUDIO_CONVERTER_CONFIG_FILE, + ) diff --git a/keras_nlp/src/layers/preprocessing/audio_converter_test.py b/keras_nlp/src/layers/preprocessing/audio_converter_test.py new file mode 100644 index 0000000000..7ac6ab22fd --- /dev/null +++ b/keras_nlp/src/layers/preprocessing/audio_converter_test.py @@ -0,0 +1,69 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib + +import numpy as np +import pytest + +from keras_nlp.src.layers.preprocessing.audio_converter import AudioConverter +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.whisper.whisper_audio_converter import ( + WhisperAudioConverter, +) +from keras_nlp.src.tests.test_case import TestCase + + +class AudioConverterTest(TestCase): + def test_preset_accessors(self): + pali_gemma_presets = set(WhisperAudioConverter.presets.keys()) + all_presets = set(AudioConverter.presets.keys()) + self.assertContainsSubset(pali_gemma_presets, all_presets) + + @pytest.mark.large + def test_from_preset(self): + self.assertIsInstance( + AudioConverter.from_preset("whisper_tiny_en"), + WhisperAudioConverter, + ) + + @pytest.mark.large + def test_from_preset_errors(self): + with self.assertRaises(ValueError): + AudioConverter.from_preset("bert_tiny_en_uncased") + with self.assertRaises(ValueError): + # No loading on a non-keras model. + AudioConverter.from_preset("hf://spacy/en_core_web_sm") + + @pytest.mark.large + def test_save_to_preset(self): + save_dir = self.get_temp_dir() + converter = AudioConverter.from_preset( + "whisper_tiny_en", + num_mels=40, + ) + converter.save_to_preset(save_dir) + # Save a backbone so the preset is valid. + backbone = Backbone.from_preset("whisper_tiny_en", load_weights=False) + backbone.save_to_preset(save_dir) + + # Check existence of files. + path = pathlib.Path(save_dir) + self.assertTrue(os.path.exists(path / "audio_converter.json")) + + # Check loading. + restored = AudioConverter.from_preset(save_dir) + test_audio = np.random.rand(1_000) + self.assertAllClose(restored(test_audio), converter(test_audio)) diff --git a/keras_nlp/src/layers/preprocessing/image_converter.py b/keras_nlp/src/layers/preprocessing/image_converter.py new file mode 100644 index 0000000000..28ace2aba3 --- /dev/null +++ b/keras_nlp/src/layers/preprocessing/image_converter.py @@ -0,0 +1,130 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.layers.preprocessing.preprocessing_layer import ( + PreprocessingLayer, +) +from keras_nlp.src.utils.preset_utils import IMAGE_CONVERTER_CONFIG_FILE +from keras_nlp.src.utils.preset_utils import find_subclass +from keras_nlp.src.utils.preset_utils import get_preset_loader +from keras_nlp.src.utils.preset_utils import list_presets +from keras_nlp.src.utils.preset_utils import list_subclasses +from keras_nlp.src.utils.preset_utils import save_serialized_object +from keras_nlp.src.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.layers.ImageConverter") +class ImageConverter(PreprocessingLayer): + """Convert raw image for models that support image input. + + This class converts from raw images of any size, to preprocessed + images for pretrained model inputs. It is meant to be a convenient way to + write custom preprocessing code that is not model specific. This layer + should be instantiated via the `from_preset()` constructor, which will + create the correct subclass of this layer for the model preset. + + The layer will take as input a raw image tensor in the channels last or + channels first format, and output a preprocessed image input for modeling. + The exact structure of the output will vary per model, though in most cases + this layer will simply resize the image to the size needed by the model + input. + + Examples: + ```python + # Resize images for `"pali_gemma_3b_224"`. + converter = keras_nlp.layers.ImageConverter.from_preset("pali_gemma_3b_224") + converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3) + # Resize images for `"pali_gemma_3b_448"`. + converter = keras_nlp.layers.ImageConverter.from_preset("pali_gemma_3b_448") + converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 448, 448, 3) + ``` + """ + + backbone_cls = None + + @classproperty + def presets(cls): + """List built-in presets for a `Task` subclass.""" + presets = list_presets(cls) + for subclass in list_subclasses(cls): + presets.update(subclass.presets) + return presets + + @classmethod + def from_preset( + cls, + preset, + **kwargs, + ): + """Instantiate a `keras_nlp.layers.ImageConverter` from a model preset. + + A preset is a directory of configs, weights and other file assets used + to save and load a pre-trained model. The `preset` can be passed as + one of: + + 1. a built-in preset identifier like `'pali_gemma_3b_224'` + 2. a Kaggle Models handle like + `'kaggle://user/paligemma/keras/pali_gemma_3b_224'` + 3. a Hugging Face handle like `'hf://user/pali_gemma_3b_224'` + 4. a path to a local preset directory like `'./pali_gemma_3b_224'` + + You can run `cls.presets.keys()` to list all built-in presets available + on the class. + + This constructor can be called in one of two ways. Either from the base + class like `keras_nlp.models.ImageConverter.from_preset()`, or from a + model class like + `keras_nlp.models.PaliGemmaImageConverter.from_preset()`. If calling + from the base class, the subclass of the returning object will be + inferred from the config in the preset directory. + + Args: + preset: string. A built-in preset identifier, a Kaggle Models + handle, a Hugging Face handle, or a path to a local directory. + load_weights: bool. If `True`, the weights will be loaded into the + model architecture. If `False`, the weights will be randomly + initialized. + + Examples: + ```python + # Resize images for `"pali_gemma_3b_224"`. + converter = keras_nlp.layers.ImageConverter.from_preset( + "pali_gemma_3b_224" + ) + converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3) + # Override arguments on the base class. + converter = keras_nlp.layers.ImageConverter.from_preset( + "pali_gemma_3b_448", + crop_to_aspect_ratio=False, + ) + converter(np.ones(2, 512, 512, 3)) # (2, 448, 448, 3) + ``` + """ + loader = get_preset_loader(preset) + backbone_cls = loader.check_backbone_class() + if cls.backbone_cls != backbone_cls: + cls = find_subclass(preset, cls, backbone_cls) + return loader.load_image_converter(cls, **kwargs) + + def save_to_preset(self, preset_dir): + """Save image converter to a preset directory. + + Args: + preset_dir: The path to the local model preset directory. + """ + save_serialized_object( + self, + preset_dir, + config_file=IMAGE_CONVERTER_CONFIG_FILE, + ) diff --git a/keras_nlp/src/layers/preprocessing/image_converter_test.py b/keras_nlp/src/layers/preprocessing/image_converter_test.py new file mode 100644 index 0000000000..c3ecefead0 --- /dev/null +++ b/keras_nlp/src/layers/preprocessing/image_converter_test.py @@ -0,0 +1,84 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib + +import numpy as np +import pytest + +from keras_nlp.src.layers.preprocessing.image_converter import ImageConverter +from keras_nlp.src.models.pali_gemma.pali_gemma_backbone import ( + PaliGemmaBackbone, +) +from keras_nlp.src.models.pali_gemma.pali_gemma_image_converter import ( + PaliGemmaImageConverter, +) +from keras_nlp.src.tests.test_case import TestCase + + +class ImageConverterTest(TestCase): + def test_preset_accessors(self): + pali_gemma_presets = set(PaliGemmaImageConverter.presets.keys()) + all_presets = set(ImageConverter.presets.keys()) + self.assertContainsSubset(pali_gemma_presets, all_presets) + + @pytest.mark.large + def test_from_preset(self): + self.assertIsInstance( + ImageConverter.from_preset("pali_gemma_3b_mix_224"), + PaliGemmaImageConverter, + ) + + @pytest.mark.large + def test_from_preset_errors(self): + with self.assertRaises(ValueError): + ImageConverter.from_preset("bert_tiny_en_uncased") + with self.assertRaises(ValueError): + # No loading on a non-keras model. + ImageConverter.from_preset("hf://spacy/en_core_web_sm") + + @pytest.mark.large + def test_save_to_preset(self): + save_dir = self.get_temp_dir() + converter = ImageConverter.from_preset( + "pali_gemma_3b_mix_224", + interpolation="nearest", + ) + converter.save_to_preset(save_dir) + # Save a tiny backbone so the preset is valid. + backbone = PaliGemmaBackbone( + vocabulary_size=100, + image_size=224, + num_layers=1, + num_query_heads=1, + num_key_value_heads=1, + hidden_dim=8, + intermediate_dim=16, + head_dim=8, + vit_patch_size=14, + vit_num_heads=1, + vit_hidden_dim=8, + vit_num_layers=1, + ) + backbone.save_to_preset(save_dir) + + # Check existence of files. + path = pathlib.Path(save_dir) + self.assertTrue(os.path.exists(path / "image_converter.json")) + + # Check loading. + restored = ImageConverter.from_preset(save_dir) + test_image = np.random.rand(100, 100, 3) * 255 + self.assertAllClose(restored(test_image), converter(test_image)) diff --git a/keras_nlp/src/layers/preprocessing/masked_lm_mask_generator.py b/keras_nlp/src/layers/preprocessing/masked_lm_mask_generator.py index 2c14394437..fe78e0c172 100644 --- a/keras_nlp/src/layers/preprocessing/masked_lm_mask_generator.py +++ b/keras_nlp/src/layers/preprocessing/masked_lm_mask_generator.py @@ -18,7 +18,7 @@ PreprocessingLayer, ) from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function try: import tensorflow as tf @@ -166,7 +166,7 @@ def __init__( random_token_rate=self.random_token_rate, ) - @tf_preprocessing_function + @preprocessing_function def call(self, inputs): inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) diff --git a/keras_nlp/src/layers/preprocessing/multi_segment_packer.py b/keras_nlp/src/layers/preprocessing/multi_segment_packer.py index a972bf551c..53625783cc 100644 --- a/keras_nlp/src/layers/preprocessing/multi_segment_packer.py +++ b/keras_nlp/src/layers/preprocessing/multi_segment_packer.py @@ -17,7 +17,7 @@ PreprocessingLayer, ) from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function try: import tensorflow as tf @@ -282,7 +282,7 @@ def _combine_inputs( segment_ids = tf.concat(segment_ids_to_combine, 1) return token_ids, segment_ids - @tf_preprocessing_function + @preprocessing_function def call( self, inputs, diff --git a/keras_nlp/src/layers/preprocessing/multi_segment_packer_test.py b/keras_nlp/src/layers/preprocessing/multi_segment_packer_test.py index 463a43b412..2c34dcd0bd 100644 --- a/keras_nlp/src/layers/preprocessing/multi_segment_packer_test.py +++ b/keras_nlp/src/layers/preprocessing/multi_segment_packer_test.py @@ -31,7 +31,7 @@ def test_trim_single_input_ints(self): self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0]) def test_trim_single_input_strings(self): - input_data = np.array(["a", "b", "c", "d"]) + input_data = ["a", "b", "c", "d"] packer = MultiSegmentPacker( sequence_length=5, start_value="[CLS]", end_value="[SEP]" ) diff --git a/keras_nlp/src/layers/preprocessing/random_deletion.py b/keras_nlp/src/layers/preprocessing/random_deletion.py index bd1b6764fa..6df8b6ef28 100644 --- a/keras_nlp/src/layers/preprocessing/random_deletion.py +++ b/keras_nlp/src/layers/preprocessing/random_deletion.py @@ -21,7 +21,7 @@ from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.src.utils.tensor_utils import is_int_dtype from keras_nlp.src.utils.tensor_utils import is_string_dtype -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function try: import tensorflow as tf @@ -171,7 +171,7 @@ def __init__( default_value=False, ) - @tf_preprocessing_function + @preprocessing_function def call(self, inputs): inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) diff --git a/keras_nlp/src/layers/preprocessing/random_swap.py b/keras_nlp/src/layers/preprocessing/random_swap.py index 71b07193dc..32f7a08d26 100644 --- a/keras_nlp/src/layers/preprocessing/random_swap.py +++ b/keras_nlp/src/layers/preprocessing/random_swap.py @@ -21,7 +21,7 @@ from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.src.utils.tensor_utils import is_int_dtype from keras_nlp.src.utils.tensor_utils import is_string_dtype -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function try: import tensorflow as tf @@ -167,7 +167,7 @@ def __init__( default_value=False, ) - @tf_preprocessing_function + @preprocessing_function def call(self, inputs): inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) diff --git a/keras_nlp/src/layers/preprocessing/resizing_image_converter.py b/keras_nlp/src/layers/preprocessing/resizing_image_converter.py new file mode 100644 index 0000000000..8012acf5c2 --- /dev/null +++ b/keras_nlp/src/layers/preprocessing/resizing_image_converter.py @@ -0,0 +1,97 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import keras + +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.layers.preprocessing.image_converter import ImageConverter +from keras_nlp.src.utils.tensor_utils import preprocessing_function + + +@keras_nlp_export("keras_nlp.layers.ResizingImageConverter") +class ResizingImageConverter(ImageConverter): + """An `ImageConverter` that simply resizes the input image. + + The `ResizingImageConverter` is a subclass of `ImageConverter` for models + that simply need to resize image tensors before using them for modeling. + The layer will take as input a raw image tensor (batched or unbatched) in the + channels last or channels first format, and output a resize tensor. + + Args: + height: Integer, the height of the output shape. + width: Integer, the width of the output shape. + crop_to_aspect_ratio: If `True`, resize the images without aspect + ratio distortion. When the original aspect ratio differs + from the target aspect ratio, the output image will be + cropped so as to return the + largest possible window in the image (of size `(height, width)`) + that matches the target aspect ratio. By default + (`crop_to_aspect_ratio=False`), aspect ratio may not be preserved. + interpolation: String, the interpolation method. + Supports `"bilinear"`, `"nearest"`, `"bicubic"`, + `"lanczos3"`, `"lanczos5"`. Defaults to `"bilinear"`. + data_format: String, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + + Examples: + ```python + # Resize images for `"pali_gemma_3b_224"`. + converter = keras_nlp.layers.ImageConverter.from_preset("pali_gemma_3b_224") + converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 224, 224, 3) + # Resize images for `"pali_gemma_3b_224"`. + converter = keras_nlp.layers.ImageConverter.from_preset("pali_gemma_3b_448") + converter(np.ones(2, 512, 512, 3)) # Output shape: (2, 448, 448, 3) + ``` + """ + + def __init__( + self, + height, + width, + crop_to_aspect_ratio=True, + interpolation="bilinear", + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + # By default, we just do a simple resize. Any model can subclass this + # layer for preprocessing of a raw image to a model image input. + self.resizing = keras.layers.Resizing( + height, + width, + crop_to_aspect_ratio=crop_to_aspect_ratio, + interpolation=interpolation, + data_format=data_format, + ) + + @preprocessing_function + def call(self, inputs): + return self.resizing(inputs) + + def get_config(self): + config = super().get_config() + config.update( + { + "height": self.resizing.height, + "width": self.resizing.width, + "interpolation": self.resizing.interpolation, + "crop_to_aspect_ratio": self.resizing.crop_to_aspect_ratio, + } + ) + return config diff --git a/keras_nlp/src/layers/preprocessing/resizing_image_converter_test.py b/keras_nlp/src/layers/preprocessing/resizing_image_converter_test.py new file mode 100644 index 0000000000..f96a3a3488 --- /dev/null +++ b/keras_nlp/src/layers/preprocessing/resizing_image_converter_test.py @@ -0,0 +1,46 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from keras import ops + +from keras_nlp.src.layers.preprocessing.resizing_image_converter import ( + ResizingImageConverter, +) +from keras_nlp.src.tests.test_case import TestCase + + +class ResizingImageConverterTest(TestCase): + def test_resize_one(self): + converter = ResizingImageConverter(22, 22) + test_image = np.random.rand(10, 10, 3) * 255 + shape = ops.shape(converter(test_image)) + self.assertEqual(shape, (22, 22, 3)) + + def test_resize_batch(self): + converter = ResizingImageConverter(12, 12) + test_batch = np.random.rand(4, 10, 20, 3) * 255 + shape = ops.shape(converter(test_batch)) + self.assertEqual(shape, (4, 12, 12, 3)) + + def test_config(self): + converter = ResizingImageConverter( + width=12, + height=20, + crop_to_aspect_ratio=False, + interpolation="nearest", + ) + clone = ResizingImageConverter.from_config(converter.get_config()) + test_batch = np.random.rand(4, 10, 20, 3) * 255 + self.assertAllClose(converter(test_batch), clone(test_batch)) diff --git a/keras_nlp/src/layers/preprocessing/start_end_packer.py b/keras_nlp/src/layers/preprocessing/start_end_packer.py index e1712eba91..a6e7b4d068 100644 --- a/keras_nlp/src/layers/preprocessing/start_end_packer.py +++ b/keras_nlp/src/layers/preprocessing/start_end_packer.py @@ -18,7 +18,7 @@ PreprocessingLayer, ) from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function try: import tensorflow as tf @@ -155,7 +155,7 @@ def check_special_value_type(value, value_name): self.pad_value = pad_value self.return_padding_mask = return_padding_mask - @tf_preprocessing_function + @preprocessing_function def call( self, inputs, diff --git a/keras_nlp/src/layers/preprocessing/start_end_packer_test.py b/keras_nlp/src/layers/preprocessing/start_end_packer_test.py index dc99eaa8cd..5fb77a930e 100644 --- a/keras_nlp/src/layers/preprocessing/start_end_packer_test.py +++ b/keras_nlp/src/layers/preprocessing/start_end_packer_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras import tensorflow as tf from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker @@ -30,13 +29,10 @@ def test_dense_input(self): def test_bfloat16_dtype(self): # Core Keras has a strange bug where it converts int to floats in # ops.convert_to_tensor only with jax and bfloat16. - floatx = keras.config.floatx() - keras.config.set_floatx("bfloat16") input_data = [5, 6, 7] start_end_packer = StartEndPacker(sequence_length=5, dtype="bfloat16") output = start_end_packer(input_data) self.assertDTypeEqual(output, "int32") - keras.config.set_floatx(floatx) def test_dense_2D_input(self): input_data = [[5, 6, 7]] diff --git a/keras_nlp/src/models/bart/bart_preprocessor.py b/keras_nlp/src/models/bart/bart_preprocessor.py index dc013779a3..0f9de65fab 100644 --- a/keras_nlp/src/models/bart/bart_preprocessor.py +++ b/keras_nlp/src/models/bart/bart_preprocessor.py @@ -20,7 +20,7 @@ from keras_nlp.src.models.bart.bart_backbone import BartBackbone from keras_nlp.src.models.bart.bart_tokenizer import BartTokenizer from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.BartPreprocessor") @@ -174,7 +174,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call( self, x, diff --git a/keras_nlp/src/models/bloom/bloom_preprocessor.py b/keras_nlp/src/models/bloom/bloom_preprocessor.py index 6c572591b7..8b2d7b2ba0 100644 --- a/keras_nlp/src/models/bloom/bloom_preprocessor.py +++ b/keras_nlp/src/models/bloom/bloom_preprocessor.py @@ -20,7 +20,7 @@ from keras_nlp.src.models.bloom.bloom_backbone import BloomBackbone from keras_nlp.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.BloomPreprocessor") @@ -134,7 +134,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call( self, x, diff --git a/keras_nlp/src/models/causal_lm_preprocessor.py b/keras_nlp/src/models/causal_lm_preprocessor.py index 1713ce6566..6a0dad3bdf 100644 --- a/keras_nlp/src/models/causal_lm_preprocessor.py +++ b/keras_nlp/src/models/causal_lm_preprocessor.py @@ -16,8 +16,8 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.utils.tensor_utils import preprocessing_function from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @keras_nlp_export("keras_nlp.models.CausalLMPreprocessor") @@ -98,7 +98,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call( self, x, @@ -124,7 +124,7 @@ def call( y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - @tf_preprocessing_function + @preprocessing_function def generate_preprocess( self, x, @@ -153,7 +153,7 @@ def generate_preprocess( "padding_mask": padding_mask, } - @tf_preprocessing_function + @preprocessing_function def generate_postprocess( self, x, diff --git a/keras_nlp/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py b/keras_nlp/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py index 8096e76e37..6823d1a8eb 100644 --- a/keras_nlp/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py @@ -22,7 +22,7 @@ DebertaV3Tokenizer, ) from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.DebertaV3MaskedLMPreprocessor") @@ -119,7 +119,7 @@ class DebertaV3MaskedLMPreprocessor(MaskedLMPreprocessor): backbone_cls = DebertaV3Backbone tokenizer_cls = DebertaV3Tokenizer - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): output = super().call(x, y=y, sample_weight=sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) diff --git a/keras_nlp/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py b/keras_nlp/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py index 6f1996d3b4..9634b77238 100644 --- a/keras_nlp/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/deberta_v3/deberta_v3_text_classifier_preprocessor.py @@ -24,7 +24,7 @@ from keras_nlp.src.models.text_classifier_preprocessor import ( TextClassifierPreprocessor, ) -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export( @@ -156,7 +156,7 @@ class DebertaV3TextClassifierPreprocessor(TextClassifierPreprocessor): backbone_cls = DebertaV3Backbone tokenizer_cls = DebertaV3Tokenizer - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): output = super().call(x, y=y, sample_weight=sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) diff --git a/keras_nlp/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py b/keras_nlp/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py index 9004fa06ba..1a69de4e50 100644 --- a/keras_nlp/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py @@ -22,7 +22,7 @@ DistilBertTokenizer, ) from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.DistilBertMaskedLMPreprocessor") @@ -123,7 +123,7 @@ class DistilBertMaskedLMPreprocessor(MaskedLMPreprocessor): backbone_cls = DistilBertBackbone tokenizer_cls = DistilBertTokenizer - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): output = super().call(x, y=y, sample_weight=sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) diff --git a/keras_nlp/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py b/keras_nlp/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py index 90ebe7ef3b..f4bda84a2c 100644 --- a/keras_nlp/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/distil_bert/distil_bert_text_classifier_preprocessor.py @@ -25,7 +25,7 @@ from keras_nlp.src.models.text_classifier_preprocessor import ( TextClassifierPreprocessor, ) -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export( @@ -126,7 +126,7 @@ class DistilBertTextClassifierPreprocessor(TextClassifierPreprocessor): backbone_cls = DistilBertBackbone tokenizer_cls = DistilBertTokenizer - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): output = super().call(x, y=y, sample_weight=sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) diff --git a/keras_nlp/src/models/electra/electra_preprocessor.py b/keras_nlp/src/models/electra/electra_preprocessor.py index 941e5d4d13..82dbb3310c 100644 --- a/keras_nlp/src/models/electra/electra_preprocessor.py +++ b/keras_nlp/src/models/electra/electra_preprocessor.py @@ -21,7 +21,7 @@ from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone from keras_nlp.src.models.electra.electra_tokenizer import ElectraTokenizer from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.ElectraPreprocessor") @@ -142,7 +142,7 @@ def get_config(self): ) return config - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): x = x if isinstance(x, tuple) else (x,) x = tuple(self.tokenizer(segment) for segment in x) diff --git a/keras_nlp/src/models/f_net/f_net_masked_lm_preprocessor.py b/keras_nlp/src/models/f_net/f_net_masked_lm_preprocessor.py index 7628f26ee4..3d0b625ef1 100644 --- a/keras_nlp/src/models/f_net/f_net_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/f_net/f_net_masked_lm_preprocessor.py @@ -18,7 +18,7 @@ from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.FNetMaskedLMPreprocessor") @@ -121,7 +121,7 @@ class FNetMaskedLMPreprocessor(MaskedLMPreprocessor): backbone_cls = FNetBackbone tokenizer_cls = FNetTokenizer - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): output = super().call(x, y=y, sample_weight=sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) diff --git a/keras_nlp/src/models/f_net/f_net_text_classifier_preprocessor.py b/keras_nlp/src/models/f_net/f_net_text_classifier_preprocessor.py index c1f1f3f71e..124fc803e2 100644 --- a/keras_nlp/src/models/f_net/f_net_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/f_net/f_net_text_classifier_preprocessor.py @@ -21,7 +21,7 @@ from keras_nlp.src.models.text_classifier_preprocessor import ( TextClassifierPreprocessor, ) -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export( @@ -124,7 +124,7 @@ class FNetTextClassifierPreprocessor(TextClassifierPreprocessor): backbone_cls = FNetBackbone tokenizer_cls = FNetTokenizer - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): # FNet has not padding mask. output = super().call(x, y=y, sample_weight=sample_weight) diff --git a/keras_nlp/src/models/falcon/falcon_preprocessor.py b/keras_nlp/src/models/falcon/falcon_preprocessor.py index 1c7fd3c138..491f6e5fe2 100644 --- a/keras_nlp/src/models/falcon/falcon_preprocessor.py +++ b/keras_nlp/src/models/falcon/falcon_preprocessor.py @@ -20,7 +20,7 @@ from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.FalconPreprocessor") @@ -136,7 +136,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call( self, x, diff --git a/keras_nlp/src/models/gemma/gemma_preprocessor.py b/keras_nlp/src/models/gemma/gemma_preprocessor.py index 745f437ecd..dcbe531b56 100644 --- a/keras_nlp/src/models/gemma/gemma_preprocessor.py +++ b/keras_nlp/src/models/gemma/gemma_preprocessor.py @@ -20,7 +20,7 @@ from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.src.models.gemma.gemma_tokenizer import GemmaTokenizer from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.GemmaPreprocessor") @@ -151,7 +151,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call( self, x, diff --git a/keras_nlp/src/models/gpt2/gpt2_preprocessor.py b/keras_nlp/src/models/gpt2/gpt2_preprocessor.py index b645ac7682..c9af92fcbf 100644 --- a/keras_nlp/src/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/src/models/gpt2/gpt2_preprocessor.py @@ -20,7 +20,7 @@ from keras_nlp.src.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.GPT2Preprocessor") @@ -136,7 +136,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call( self, x, diff --git a/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py b/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py index 25875d03da..06a46ce470 100644 --- a/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +++ b/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py @@ -19,7 +19,7 @@ from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.GPTNeoXPreprocessor") @@ -94,7 +94,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call( self, x, diff --git a/keras_nlp/src/models/llama/llama_preprocessor.py b/keras_nlp/src/models/llama/llama_preprocessor.py index 8b7c5772e6..75122856c6 100644 --- a/keras_nlp/src/models/llama/llama_preprocessor.py +++ b/keras_nlp/src/models/llama/llama_preprocessor.py @@ -18,7 +18,7 @@ from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone from keras_nlp.src.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.LlamaPreprocessor") @@ -149,7 +149,7 @@ def get_config(self): ) return config - @tf_preprocessing_function + @preprocessing_function def call( self, x, diff --git a/keras_nlp/src/models/masked_lm_preprocessor.py b/keras_nlp/src/models/masked_lm_preprocessor.py index 491b715280..c4dab122ab 100644 --- a/keras_nlp/src/models/masked_lm_preprocessor.py +++ b/keras_nlp/src/models/masked_lm_preprocessor.py @@ -21,7 +21,7 @@ MultiSegmentPacker, ) from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.MaskedLMPreprocessor") @@ -113,7 +113,7 @@ def build(self, input_shape): unselectable_token_ids=self.tokenizer.special_token_ids, ) - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): x = x if isinstance(x, tuple) else (x,) x = tuple(self.tokenizer(segment) for segment in x) diff --git a/keras_nlp/src/models/mistral/mistral_preprocessor.py b/keras_nlp/src/models/mistral/mistral_preprocessor.py index 0278103c54..c6d7731722 100644 --- a/keras_nlp/src/models/mistral/mistral_preprocessor.py +++ b/keras_nlp/src/models/mistral/mistral_preprocessor.py @@ -19,7 +19,7 @@ from keras_nlp.src.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.src.models.mistral.mistral_tokenizer import MistralTokenizer from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.MistralPreprocessor") @@ -150,7 +150,7 @@ def get_config(self): ) return config - @tf_preprocessing_function + @preprocessing_function def call( self, x, diff --git a/keras_nlp/src/models/opt/opt_preprocessor.py b/keras_nlp/src/models/opt/opt_preprocessor.py index 3b39c82cbc..0cafaaec5e 100644 --- a/keras_nlp/src/models/opt/opt_preprocessor.py +++ b/keras_nlp/src/models/opt/opt_preprocessor.py @@ -20,7 +20,7 @@ from keras_nlp.src.models.opt.opt_backbone import OPTBackbone from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.OPTPreprocessor") @@ -148,7 +148,7 @@ def get_config(self): ) return config - @tf_preprocessing_function + @preprocessing_function def call( self, x, diff --git a/keras_nlp/src/models/pali_gemma/pali_gemma_backbone_test.py b/keras_nlp/src/models/pali_gemma/pali_gemma_backbone_test.py index be541597cd..44d80b307d 100644 --- a/keras_nlp/src/models/pali_gemma/pali_gemma_backbone_test.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_backbone_test.py @@ -11,19 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os - import numpy as np +import pytest +from keras import ops from keras_nlp.src.models.pali_gemma.pali_gemma_backbone import ( PaliGemmaBackbone, ) -from keras_nlp.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( - PaliGemmaCausalLMPreprocessor, -) -from keras_nlp.src.models.pali_gemma.pali_gemma_tokenizer import ( - PaliGemmaTokenizer, -) from keras_nlp.src.tests.test_case import TestCase @@ -34,15 +28,6 @@ def setUp(self): self.text_sequence_length = 64 self.image_size = 16 self.image_sequence_length = int((self.image_size / 4) ** 2) - - proto = "gemma_test_vocab.spm" - tokenizer = PaliGemmaTokenizer( - os.path.join(self.get_test_data_dir(), proto) - ) - self.preprocessor = PaliGemmaCausalLMPreprocessor( - tokenizer, self.text_sequence_length, False, False - ) - self.init_kwargs = { "vocabulary_size": self.vocabulary_size, "image_size": self.image_size, @@ -65,7 +50,6 @@ def setUp(self): dummy_text_token_ids = np.random.rand( self.batch_size, self.text_sequence_length ) - dummy_text = ["answer en the quick brown fox"] * self.batch_size self.input_data = { "token_ids": dummy_text_token_ids, "images": dummy_images, @@ -78,11 +62,6 @@ def setUp(self): dtype="int32", ), } - self.raw_input_data = { - "images": dummy_images, - "prompts": dummy_text, - "responses": dummy_text, - } def test_backbone_basics(self): self.run_backbone_test( @@ -98,15 +77,37 @@ def test_backbone_basics(self): run_mixed_precision_check=False, # TODO: Set to `True` ) - def test_pali_gemma_backbone_with_preprocessing(self): - model = PaliGemmaBackbone(**self.init_kwargs) - x, _, _ = self.preprocessor(self.raw_input_data) - output = model(x) - self.assertEqual( - ( - self.batch_size, - self.text_sequence_length + self.image_sequence_length, - 8, + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=PaliGemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_smallest_preset(self): + self.run_preset_test( + cls=PaliGemmaBackbone, + preset="pali_gemma_3b_mix_224", + input_data={ + "token_ids": ops.array([[1169, 2068, 7586, 21831, 13]]), + "padding_mask": ops.ones((1, 5), dtype="int32"), + "response_mask": ops.zeros((1, 5), dtype="int32"), + "images": ops.zeros((1, 224, 224, 3), dtype="float32"), + }, + expected_output_shape=(1, 261, 2048), + # The forward pass from a preset should be stable! + expected_partial_output=ops.array( + [-0.449851, 1.431027, -0.713446, 0.417485, -0.640859] ), - output.shape, ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in PaliGemmaBackbone.presets: + self.run_preset_test( + cls=PaliGemmaBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py b/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py index 764b570fa2..01493ef454 100644 --- a/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py @@ -17,22 +17,42 @@ from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.src.models.gemma.gemma_causal_lm_preprocessor import ( - GemmaCausalLMPreprocessor, -) +from keras_nlp.src.models.causal_lm_preprocessor import CausalLMPreprocessor from keras_nlp.src.models.pali_gemma.pali_gemma_backbone import ( PaliGemmaBackbone, ) +from keras_nlp.src.models.pali_gemma.pali_gemma_image_converter import ( + PaliGemmaImageConverter, +) from keras_nlp.src.models.pali_gemma.pali_gemma_tokenizer import ( PaliGemmaTokenizer, ) -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.PaliGemmaCausalLMPreprocessor") -class PaliGemmaCausalLMPreprocessor(GemmaCausalLMPreprocessor): +class PaliGemmaCausalLMPreprocessor(CausalLMPreprocessor): backbone_cls = PaliGemmaBackbone tokenizer_cls = PaliGemmaTokenizer + image_converter_cls = PaliGemmaImageConverter + + def __init__( + self, + tokenizer, + image_converter=None, + sequence_length=1024, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__( + tokenizer=tokenizer, + sequence_length=sequence_length, + add_start_token=add_start_token, + add_end_token=add_end_token, + **kwargs, + ) + self.image_converter = image_converter def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -46,7 +66,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call( self, x, @@ -58,6 +78,8 @@ def call( images, prompts, responses = x["images"], x["prompts"], x["responses"] prompts = self.tokenizer(prompts) responses = self.tokenizer(responses) + if self.image_converter: + images = self.image_converter(images) # Pad with one extra token to account for the truncation below. token_ids, segment_ids = self.packer( (prompts, responses), @@ -80,7 +102,7 @@ def call( sample_weight = response_mask[..., 1:] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - @tf_preprocessing_function + @preprocessing_function def generate_preprocess( self, x, @@ -103,6 +125,8 @@ def generate_preprocess( images, prompts = x["images"], x["prompts"] prompts = self.tokenizer(prompts) + if self.image_converter: + images = self.image_converter(images) if "responses" in x: responses = self.tokenizer(x["responses"]) segments = (prompts, responses) diff --git a/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor_test.py b/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor_test.py index cba98a9d97..5553bcd71a 100644 --- a/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor_test.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor_test.py @@ -20,6 +20,9 @@ from keras_nlp.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( PaliGemmaCausalLMPreprocessor, ) +from keras_nlp.src.models.pali_gemma.pali_gemma_image_converter import ( + PaliGemmaImageConverter, +) from keras_nlp.src.models.pali_gemma.pali_gemma_tokenizer import ( PaliGemmaTokenizer, ) @@ -34,14 +37,16 @@ def setUp(self): self.get_test_data_dir(), "gemma_test_vocab.spm" ), ) + self.image_converter = PaliGemmaImageConverter(width=224, height=224) self.init_kwargs = { "tokenizer": self.tokenizer, + "image_converter": self.image_converter, "sequence_length": 8, } self.input_data = { "prompts": ["the quick"], "responses": ["brown fox"], - "images": [np.zeros([1, 224, 224, 3])], + "images": [np.zeros([512, 512, 3])], } def test_preprocessor_basics(self): @@ -54,7 +59,7 @@ def test_preprocessor_basics(self): "token_ids": [[1, 4, 9, 5, 7, 2, 0, 0]], "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], "response_mask": [[0, 0, 0, 1, 1, 1, 0, 0]], - "images": self.input_data["images"], + "images": np.zeros([1, 224, 224, 3]), }, [[4, 9, 5, 7, 2, 0, 0, 0]], # Labels shifted. [[0, 0, 1, 1, 1, 0, 0, 0]], # Zero out unlabeled examples. @@ -65,7 +70,7 @@ def test_no_start_end_token(self): input_data = { "prompts": ["the quick"] * 4, "responses": ["brown fox"] * 4, - "images": [np.zeros([1, 224, 224, 3])] * 4, + "images": [np.zeros([512, 512, 3])] * 4, } preprocessor = PaliGemmaCausalLMPreprocessor( **self.init_kwargs, @@ -76,19 +81,21 @@ def test_no_start_end_token(self): self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4) self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) self.assertAllEqual(x["response_mask"], [[0, 0, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["images"], np.zeros([4, 224, 224, 3])) self.assertAllEqual(y, [[9, 5, 7, 0, 0, 0, 0, 0]] * 4) self.assertAllEqual(sw, [[0, 1, 1, 0, 0, 0, 0, 0]] * 4) def test_generate_preprocess(self): input_data = { "prompts": "the quick", - "images": np.zeros([1, 224, 224, 3]), + "images": np.zeros([1, 512, 512, 3]), } preprocessor = PaliGemmaCausalLMPreprocessor(**self.init_kwargs) x = preprocessor.generate_preprocess(input_data) self.assertAllEqual(x["token_ids"], [1, 4, 9, 0, 0, 0, 0, 0]) self.assertAllEqual(x["padding_mask"], [1, 1, 1, 0, 0, 0, 0, 0]) self.assertAllEqual(x["response_mask"], [0, 0, 0, 0, 0, 0, 0, 0]) + self.assertAllEqual(x["images"], np.zeros([1, 224, 224, 3])) def test_generate_postprocess(self): input_data = { diff --git a/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_test.py b/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_test.py index 51afdc5fb7..3009cbf944 100644 --- a/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_test.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_test.py @@ -25,6 +25,9 @@ from keras_nlp.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import ( PaliGemmaCausalLMPreprocessor, ) +from keras_nlp.src.models.pali_gemma.pali_gemma_image_converter import ( + PaliGemmaImageConverter, +) from keras_nlp.src.models.pali_gemma.pali_gemma_tokenizer import ( PaliGemmaTokenizer, ) @@ -39,17 +42,17 @@ def setUp(self): self.dummy_text = [ "the quick brown fox" for _ in range(self.batch_size) ] - self.dummy_images = np.random.uniform( - size=(self.batch_size, self.image_size, self.image_size, 3) - ) + self.dummy_images = np.random.uniform(size=(self.batch_size, 20, 20, 3)) proto = "gemma_test_vocab.spm" tokenizer = PaliGemmaTokenizer( os.path.join(self.get_test_data_dir(), proto) ) + image_converter = PaliGemmaImageConverter(16, 16) self.vocabulary_size = tokenizer.vocabulary_size() self.preprocessor = PaliGemmaCausalLMPreprocessor( tokenizer, + image_converter, self.text_sequence_length, add_start_token=False, add_end_token=False, @@ -96,7 +99,9 @@ def test_saved_model(self): "token_ids": np.random.rand( self.batch_size, self.text_sequence_length ), - "images": self.dummy_images, + "images": np.ones( + (self.batch_size, self.image_size, self.image_size, 3) + ), "padding_mask": np.ones( (self.batch_size, self.text_sequence_length), dtype="int32", diff --git a/keras_nlp/api/utils/__init__.py b/keras_nlp/src/models/pali_gemma/pali_gemma_image_converter.py similarity index 59% rename from keras_nlp/api/utils/__init__.py rename to keras_nlp/src/models/pali_gemma/pali_gemma_image_converter.py index a904b48f59..6bef25d8b4 100644 --- a/keras_nlp/api/utils/__init__.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_image_converter.py @@ -11,11 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""DO NOT EDIT. +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.layers.preprocessing.resizing_image_converter import ( + ResizingImageConverter, +) +from keras_nlp.src.models.pali_gemma.pali_gemma_backbone import ( + PaliGemmaBackbone, +) -This file was autogenerated. Do not edit it by hand, -since your modifications would be overwritten. -""" -from keras_nlp.src.utils.tensor_utils import convert_from_tf -from keras_nlp.src.utils.tensor_utils import convert_to_tf +@keras_nlp_export("keras_nlp.layers.PaliGemmaImageConverter") +class PaliGemmaImageConverter(ResizingImageConverter): + backbone_cls = PaliGemmaBackbone diff --git a/keras_nlp/src/models/pali_gemma/pali_gemma_presets.py b/keras_nlp/src/models/pali_gemma/pali_gemma_presets.py index c8791f709b..8cabcc7038 100644 --- a/keras_nlp/src/models/pali_gemma/pali_gemma_presets.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_presets.py @@ -25,7 +25,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_224/1", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_224/2", }, "pali_gemma_3b_mix_448": { "metadata": { @@ -37,7 +37,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_448/1", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_448/2", }, "pali_gemma_3b_224": { "metadata": { @@ -49,7 +49,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_224/1", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_224/2", }, "pali_gemma_3b_448": { "metadata": { @@ -61,7 +61,7 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_448/1", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_448/2", }, "pali_gemma_3b_896": { "metadata": { @@ -73,6 +73,6 @@ "path": "pali_gemma", "model_card": "https://www.kaggle.com/models/google/paligemma", }, - "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/1", + "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/2", }, } diff --git a/keras_nlp/src/models/phi3/phi3_preprocessor.py b/keras_nlp/src/models/phi3/phi3_preprocessor.py index caa5c9eab4..ce392b5088 100644 --- a/keras_nlp/src/models/phi3/phi3_preprocessor.py +++ b/keras_nlp/src/models/phi3/phi3_preprocessor.py @@ -18,7 +18,7 @@ from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.Phi3Preprocessor") @@ -150,7 +150,7 @@ def get_config(self): ) return config - @tf_preprocessing_function + @preprocessing_function def call( self, x, diff --git a/keras_nlp/src/models/preprocessor.py b/keras_nlp/src/models/preprocessor.py index e498746685..fd24ea8b47 100644 --- a/keras_nlp/src/models/preprocessor.py +++ b/keras_nlp/src/models/preprocessor.py @@ -44,14 +44,18 @@ class Preprocessor(PreprocessingLayer): backbone_cls = None tokenizer_cls = None + audio_converter_cls = None + image_converter_cls = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._tokenizer = None + self._image_converter = None + self._audio_converter = None def __setattr__(self, name, value): # Work around torch setattr for properties. - if name in ["tokenizer"]: + if name in ["tokenizer", "audio_converter", "image_converter"]: return object.__setattr__(self, name, value) return super().__setattr__(name, value) @@ -64,15 +68,54 @@ def tokenizer(self): def tokenizer(self, value): self._tokenizer = value + @property + def audio_converter(self): + """The audio converter used to preprocess audio data.""" + return self._audio_converter + + @audio_converter.setter + def audio_converter(self, value): + self._audio_converter = value + + @property + def image_converter(self): + """The image converter used to preprocess image data.""" + return self._image_converter + + @image_converter.setter + def image_converter(self, value): + self._image_converter = value + def get_config(self): config = super().get_config() - config["tokenizer"] = keras.layers.serialize(self.tokenizer) + if self.tokenizer: + config["tokenizer"] = keras.layers.serialize(self.tokenizer) + if self.audio_converter: + config["audio_converter"] = keras.layers.serialize( + self.audio_converter + ) + if self.image_converter: + config["image_converter"] = keras.layers.serialize( + self.image_converter + ) return config @classmethod def from_config(cls, config): if "tokenizer" in config and isinstance(config["tokenizer"], dict): config["tokenizer"] = keras.layers.deserialize(config["tokenizer"]) + if "audio_converter" in config and isinstance( + config["audio_converter"], dict + ): + config["audio_converter"] = keras.layers.deserialize( + config["audio_converter"] + ) + if "image_converter" in config and isinstance( + config["image_converter"], dict + ): + config["image_converter"] = keras.layers.deserialize( + config["image_converter"] + ) return cls(**config) @classproperty @@ -95,7 +138,7 @@ def from_preset( """Instantiate a `keras_nlp.models.Preprocessor` from a model preset. A preset is a directory of configs, weights and other file assets used - to save and load a pre-trained model. The `preset` can be passed as a + to save and load a pre-trained model. The `preset` can be passed as one of: 1. a built-in preset identifier like `'bert_base_en'` @@ -155,4 +198,9 @@ def save_to_preset(self, preset_dir): preset_dir, config_file=PREPROCESSOR_CONFIG_FILE, ) - self.tokenizer.save_to_preset(preset_dir) + if self.tokenizer: + self.tokenizer.save_to_preset(preset_dir) + if self.audio_converter: + self.audio_converter.save_to_preset(preset_dir) + if self.image_converter: + self.image_converter.save_to_preset(preset_dir) diff --git a/keras_nlp/src/models/roberta/roberta_masked_lm_preprocessor.py b/keras_nlp/src/models/roberta/roberta_masked_lm_preprocessor.py index 4b3bf02b6c..369a34ce61 100644 --- a/keras_nlp/src/models/roberta/roberta_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/roberta/roberta_masked_lm_preprocessor.py @@ -21,7 +21,7 @@ from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.src.models.roberta.roberta_tokenizer import RobertaTokenizer -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.RobertaMaskedLMPreprocessor") @@ -137,7 +137,7 @@ def build(self, input_shape): sequence_length=self.sequence_length, ) - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): output = super().call(x, y=y, sample_weight=sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) diff --git a/keras_nlp/src/models/roberta/roberta_text_classifier_preprocessor.py b/keras_nlp/src/models/roberta/roberta_text_classifier_preprocessor.py index 377cc2cb75..a905156615 100644 --- a/keras_nlp/src/models/roberta/roberta_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/roberta/roberta_text_classifier_preprocessor.py @@ -23,7 +23,7 @@ from keras_nlp.src.models.text_classifier_preprocessor import ( TextClassifierPreprocessor, ) -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export( @@ -151,7 +151,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): output = super().call(x, y=y, sample_weight=sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) diff --git a/keras_nlp/src/models/seq_2_seq_lm_preprocessor.py b/keras_nlp/src/models/seq_2_seq_lm_preprocessor.py index 8da697dc99..27405f99d0 100644 --- a/keras_nlp/src/models/seq_2_seq_lm_preprocessor.py +++ b/keras_nlp/src/models/seq_2_seq_lm_preprocessor.py @@ -16,8 +16,8 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.utils.tensor_utils import preprocessing_function from keras_nlp.src.utils.tensor_utils import strip_to_ragged -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function try: import tensorflow as tf @@ -114,7 +114,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call( self, x, @@ -154,7 +154,7 @@ def call( sample_weight = decoder_padding_mask[..., 1:] return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - @tf_preprocessing_function + @preprocessing_function def generate_preprocess( self, x, @@ -215,7 +215,7 @@ def generate_preprocess( "decoder_padding_mask": decoder_padding_mask, } - @tf_preprocessing_function + @preprocessing_function def generate_postprocess( self, x, diff --git a/keras_nlp/src/models/task.py b/keras_nlp/src/models/task.py index 919d025d1a..8da47992f1 100644 --- a/keras_nlp/src/models/task.py +++ b/keras_nlp/src/models/task.py @@ -152,7 +152,7 @@ def from_preset( """Instantiate a `keras_nlp.models.Task` from a model preset. A preset is a directory of configs, weights and other file assets used - to save and load a pre-trained model. The `preset` can be passed as a + to save and load a pre-trained model. The `preset` can be passed as one of: 1. a built-in preset identifier like `'bert_base_en'` diff --git a/keras_nlp/src/models/text_classifier_preprocessor.py b/keras_nlp/src/models/text_classifier_preprocessor.py index 6b8639d045..9b34353938 100644 --- a/keras_nlp/src/models/text_classifier_preprocessor.py +++ b/keras_nlp/src/models/text_classifier_preprocessor.py @@ -18,7 +18,7 @@ MultiSegmentPacker, ) from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.TextClassifierPreprocessor") @@ -104,7 +104,7 @@ def build(self, input_shape): sequence_length=self.sequence_length, ) - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): x = x if isinstance(x, tuple) else (x,) x = tuple(self.tokenizer(segment) for segment in x) diff --git a/keras_nlp/src/models/whisper/whisper_audio_feature_extractor.py b/keras_nlp/src/models/whisper/whisper_audio_converter.py similarity index 92% rename from keras_nlp/src/models/whisper/whisper_audio_feature_extractor.py rename to keras_nlp/src/models/whisper/whisper_audio_converter.py index 0fdc4374e0..ac15be457a 100644 --- a/keras_nlp/src/models/whisper/whisper_audio_feature_extractor.py +++ b/keras_nlp/src/models/whisper/whisper_audio_converter.py @@ -15,24 +15,19 @@ import numpy as np +from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.layers.preprocessing.audio_converter import AudioConverter +from keras_nlp.src.models.whisper.whisper_backbone import WhisperBackbone + try: import tensorflow as tf except ImportError: - raise ImportError( - "To use `keras_nlp`, please install Tensorflow: `pip install tensorflow`. " - "The TensorFlow package is required for data preprocessing with any backend." - ) + tf = None -from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.preprocessing_layer import ( - PreprocessingLayer, -) - -@keras_nlp_export("keras_nlp.models.WhisperAudioFeatureExtractor") -class WhisperAudioFeatureExtractor(PreprocessingLayer): - """ - Whisper audio feature extractor layer. +@keras_nlp_export("keras_nlp.layers.WhisperAudioConverter") +class WhisperAudioConverter(AudioConverter): + """Whisper audio converter layer. This layer takes in a batch of audio tensors, and computes the log-mel spectrogram features for each audio tensor. @@ -55,22 +50,25 @@ class WhisperAudioFeatureExtractor(PreprocessingLayer): `max_audio_length * sampling_rate`. Defaults to `30`. Examples: - ```python audio_tensor = tf.ones((8000,), dtype="float32") # Compute the log-mel spectrogram. - whisper_audio_feature_extractor = keras_nlp.models.WhisperAudioFeatureExtractor() - whisper_audio_feature_extractor(audio_tensor) + audio_converter = keras_nlp.models.WhisperAudioConverter.from_preset( + "whisper_base_en", + ) + audio_converter(audio_tensor) # Compute the log-mel spectrogram for a batch of audio tensors. audio_tensor_1 = tf.ones((8000,), dtype="float32") - audio_tensor_2 = tf.ones((10000,), dtype="float32" + audio_tensor_2 = tf.ones((10000,), dtype="float32") audio_tensor = tf.ragged.stack([audio_tensor_1, audio_tensor_2], axis=0) - whisper_audio_feature_extractor(audio_tensor) + audio_converter(audio_tensor) ``` """ + backbone_cls = WhisperBackbone + def __init__( self, num_mels=80, diff --git a/keras_nlp/src/models/whisper/whisper_audio_feature_extractor_test.py b/keras_nlp/src/models/whisper/whisper_audio_converter_test.py similarity index 84% rename from keras_nlp/src/models/whisper/whisper_audio_feature_extractor_test.py rename to keras_nlp/src/models/whisper/whisper_audio_converter_test.py index 823acbf3d1..16923787f4 100644 --- a/keras_nlp/src/models/whisper/whisper_audio_feature_extractor_test.py +++ b/keras_nlp/src/models/whisper/whisper_audio_converter_test.py @@ -14,13 +14,13 @@ import tensorflow as tf -from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( - WhisperAudioFeatureExtractor, +from keras_nlp.src.models.whisper.whisper_audio_converter import ( + WhisperAudioConverter, ) from keras_nlp.src.tests.test_case import TestCase -class WhisperAudioFeatureExtractorTest(TestCase): +class WhisperAudioConverterTest(TestCase): def setUp(self): self.init_kwargs = { "num_mels": 80, @@ -38,14 +38,14 @@ def setUp(self): def test_feature_extractor_basics(self): self.run_preprocessing_layer_test( - cls=WhisperAudioFeatureExtractor, + cls=WhisperAudioConverter, init_kwargs=self.init_kwargs, input_data=self.input_data, ) def test_correctness(self): audio_tensor = tf.ones((2,), dtype="float32") - outputs = WhisperAudioFeatureExtractor(**self.init_kwargs)(audio_tensor) + outputs = WhisperAudioConverter(**self.init_kwargs)(audio_tensor) # Verify shape. self.assertEqual(outputs.shape, (5, 80)) diff --git a/keras_nlp/src/models/whisper/whisper_backbone.py b/keras_nlp/src/models/whisper/whisper_backbone.py index 5629e05f2d..e104494b2b 100644 --- a/keras_nlp/src/models/whisper/whisper_backbone.py +++ b/keras_nlp/src/models/whisper/whisper_backbone.py @@ -24,7 +24,6 @@ from keras_nlp.src.models.backbone import Backbone from keras_nlp.src.models.whisper.whisper_decoder import WhisperDecoder from keras_nlp.src.models.whisper.whisper_encoder import WhisperEncoder -from keras_nlp.src.utils.tensor_utils import assert_tf_backend def whisper_kernel_initializer(stddev=0.02): @@ -117,8 +116,6 @@ def __init__( dtype=None, **kwargs, ): - assert_tf_backend(self.__class__.__name__) - # === Layers === self.encoder_conv_layer_1 = keras.layers.Conv1D( filters=hidden_dim, diff --git a/keras_nlp/src/models/whisper/whisper_backbone_test.py b/keras_nlp/src/models/whisper/whisper_backbone_test.py index 0b34d95cce..33bc0e2871 100644 --- a/keras_nlp/src/models/whisper/whisper_backbone_test.py +++ b/keras_nlp/src/models/whisper/whisper_backbone_test.py @@ -19,7 +19,6 @@ from keras_nlp.src.tests.test_case import TestCase -@pytest.mark.tf_only class WhisperBackboneTest(TestCase): def setUp(self): self.init_kwargs = { diff --git a/keras_nlp/src/models/whisper/whisper_preprocessor.py b/keras_nlp/src/models/whisper/whisper_preprocessor.py deleted file mode 100644 index 8a65e6d004..0000000000 --- a/keras_nlp/src/models/whisper/whisper_preprocessor.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright 2024 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import keras -from absl import logging - -from keras_nlp.src.api_export import keras_nlp_export -from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.src.models.preprocessor import Preprocessor -from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( - WhisperAudioFeatureExtractor, -) -from keras_nlp.src.models.whisper.whisper_backbone import WhisperBackbone -from keras_nlp.src.models.whisper.whisper_tokenizer import WhisperTokenizer -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function - - -@keras_nlp_export("keras_nlp.models.WhisperPreprocessor") -class WhisperPreprocessor(Preprocessor): - """A Whisper preprocessing layer which handles audio and text input. - - This preprocessing layer will do three things: - - 1. Compute the log-mel spectrogram of the audio tensor inputs using - `audio_feature_extractor`. - 2. Tokenize decoder inputs using the `tokenizer`. - 2. Add the appropriate special tokens - `"<|startoftranscript|>", task - token, language token, `"<|endoftext|>"`, etc. - 3. Construct a dictionary with keys `"encoder_features"`, - `"decoder_token_ids"`, `"decoder_padding_mask"` that can be passed - directly to a Whisper model. - - Args: - tokenizer: A `keras_nlp.models.WhisperTokenizer` instance. - audio_feature_extractor: A - `keras_nlp.models.WhisperAudioFeatureExtractor` instance or `None`. - If `None` a feature extractor with default parameters will be - created. - decoder_sequence_length: The length of the packed decoder inputs. - language: string, language token. Should only be passed if your - tokenizer is multilingual. - task: string, task name. One of `"transcribe"`, `"translate"`. Should - only be passed if your tokenizer is multilingual. - no_timestamps: bool. If True, `"<|no_timestamps|>"` will be added as a - special token to your input. - - Call arguments: - x: A dictionary with `"encoder_audio"` and `"decoder_text"` as its keys. - `"encoder_audio"` should correspond to the input audio tensor. - `"decoder_text"` should be a tensor of single string sequences. - Inputs may be batched or unbatched. Raw python inputs will be - converted to tensors. - y: Any label data. Will be passed through unaltered. - sample_weight: Any label weight data. Will be passed through unaltered. - - Examples: - - Directly calling the layer on data. - ```python - preprocessor = keras_nlp.models.WhisperPreprocessor.from_preset( - "whisper_tiny_en", - ) - - # Preprocess unbatched inputs. - input_data = { - "encoder_audio": tf.ones((200,)), - "decoder_text": "The quick brown fox jumped.", - } - preprocessor(input_data) - - # Preprocess batched inputs. - input_data = { - "encoder_audio": tf.ones((2, 200)), - "decoder_text": ["The quick brown fox jumped.", "Call me Ishmael."], - } - preprocessor(input_data) - - # Custom audio feature extractor and vocabulary. - audio_feature_extractor = keras_nlp.models.WhisperAudioFeatureExtractor( - num_mels=80, - num_fft_bins=400, - stride=100, - sampling_rate=100, - max_audio_length=5, - ) - - features = ["a quick fox.", "a fox quick."] - vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6} - merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"] - merges += ["Ġ f", "o x", "Ġf ox"] - special_tokens = { - "<|startoftranscript|>": 9, - "<|endoftext|>": 10, - "<|notimestamps|>": 11, - "<|transcribe|>": 12, - "<|translate|>": 13, - } - - tokenizer = keras_nlp.models.WhisperTokenizer( - vocabulary=vocab, - merges=merges, - special_tokens=special_tokens, - ) - preprocessor = keras_nlp.models.WhisperPreprocessor( - audio_feature_extractor=audio_feature_extractor, - tokenizer=tokenizer, - ) - - input_data = { - "encoder_audio": tf.ones((200,)), - "decoder_text": "The quick brown fox jumped.", - } - preprocessor(input_data) - ``` - - Mapping with `tf.data.Dataset`. - ```python - preprocessor = keras_nlp.models.WhisperPreprocessor.from_preset( - "whisper_tiny_en") - - # Map labeled single sentences. - features = { - "encoder_audio": tf.ones((2, 200)), - "decoder_text": ["The quick brown fox jumped.", "Call me Ishmael."], - } - labels = tf.constant(["True", "False"]) - ds = tf.data.Dataset.from_tensor_slices((features, labels)) - ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) - - # Map unlabeled single sentences. - features = { - "encoder_audio": tf.ones((2, 200)), - "decoder_text": ["The quick brown fox jumped.", "Call me Ishmael."], - } - ds = tf.data.Dataset.from_tensor_slices(features) - ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) - ``` - """ - - backbone_cls = WhisperBackbone - tokenizer_cls = WhisperTokenizer - - def __init__( - self, - tokenizer, - audio_feature_extractor=None, - decoder_sequence_length=448, - language=None, - task=None, - no_timestamps=True, - **kwargs, - ): - super().__init__(**kwargs) - if audio_feature_extractor is None: - audio_feature_extractor = WhisperAudioFeatureExtractor() - self.audio_feature_extractor = audio_feature_extractor - self.tokenizer = tokenizer - self.decoder_packer = None - self.decoder_sequence_length = decoder_sequence_length - self.language = language - self.task = task - self.no_timestamps = no_timestamps - - def build(self, input_shape): - # Defer packer creation to `build()` so that we can be sure tokenizer - # assets have loaded when restoring a saved model. - - # Create list of tokens to be prepended to decoder inputs. - bos_tokens = [self.tokenizer.bos_token_id] - if self.tokenizer.language_tokens is not None: - if ( - self.language is None - or self.language not in self.tokenizer.language_tokens - ): - raise ValueError( - "You must pass a non-None value for `language` when using " - "a multilingual tokenizer. The value must be one of " - f'{",".join(self.tokenizer.language_tokens.keys())}. ' - f"Received: language={self.language}." - ) - if self.task is None or self.task not in [ - "transcribe", - "translate", - ]: - raise ValueError( - "You must pass a non-None value for `task` when using " - "a multilingual tokenizer. The value must be one of " - '`"transcribe"`, `"translate"`. ' - f"Received: task={self.task}." - ) - - bos_tokens += [self.tokenizer.language_tokens[self.language]] - - special_token_dict = self.tokenizer._special_token_dict - if self.task == "transcribe": - bos_tokens += [special_token_dict["<|transcribe|>"]] - elif self.task == "translate": - bos_tokens += [special_token_dict["<|translate|>"]] - else: - if self.language is not None: - logging.info( - "`tokenizer` is monolingual, and `language` has a " - "non-`None` value. Setting `language` to `None`." - ) - self.language = None - if self.task is not None: - logging.info( - "`tokenizer` is monolingual, and `task` has a " - "non-`None` value. Setting `task` to `None`." - ) - self.task = None - - if self.no_timestamps: - bos_tokens += [self.tokenizer.no_timestamps_token_id] - - # TODO: Use `MultiSegmentPacker` instead of `StartEndPacker` once we - # want to move to multi-segment packing and have improved - # `MultiSegmentPacker`'s performance. - self.decoder_packer = StartEndPacker( - start_value=bos_tokens, - end_value=self.tokenizer.eos_token_id, - pad_value=self.tokenizer.pad_token_id, - sequence_length=self.decoder_sequence_length, - return_padding_mask=True, - ) - - @tf_preprocessing_function - def call(self, x, y=None, sample_weight=None, decoder_sequence_length=None): - if not ( - isinstance(x, dict) - and ["encoder_audio", "decoder_text"] == list(x.keys()) - ): - raise ValueError( - '`x` must be a dictionary, containing the keys `"encoder_audio"`' - f' and `"decoder_text"`. Received x={x}.' - ) - - encoder_features = self.audio_feature_extractor(x["encoder_audio"]) - decoder_sequence_length = ( - decoder_sequence_length or self.decoder_sequence_length - ) - decoder_inputs = self.tokenizer(x["decoder_text"]) - decoder_token_ids, decoder_padding_mask = self.decoder_packer( - decoder_inputs, - sequence_length=decoder_sequence_length, - ) - - x = { - "encoder_features": encoder_features, - "decoder_token_ids": decoder_token_ids, - "decoder_padding_mask": decoder_padding_mask, - } - - return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) - - def get_config(self): - config = super().get_config() - config.update( - { - "audio_feature_extractor": keras.layers.serialize( - self.audio_feature_extractor - ), - "decoder_sequence_length": self.decoder_sequence_length, - "language": self.language, - "task": self.task, - "no_timestamps": self.no_timestamps, - } - ) - return config - - @classmethod - def from_config(cls, config): - if "tokenizer" in config and isinstance(config["tokenizer"], dict): - config["tokenizer"] = keras.layers.deserialize(config["tokenizer"]) - - if "audio_feature_extractor" in config and isinstance( - config["audio_feature_extractor"], dict - ): - config["audio_feature_extractor"] = keras.layers.deserialize( - config["audio_feature_extractor"] - ) - - return cls(**config) - - @property - def decoder_sequence_length(self): - """The padded length of decoder input sequences.""" - return self._decoder_sequence_length - - @decoder_sequence_length.setter - def decoder_sequence_length(self, value): - self._decoder_sequence_length = value - if self.decoder_packer is not None: - self.decoder_packer.sequence_length = value - - @property - def sequence_length(self): - """Alias for `decoder_sequence_length`.""" - return self.decoder_sequence_length - - @sequence_length.setter - def sequence_length(self, value): - self.decoder_sequence_length = value diff --git a/keras_nlp/src/models/whisper/whisper_preprocessor_test.py b/keras_nlp/src/models/whisper/whisper_preprocessor_test.py deleted file mode 100644 index 9721ddcad6..0000000000 --- a/keras_nlp/src/models/whisper/whisper_preprocessor_test.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np - -from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( - WhisperAudioFeatureExtractor, -) -from keras_nlp.src.models.whisper.whisper_preprocessor import ( - WhisperPreprocessor, -) -from keras_nlp.src.models.whisper.whisper_tokenizer import WhisperTokenizer -from keras_nlp.src.tests.test_case import TestCase - - -class WhisperPreprocessorTest(TestCase): - def setUp(self): - self.audio_feature_extractor = WhisperAudioFeatureExtractor( - num_mels=80, - num_fft_bins=400, - stride=100, - sampling_rate=100, - max_audio_length=5, - ) - self.vocab = ["air", "Ġair", "plane", "Ġat", "port"] - self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) - self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] - self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] - self.merges += ["Ġai r", "Ġa i", "pla ne"] - self.special_tokens = { - "<|startoftranscript|>": 9, - "<|endoftext|>": 10, - "<|notimestamps|>": 11, - "<|transcribe|>": 12, - "<|translate|>": 13, - } - self.language_tokens = { - "<|en|>": 14, - "<|fr|>": 15, - } - self.tokenizer = WhisperTokenizer( - vocabulary=self.vocab, - merges=self.merges, - special_tokens=self.special_tokens, - language_tokens=self.language_tokens, - ) - self.init_kwargs = { - "audio_feature_extractor": self.audio_feature_extractor, - "tokenizer": self.tokenizer, - "decoder_sequence_length": 12, - "language": "<|en|>", - "task": "translate", - } - self.input_data = { - "encoder_audio": np.ones((2, 200)), - "decoder_text": [" airplane at airport", " airplane at"], - } - - def test_feature_extractor_basics(self): - self.run_preprocessor_test( - cls=WhisperPreprocessor, - init_kwargs=self.init_kwargs, - input_data=self.input_data, - token_id_key="decoder_token_ids", - ) - - def test_sequence_length_override(self): - input_data = { - "encoder_audio": np.ones((200,)), - "decoder_text": " airplane at airport", - } - preprocessor = WhisperPreprocessor(**self.init_kwargs) - x = preprocessor(input_data, decoder_sequence_length=6) - self.assertAllEqual(x["decoder_token_ids"], [9, 14, 13, 11, 1, 10]) diff --git a/keras_nlp/src/models/whisper/whisper_presets.py b/keras_nlp/src/models/whisper/whisper_presets.py index 6684ac793f..1a7844bbbc 100644 --- a/keras_nlp/src/models/whisper/whisper_presets.py +++ b/keras_nlp/src/models/whisper/whisper_presets.py @@ -25,7 +25,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/2", + "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/3", }, "whisper_base_en": { "metadata": { @@ -38,7 +38,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/2", + "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/3", }, "whisper_small_en": { "metadata": { @@ -51,7 +51,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/2", + "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/3", }, "whisper_medium_en": { "metadata": { @@ -64,7 +64,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/2", + "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/3", }, "whisper_tiny_multi": { "metadata": { @@ -77,7 +77,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/2", + "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/3", }, "whisper_base_multi": { "metadata": { @@ -90,7 +90,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/2", + "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/3", }, "whisper_small_multi": { "metadata": { @@ -103,7 +103,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/2", + "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/3", }, "whisper_medium_multi": { "metadata": { @@ -116,7 +116,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/2", + "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/3", }, "whisper_large_multi": { "metadata": { @@ -129,7 +129,7 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/2", + "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/3", }, "whisper_large_multi_v2": { "metadata": { @@ -143,6 +143,6 @@ "path": "whisper", "model_card": "https://github.com/openai/whisper/blob/main/model-card.md", }, - "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/2", + "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/3", }, } diff --git a/keras_nlp/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py b/keras_nlp/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py index 8feb7d674f..d1f028aac4 100644 --- a/keras_nlp/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +++ b/keras_nlp/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py @@ -25,7 +25,7 @@ from keras_nlp.src.models.xlm_roberta.xlm_roberta_tokenizer import ( XLMRobertaTokenizer, ) -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export("keras_nlp.models.XLMRobertaMaskedLMPreprocessor") @@ -140,7 +140,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): output = super().call(x, y=y, sample_weight=sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) diff --git a/keras_nlp/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py b/keras_nlp/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py index 756b935dd0..52984e9ad7 100644 --- a/keras_nlp/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py +++ b/keras_nlp/src/models/xlm_roberta/xlm_roberta_text_classifier_preprocessor.py @@ -28,7 +28,7 @@ from keras_nlp.src.models.xlm_roberta.xlm_roberta_tokenizer import ( XLMRobertaTokenizer, ) -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export( @@ -169,7 +169,7 @@ def build(self, input_shape): ) self.built = True - @tf_preprocessing_function + @preprocessing_function def call(self, x, y=None, sample_weight=None): output = super().call(x, y=y, sample_weight=sample_weight) x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(output) diff --git a/keras_nlp/src/tokenizers/byte_pair_tokenizer.py b/keras_nlp/src/tokenizers/byte_pair_tokenizer.py index 26096f5ec4..5eecf4cbfc 100644 --- a/keras_nlp/src/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/src/tokenizers/byte_pair_tokenizer.py @@ -31,7 +31,7 @@ from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.src.utils.tensor_utils import is_int_dtype from keras_nlp.src.utils.tensor_utils import is_string_dtype -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function try: import tensorflow as tf @@ -534,7 +534,7 @@ def _check_vocabulary(self): "layer." ) - @tf_preprocessing_function + @preprocessing_function def tokenize(self, inputs): self._check_vocabulary() if self.add_prefix_space: @@ -598,7 +598,7 @@ def process_unseen_tokens(): return tokens - @tf_preprocessing_function + @preprocessing_function def detokenize(self, inputs): self._check_vocabulary() inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) diff --git a/keras_nlp/src/tokenizers/byte_tokenizer.py b/keras_nlp/src/tokenizers/byte_tokenizer.py index 594b2c2ffc..6b70df35ce 100644 --- a/keras_nlp/src/tokenizers/byte_tokenizer.py +++ b/keras_nlp/src/tokenizers/byte_tokenizer.py @@ -26,7 +26,7 @@ from keras_nlp.src.tokenizers import tokenizer from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.src.utils.tensor_utils import is_int_dtype -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function try: import tensorflow_text as tf_text @@ -212,7 +212,7 @@ def get_vocabulary(self): vocab[chr(i)] = i return vocab - @tf_preprocessing_function + @preprocessing_function def tokenize(self, inputs): unbatched = inputs.shape.rank == 0 if unbatched: @@ -243,7 +243,7 @@ def tokenize(self, inputs): tokens = tf.squeeze(tokens, 0) return tokens - @tf_preprocessing_function + @preprocessing_function def detokenize(self, inputs): inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) # Remove trailing padding tokens, so that trailing "\x00" bytes don't diff --git a/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py index 37d609128b..ea9f4d6f4a 100644 --- a/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/src/tokenizers/sentence_piece_tokenizer.py @@ -31,8 +31,8 @@ from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.src.utils.tensor_utils import is_int_dtype from keras_nlp.src.utils.tensor_utils import is_string_dtype +from keras_nlp.src.utils.tensor_utils import preprocessing_function from keras_nlp.src.utils.tensor_utils import tensor_to_list -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function try: import tensorflow_text as tf_text @@ -235,7 +235,7 @@ def _check_vocabulary(self): "sure to pass a `proto` argument when creating the layer." ) - @tf_preprocessing_function + @preprocessing_function def tokenize(self, inputs): self._check_vocabulary() unbatched = inputs.shape.rank == 0 @@ -262,7 +262,7 @@ def tokenize(self, inputs): tf.ensure_shape(tokens, shape=[self.sequence_length]) return tokens - @tf_preprocessing_function + @preprocessing_function def detokenize(self, inputs): self._check_vocabulary() inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) diff --git a/keras_nlp/src/tokenizers/tokenizer.py b/keras_nlp/src/tokenizers/tokenizer.py index f1304c8413..2f8a5c9549 100644 --- a/keras_nlp/src/tokenizers/tokenizer.py +++ b/keras_nlp/src/tokenizers/tokenizer.py @@ -27,7 +27,7 @@ from keras_nlp.src.utils.preset_utils import save_serialized_object from keras_nlp.src.utils.preset_utils import save_tokenizer_assets from keras_nlp.src.utils.python_utils import classproperty -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function @keras_nlp_export( @@ -201,7 +201,7 @@ def save_to_preset(self, preset_dir): ) save_tokenizer_assets(self, preset_dir) - @tf_preprocessing_function + @preprocessing_function def call(self, inputs, *args, training=None, **kwargs): return self.tokenize(inputs, *args, **kwargs) @@ -231,7 +231,7 @@ def from_preset( """Instantiate a `keras_nlp.models.Tokenizer` from a model preset. A preset is a directory of configs, weights and other file assets used - to save and load a pre-trained model. The `preset` can be passed as a + to save and load a pre-trained model. The `preset` can be passed as one of: 1. a built-in preset identifier like `'bert_base_en'` diff --git a/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer.py b/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer.py index 16115fa199..30bf31afc1 100644 --- a/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer.py +++ b/keras_nlp/src/tokenizers/unicode_codepoint_tokenizer.py @@ -17,7 +17,7 @@ from keras_nlp.src.tokenizers import tokenizer from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.src.utils.tensor_utils import is_int_dtype -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function try: import tensorflow as tf @@ -284,7 +284,7 @@ def get_vocabulary(self): vocab[chr(i)] = i return vocab - @tf_preprocessing_function + @preprocessing_function def tokenize(self, inputs): unbatched = inputs.shape.rank == 0 if unbatched: @@ -321,7 +321,7 @@ def tokenize(self, inputs): return tokens - @tf_preprocessing_function + @preprocessing_function def detokenize(self, inputs): inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) inputs = tf.ragged.boolean_mask(inputs, tf.not_equal(inputs, 0)) diff --git a/keras_nlp/src/tokenizers/word_piece_tokenizer.py b/keras_nlp/src/tokenizers/word_piece_tokenizer.py index 8336ffe83b..f228afae2a 100644 --- a/keras_nlp/src/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/src/tokenizers/word_piece_tokenizer.py @@ -23,7 +23,7 @@ from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.src.utils.tensor_utils import is_int_dtype from keras_nlp.src.utils.tensor_utils import is_string_dtype -from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function +from keras_nlp.src.utils.tensor_utils import preprocessing_function try: import tensorflow as tf @@ -470,7 +470,7 @@ def _check_vocabulary(self): "to pass a `vocabulary` argument when creating the layer." ) - @tf_preprocessing_function + @preprocessing_function def tokenize(self, inputs): self._check_vocabulary() unbatched = inputs.shape.rank == 0 @@ -515,7 +515,7 @@ def tokenize(self, inputs): return tokens - @tf_preprocessing_function + @preprocessing_function def detokenize(self, inputs): self._check_vocabulary() inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index 146f64ce12..21feaec0ae 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -60,6 +60,8 @@ # Config file names. CONFIG_FILE = "config.json" TOKENIZER_CONFIG_FILE = "tokenizer.json" +AUDIO_CONVERTER_CONFIG_FILE = "audio_converter.json" +IMAGE_CONVERTER_CONFIG_FILE = "image_converter.json" TASK_CONFIG_FILE = "task.json" PREPROCESSOR_CONFIG_FILE = "preprocessor.json" METADATA_FILE = "metadata.json" @@ -656,6 +658,14 @@ def load_tokenizer(self, cls, **kwargs): """Load a tokenizer layer from the preset.""" raise NotImplementedError + def load_audio_converter(self, cls, **kwargs): + """Load an audio converter layer from the preset.""" + raise NotImplementedError + + def load_image_converter(self, cls, **kwargs): + """Load an image converter layer from the preset.""" + raise NotImplementedError + def load_task(self, cls, load_weights, load_task_extras, **kwargs): """Load a task model from the preset. @@ -683,8 +693,16 @@ def load_preprocessor(self, cls, load_task_extras, **kwargs): arguments. This allow us to support transformers checkpoints by only converting the backbone and tokenizer. """ - if "tokenizer" not in kwargs: + if "tokenizer" not in kwargs and cls.tokenizer_cls: kwargs["tokenizer"] = self.load_tokenizer(cls.tokenizer_cls) + if "audio_converter" not in kwargs and cls.audio_converter_cls: + kwargs["audio_converter"] = self.load_audio_converter( + cls.audio_converter_cls + ) + if "image_converter" not in kwargs and cls.image_converter_cls: + kwargs["image_converter"] = self.load_image_converter( + cls.image_converter_cls + ) return cls(**kwargs) @@ -705,6 +723,14 @@ def load_tokenizer(self, cls, **kwargs): tokenizer.load_preset_assets(self.preset) return tokenizer + def load_audio_converter(self, cls, **kwargs): + converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE) + return load_serialized_object(converter_config, **kwargs) + + def load_image_converter(self, cls, **kwargs): + converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE) + return load_serialized_object(converter_config, **kwargs) + def load_task(self, cls, load_weights, load_task_extras, **kwargs): # If there is no `task.json` or it's for the wrong class delegate to the # super class loader. diff --git a/keras_nlp/src/utils/tensor_utils.py b/keras_nlp/src/utils/tensor_utils.py index 4806400f2f..94594202ab 100644 --- a/keras_nlp/src/utils/tensor_utils.py +++ b/keras_nlp/src/utils/tensor_utils.py @@ -18,10 +18,9 @@ import threading import keras +import numpy as np from keras import ops -from keras_nlp.src.api_export import keras_nlp_export - try: import tensorflow as tf import tensorflow_text as tf_text @@ -47,7 +46,7 @@ def in_no_convert_scope(): return NO_CONVERT_COUNTER.count > 0 -def tf_preprocessing_function(fn): +def preprocessing_function(fn): """Wraps a preprocessing function to handle tf tensor conversion.""" if tf is None: return fn @@ -59,33 +58,33 @@ def tf_preprocessing_function(fn): @functools.wraps(fn) def wrapper(self, x, **kwargs): - x = convert_to_tf(x) + x = convert_preprocessing_inputs(x) with no_convert_scope(): x = fn(self, x, **kwargs) - return convert_from_tf(x) + return convert_preprocessing_outputs(x) else: @functools.wraps(fn) def wrapper(self, x, y=None, sample_weight=None, **kwargs): - x, y, sample_weight = convert_to_tf((x, y, sample_weight)) + x, y, sample_weight = convert_preprocessing_inputs( + (x, y, sample_weight) + ) with no_convert_scope(): x = fn(self, x, y=y, sample_weight=sample_weight, **kwargs) - return convert_from_tf(x) + return convert_preprocessing_outputs(x) return wrapper -@keras_nlp_export("keras_nlp.utils.convert_to_tf") -def convert_to_tf(x): - """Convert raw inputs to tf inputs for preprocessing. +def convert_preprocessing_inputs(x): + """Convert raw inputs for preprocessing. This function is used to convert raw inputs (strings, lists, `np.ndarray`s, - `jax.Array`s, `torch.Tensor`s) to tensorflow inputs for use with `tf.data` - and KerasNLP preprocessing layers. It will convert ragged inputs and string - inputs `tf.RaggedTensor` and `tf.Tensor` types. This will automatically be - called when running preprocessing layers or `keras_nlp.models.Task`s with - preprocessing included. + `jax.Array`s, `torch.Tensor`s, etc) to a canonical format for + preprocessing layers. All inputs will be converted to backend tensors if + possible, except ragged inputs and string inputs which be converted to tf + tensors regardless of backend. `tuple` and `list` elements are handled differently by this function. A `tuple` is assumed to enumerate separate inputs, and a `list` is assumed to @@ -97,11 +96,11 @@ def convert_to_tf(x): ```python # Two ragged arrays of token ids. x = ([[1, 2, 3], [4, 5]], [[1, 2], [3, 4, 5]]) - keras_nlp.utils.convert_to_tf(x) + keras_nlp.utils.convert_preprocessing_inputs(x) # A batch of three samples each with two string segments. x = (["hi", "hello", "hey"], ["bye", "later", "so long"]) - keras_nlp.utils.convert_to_tf(x) + keras_nlp.utils.convert_preprocessing_inputs(x) # A batch of features in a dictionary. x = { @@ -109,31 +108,57 @@ def convert_to_tf(x): "images": np.ones((3, 64, 64, 3)), "labels": [1, 0, 1], } - keras_nlp.utils.convert_to_tf(x) + keras_nlp.utils.convert_preprocessing_inputs(x) ``` """ + if not tf.executing_eagerly() or in_no_convert_scope(): + return x + if isinstance(x, dict): - return {k: convert_to_tf(x[k]) for k, v in x.items()} + return {k: convert_preprocessing_inputs(x[k]) for k, v in x.items()} if isinstance(x, tuple): - return tuple(convert_to_tf(v) for v in x) - if isinstance(x, list): - return tf.ragged.constant(x) + return tuple(convert_preprocessing_inputs(v) for v in x) if isinstance(x, str): return tf.constant(x) + if isinstance(x, list): + try: + numpy_x = np.array(x) + except ValueError as e: + # If numpy conversion failed, try converting to a ragged array. + try: + return tf.ragged.constant(x) + except ValueError: + # If ragged conversion failed return to the numpy error. + raise e + # If we have a string input, use tf.tensor. + if numpy_x.dtype.type is np.str_: + return tf.convert_to_tensor(x) + # Numpy will default to int64, int32 works with more ops. + if numpy_x.dtype == np.int64: + numpy_x = numpy_x.astype(np.int32) + # We have non-ragged, non-string input. Use backbend type. + x = ops.convert_to_tensor(numpy_x) + # Torch will complain about device placement for GPU tensors. + if keras.config.backend() == "torch": + x = x.cpu() + return x if is_tensor_type(x): + # String or ragged types we keep as tf. + if isinstance(x, tf.RaggedTensor) or x.dtype == tf.string: + return x + # If we have a string input, use tf.tensor. + if isinstance(x, np.ndarray) and x.dtype.type is np.str_: + return tf.convert_to_tensor(x) + x = ops.convert_to_tensor(x) # Torch will complain about device placement for GPU tensors. if keras.config.backend() == "torch": - import torch - - if isinstance(x, torch.Tensor): - x = x.cpu() - return tf.convert_to_tensor(x) + x = x.cpu() + return x return x -@keras_nlp_export("keras_nlp.utils.convert_from_tf") -def convert_from_tf(x): - """Convert tf outputs after preprocessing to a backend agnostic format. +def convert_preprocessing_outputs(x): + """Convert outputs after preprocessing to a backend agnostic format. This function is used to convert `tf.Tensor` and `tf.RaggedTensor` output from preprocessing layers to either: @@ -150,11 +175,11 @@ def convert_from_tf(x): ```python # Two ragged arrays of token ids. x = tf.ragged.constant([[1, 2, 3], [4, 5]]) - keras_nlp.utils.convert_from_tf(x) + keras_nlp.utils.convert_preprocessing_outputs(x) # A batch of three samples each with two string segments. x = (tf.constant["hi", "yo", "hey"]), tf.constant(["bye", "ciao", ""])) - keras_nlp.utils.convert_from_tf(x) + keras_nlp.utils.convert_preprocessing_outputs(x) # A batch of features in a dictionary. x = { @@ -162,7 +187,7 @@ def convert_from_tf(x): "images": tf.ones((3, 64, 64, 3)), "labels": tf.constant([1, 0, 1]), } - keras_nlp.utils.convert_from_tf(x) + keras_nlp.utils.convert_preprocessing_outputs(x) ``` """ if not tf.executing_eagerly() or in_no_convert_scope(): @@ -210,6 +235,8 @@ def tensor_to_list(inputs): def convert_to_ragged_batch(inputs): """Ensure a tf.Tensor is a ragged rank 2 tensor.""" + if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)): + inputs = tf.convert_to_tensor(inputs) unbatched = inputs.shape.rank == 1 rectangular = isinstance(inputs, tf.Tensor) if unbatched: diff --git a/keras_nlp/src/utils/tensor_utils_test.py b/keras_nlp/src/utils/tensor_utils_test.py index 478f05c9ad..c0f34595c2 100644 --- a/keras_nlp/src/utils/tensor_utils_test.py +++ b/keras_nlp/src/utils/tensor_utils_test.py @@ -19,45 +19,44 @@ from keras_nlp.src.tests.test_case import TestCase from keras_nlp.src.utils.tensor_utils import any_equal -from keras_nlp.src.utils.tensor_utils import convert_from_tf +from keras_nlp.src.utils.tensor_utils import convert_preprocessing_inputs +from keras_nlp.src.utils.tensor_utils import convert_preprocessing_outputs from keras_nlp.src.utils.tensor_utils import convert_to_ragged_batch -from keras_nlp.src.utils.tensor_utils import convert_to_tf from keras_nlp.src.utils.tensor_utils import is_tensor_type from keras_nlp.src.utils.tensor_utils import tensor_to_list -class ConvertTf(TestCase): +class ConvertHelpers(TestCase): def test_basics(self): inputs = ops.array([1, 2, 3]) # Convert to tf. - outputs = convert_to_tf(inputs) - self.assertIsInstance(outputs, tf.Tensor) - self.assertAllEqual(outputs, tf.constant(ops.convert_to_numpy(inputs))) + outputs = convert_preprocessing_inputs(inputs) + self.assertAllEqual(outputs, ops.array(inputs)) # Convert from tf. - outputs = convert_from_tf(outputs) + outputs = convert_preprocessing_outputs(outputs) self.assertTrue(is_tensor_type(outputs)) self.assertAllEqual(outputs, inputs) def test_strings(self): inputs = ["one", "two"] # Convert to tf. - outputs = convert_to_tf(inputs) + outputs = convert_preprocessing_inputs(inputs) self.assertIsInstance(outputs, tf.Tensor) self.assertAllEqual(outputs, tf.constant(inputs)) # Convert from tf. - outputs = convert_from_tf(outputs) + outputs = convert_preprocessing_outputs(outputs) self.assertIsInstance(outputs, list) self.assertEqual(outputs, inputs) def test_ragged(self): inputs = [np.ones((1, 3)), np.ones((1, 2))] # Convert to tf. - outputs = convert_to_tf(inputs) + outputs = convert_preprocessing_inputs(inputs) self.assertIsInstance(outputs, tf.RaggedTensor) print(outputs, inputs) self.assertAllEqual(outputs, tf.ragged.constant(inputs)) # Convert from tf. - outputs = convert_from_tf(outputs) + outputs = convert_preprocessing_outputs(outputs) self.assertIsInstance(outputs, list) self.assertEqual(outputs, [[[1, 1, 1]], [[1, 1]]]) @@ -72,14 +71,14 @@ def test_composite(self): [3, 4], ) - outputs = convert_to_tf(inputs) + outputs = convert_preprocessing_inputs(inputs) self.assertIsInstance(outputs[0]["text"], tf.Tensor) self.assertIsInstance(outputs[0]["images"], tf.RaggedTensor) self.assertIsInstance(outputs[0]["ragged_ints"], tf.RaggedTensor) - self.assertIsInstance(outputs[1], tf.Tensor) - self.assertIsInstance(outputs[2], tf.Tensor) + self.assertTrue(is_tensor_type(outputs[1])) + self.assertTrue(is_tensor_type(outputs[2])) - outputs = convert_from_tf(outputs) + outputs = convert_preprocessing_outputs(outputs) self.assertIsInstance(outputs[0]["text"], list) self.assertIsInstance(outputs[0]["images"], list) self.assertIsInstance(outputs[0]["ragged_ints"], list) diff --git a/keras_nlp/src/utils/transformers/preset_loader.py b/keras_nlp/src/utils/transformers/preset_loader.py index 1a1ce928ba..7e93140979 100644 --- a/keras_nlp/src/utils/transformers/preset_loader.py +++ b/keras_nlp/src/utils/transformers/preset_loader.py @@ -71,3 +71,7 @@ def load_backbone(self, cls, load_weights, **kwargs): def load_tokenizer(self, cls, **kwargs): return self.converter.convert_tokenizer(cls, self.preset, **kwargs) + + def load_image_converter(self, cls, **kwargs): + # TODO: set image size for pali gemma checkpoints. + return None