diff --git a/keras_cv/layers/preprocessing/grid_mask.py b/keras_cv/layers/preprocessing/grid_mask.py index 44c1cac4d7..c5d69bf953 100644 --- a/keras_cv/layers/preprocessing/grid_mask.py +++ b/keras_cv/layers/preprocessing/grid_mask.py @@ -213,7 +213,7 @@ def _augment_images(self, images): images = tf.expand_dims(images, axis=0) # TODO: Make the batch operation vectorize. - output = tf.map_fn(lambda image: self._grid_mask(image), images) + output = tf.vectorized_map(lambda image: self._grid_mask(image), images, fallback_to_while_loop=True) if unbatched: output = tf.squeeze(output, axis=0)