Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-tensorflow backends in KerasCV's preprocessing layers #2240

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions keras_cv/layers/preprocessing/aug_mix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def test_return_shapes(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])
self.assertEqual(xs.shape, (2, 512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 3))

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])
self.assertEqual(xs.shape, (2, 512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 1))

def test_in_single_image_and_mask(self):
layer = preprocessing.AugMix([0, 255])
Expand All @@ -54,8 +54,8 @@ def test_in_single_image_and_mask(self):
)

ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [512, 512, 3])
self.assertEqual(xs.shape, (512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (512, 512, 3))

# greyscale
xs = tf.cast(
Expand All @@ -69,8 +69,8 @@ def test_in_single_image_and_mask(self):
dtype=tf.float32,
)
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [512, 512, 1])
self.assertEqual(xs.shape, (512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (512, 512, 1))

def test_non_square_images_and_masks(self):
layer = preprocessing.AugMix([0, 255])
Expand All @@ -80,16 +80,16 @@ def test_non_square_images_and_masks(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 256, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 256, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 256, 512, 3])
self.assertEqual(xs.shape, (2, 256, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 256, 512, 3))

# greyscale
xs = tf.ones((2, 256, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 256, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 256, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 256, 512, 1])
self.assertEqual(xs.shape, (2, 256, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 256, 512, 1))

def test_single_input_args(self):
layer = preprocessing.AugMix([0, 255])
Expand All @@ -99,16 +99,16 @@ def test_single_input_args(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])
self.assertEqual(xs.shape, (2, 512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 3))

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])
self.assertEqual(xs.shape, (2, 512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 1))

def test_many_augmentations(self):
layer = preprocessing.AugMix([0, 255], chain_depth=[25, 26])
Expand All @@ -118,13 +118,13 @@ def test_many_augmentations(self):
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 3))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 3])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 3])
self.assertEqual(xs.shape, (2, 512, 512, 3))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 3))

# greyscale
xs = tf.ones((2, 512, 512, 1))
xs = layer(xs)
ys_segmentation_masks = tf.ones((2, 512, 512, 1))
ys_segmentation_masks = layer(ys_segmentation_masks)
self.assertEqual(xs.shape, [2, 512, 512, 1])
self.assertEqual(ys_segmentation_masks.shape, [2, 512, 512, 1])
self.assertEqual(xs.shape, (2, 512, 512, 1))
self.assertEqual(ys_segmentation_masks.shape, (2, 512, 512, 1))
66 changes: 33 additions & 33 deletions keras_cv/layers/preprocessing/auto_contrast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

import tensorflow as tf

from keras_cv.backend import ops
from keras_cv.layers import preprocessing
from keras_cv.tests.test_case import TestCase


class AutoContrastTest(TestCase):
def test_constant_channels_dont_get_nanned(self):
img = tf.constant([1, 1], dtype=tf.float32)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([1, 1], dtype=np.float32)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 1.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 1.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0))

def test_auto_contrast_expands_value_range(self):
img = tf.constant([0, 128], dtype=tf.float32)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([0, 128], dtype=np.float32)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 255.0))

def test_auto_contrast_different_values_per_channel(self):
img = tf.constant(
img = np.array(
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
dtype=tf.float32,
dtype=np.float32,
)
img = tf.expand_dims(img, axis=0)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0, ..., 0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0, ..., 1] == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 1]) == 0.0))

self.assertTrue(tf.math.reduce_any(ys[0, ..., 0] == 255.0))
self.assertTrue(tf.math.reduce_any(ys[0, ..., 1] == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 0]) == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 1]) == 255.0))

self.assertAllClose(
ys,
Expand All @@ -71,25 +71,25 @@ def test_auto_contrast_different_values_per_channel(self):
)

def test_auto_contrast_expands_value_range_uint8(self):
img = tf.constant([0, 128], dtype=tf.uint8)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([0, 128], dtype=np.uint8)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 255))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 255.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 255.0))

def test_auto_contrast_properly_converts_value_range(self):
img = tf.constant([0, 0.5], dtype=tf.float32)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=-1)
img = tf.expand_dims(img, axis=0)
img = np.array([0, 0.5], dtype=np.float32)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=-1)
img = np.expand_dims(img, axis=0)

layer = preprocessing.AutoContrast(value_range=(0, 1))
ys = layer(img)

self.assertTrue(tf.math.reduce_any(ys[0] == 0.0))
self.assertTrue(tf.math.reduce_any(ys[0] == 1.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0))
self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0))
40 changes: 30 additions & 10 deletions keras_cv/layers/preprocessing/base_image_augmentation_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import keras
import tensorflow as tf
import tree

