From 1d520a6554acac07c9424f9000b211cf473721aa Mon Sep 17 00:00:00 2001 From: bhack Date: Tue, 22 Feb 2022 15:29:11 +0100 Subject: [PATCH 1/3] Vectorized function on the batch dimension --- keras_cv/layers/preprocessing/grid_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/layers/preprocessing/grid_mask.py b/keras_cv/layers/preprocessing/grid_mask.py index 44c1cac4d7..47597bab87 100644 --- a/keras_cv/layers/preprocessing/grid_mask.py +++ b/keras_cv/layers/preprocessing/grid_mask.py @@ -213,7 +213,7 @@ def _augment_images(self, images): images = tf.expand_dims(images, axis=0) # TODO: Make the batch operation vectorize. - output = tf.map_fn(lambda image: self._grid_mask(image), images) + output = tf.vectorized_map(lambda image: self._grid_mask(image), images, False) if unbatched: output = tf.squeeze(output, axis=0) From 50c7d07c761ae64ad2f90ea520b5033f7ecc9f9b Mon Sep 17 00:00:00 2001 From: bhack Date: Wed, 23 Feb 2022 20:32:27 +0100 Subject: [PATCH 2/3] Update grid_mask.py --- keras_cv/layers/preprocessing/grid_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/layers/preprocessing/grid_mask.py b/keras_cv/layers/preprocessing/grid_mask.py index 47597bab87..99bb9f893d 100644 --- a/keras_cv/layers/preprocessing/grid_mask.py +++ b/keras_cv/layers/preprocessing/grid_mask.py @@ -213,7 +213,7 @@ def _augment_images(self, images): images = tf.expand_dims(images, axis=0) # TODO: Make the batch operation vectorize. - output = tf.vectorized_map(lambda image: self._grid_mask(image), images, False) + output = tf.vectorized_map(lambda image: self._grid_mask(image), images, True) if unbatched: output = tf.squeeze(output, axis=0) From 59449dbafe72b5b2fb1267bb255291154e351592 Mon Sep 17 00:00:00 2001 From: bhack Date: Thu, 24 Feb 2022 00:52:47 +0100 Subject: [PATCH 3/3] add explict arg keyword --- keras_cv/layers/preprocessing/grid_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/layers/preprocessing/grid_mask.py b/keras_cv/layers/preprocessing/grid_mask.py index 99bb9f893d..c5d69bf953 100644 --- a/keras_cv/layers/preprocessing/grid_mask.py +++ b/keras_cv/layers/preprocessing/grid_mask.py @@ -213,7 +213,7 @@ def _augment_images(self, images): images = tf.expand_dims(images, axis=0) # TODO: Make the batch operation vectorize. - output = tf.vectorized_map(lambda image: self._grid_mask(image), images, True) + output = tf.vectorized_map(lambda image: self._grid_mask(image), images, fallback_to_while_loop=True) if unbatched: output = tf.squeeze(output, axis=0)