if hasattr(keras, "src"):
keras_backend = keras.src.backend
Expand All @@ -23,8 +24,8 @@
from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend import scope
from keras_cv.backend.config import keras_3
from keras_cv.utils import preprocessing

# In order to support both unbatched and batched inputs, the horizontal
Expand All @@ -42,15 +43,8 @@
USE_TARGETS = "use_targets"


base_class = (
keras.src.layers.preprocessing.tf_data_layer.TFDataLayer
if keras_3()
else keras.layers.Layer
)


@keras_cv_export("keras_cv.layers.BaseImageAugmentationLayer")
class BaseImageAugmentationLayer(base_class):
class BaseImageAugmentationLayer(keras.layers.Layer):
"""Abstract base layer for image augmentation.

This layer contains base functionalities for preprocessing layers which
Expand Down Expand Up @@ -415,6 +409,19 @@ def get_random_transformation(
return None

def call(self, inputs):
# try to convert a given backend native tensor to TensorFlow tensor
# before passing it over to TFDataScope
contains_ragged = lambda y: any(
tree.map_structure(
lambda x: isinstance(x, (tf.RaggedTensor, tf.SparseTensor)),
tree.flatten(y),
)
)
inputs_contain_ragged = contains_ragged(inputs)
if not inputs_contain_ragged:
inputs = tree.map_structure(
lambda x: tf.convert_to_tensor(x), inputs
)
with scope.TFDataScope():
inputs = self._ensure_inputs_are_compute_dtype(inputs)
inputs, metadata = self._format_inputs(inputs)
Expand All @@ -431,7 +438,20 @@ def call(self, inputs):
"rank 3 (HWC) or 4D (NHWC) tensors. Got shape: "
f"{images.shape}"
)
return outputs
# convert the outputs to backend native tensors if none of them
# contain RaggedTensors. Note that if the user passed in Raggeds
# but the outputs are dense, we still don't want to convert to
# backend native tensors. This is to avoid breaking TF data
# pipelines that can't easily be ported to become backend
# agnostic.
if not inputs_contain_ragged and not contains_ragged(outputs):
outputs = tree.map_structure(
# some layers return None, handle that case when
# converting to tensors
lambda x: ops.convert_to_tensor(x) if x is not None else x,
outputs,
)
return outputs

def _augment(self, inputs):
raw_image = inputs.get(IMAGES, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import tensorflow as tf

from keras_cv import bounding_box
from keras_cv.backend import ops
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
Expand Down Expand Up @@ -78,17 +80,17 @@ def test_augment_dict_return_type(self):

def test_augment_casts_dtypes(self):
add_layer = RandomAddLayer(fixed_value=2.0)
images = tf.ones((2, 8, 8, 3), dtype="uint8")
images = np.ones((2, 8, 8, 3), dtype="uint8")
output = add_layer(images)

self.assertAllClose(
tf.ones((2, 8, 8, 3), dtype="float32") * 3.0, output
np.ones((2, 8, 8, 3), dtype="float32") * 3.0, output
)

def test_augment_batch_images(self):
add_layer = RandomAddLayer()
images = np.random.random(size=(2, 8, 8, 3)).astype("float32")
output = add_layer(images)
output = ops.convert_to_numpy(add_layer(images))

diff = output - images
# Make sure the first image and second image get different augmentation
Expand Down Expand Up @@ -118,8 +120,8 @@ def test_augment_batch_images_and_targets(self):
targets = np.random.random(size=(2, 1)).astype("float32")
output = add_layer({"images": images, "targets": targets})

image_diff = output["images"] - images
label_diff = output["targets"] - targets
image_diff = ops.convert_to_numpy(output["images"]) - images
label_diff = ops.convert_to_numpy(output["targets"]) - targets
# Make sure the first image and second image get different augmentation
self.assertNotAllClose(image_diff[0], image_diff[1])
self.assertNotAllClose(label_diff[0], label_diff[1])
Expand Down Expand Up @@ -225,6 +227,7 @@ def test_augment_batch_image_and_localization_data(self):
segmentation_mask_diff[0], segmentation_mask_diff[1]
)

@pytest.mark.tf_only
def test_augment_all_data_in_tf_function(self):
add_layer = RandomAddLayer()
images = np.random.random(size=(2, 8, 8, 3)).astype("float32")
Expand Down
Loading
Loading