Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow.keras import layers
from keras_cv import core
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
from keras_cv.utils import fill_utils
from keras_cv.utils import preprocessing
def _center_crop(mask, width, height):
masks_shape = tf.shape(mask)
h_diff = masks_shape[0] - height
w_diff = masks_shape[1] - width
h_start = tf.cast(h_diff / 2, tf.int32)
w_start = tf.cast(w_diff / 2, tf.int32)
return tf.image.crop_to_bounding_box(mask, h_start, w_start, height, width)
@tf.keras.utils.register_keras_serializable(package="keras_cv")
class GridMask(BaseImageAugmentationLayer):
"""GridMask class for grid-mask augmentation.
Input shape:
Int or float tensor with values in the range [0, 255].
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format
Output shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format
Args:
ratio_factor: A float, tuple of two floats, or `keras_cv.FactorSampler`.
Ratio determines the ratio from spacings to grid masks.
Lower values make the grid
size smaller, and higher values make the grid mask large.
Floats should be in the range [0, 1]. 0.5 indicates that grid and
spacing will be of equal size. To always use the same value, pass a
`keras_cv.ConstantFactorSampler()`.
Defaults to `(0, 0.5)`.
rotation_factor:
The rotation_factor will be used to randomly rotate the grid_mask during
training. Default to 0.1, which results in an output rotating by a
random amount in the range [-10% * 2pi, 10% * 2pi].
A float represented as fraction of 2 Pi, or a tuple of size 2
representing lower and upper bound for rotating clockwise and
counter-clockwise. A positive values means rotating counter clock-wise,
while a negative value means clock-wise. When represented as a single
float, this value is used for both the upper and lower bound. For
instance, factor=(-0.2, 0.3) results in an output rotation by a random
amount in the range [-20% * 2pi, 30% * 2pi]. factor=0.2 results in an
output rotating by a random amount in the range [-20% * 2pi, 20% * 2pi].
fill_mode: Pixels inside the gridblock are filled according to the given
mode (one of `{"constant", "gaussian_noise"}`). Default: "constant".
- *constant*: Pixels are filled with the same constant value.
- *gaussian_noise*: Pixels are filled with random gaussian noise.
fill_value: an integer represents of value to be filled inside the gridblock
when `fill_mode="constant"`. Valid integer range [0 to 255]
seed: Integer. Used to create a random seed.
Usage:
```python
(images, labels), _ = tf.keras.datasets.cifar10.load_data()
random_gridmask = keras_cv.layers.preprocessing.GridMask()
augmented_images = random_gridmask(images)
```
References:
- [GridMask paper](https://arxiv.org/abs/2001.04086)
"""
def __init__(
self,
ratio_factor=(0, 0.5),
rotation_factor=0.15,
fill_mode="constant",
fill_value=0.0,
seed=None,
**kwargs,
):
super().__init__(seed=seed, **kwargs)
self.ratio_factor = preprocessing.parse_factor(
ratio_factor, param_name="ratio_factor"
)
if isinstance(rotation_factor, core.FactorSampler):
raise ValueError(
"Currently `GridMask.rotation_factor` does not support the "
"`FactorSampler` API. This will be supported in the next Keras "
"release. For now, please pass a float for the "
"`rotation_factor` argument."
)
self.fill_mode = fill_mode
self.fill_value = fill_value
self.rotation_factor = rotation_factor
self.random_rotate = layers.RandomRotation(
factor=rotation_factor,
fill_mode="constant",
fill_value=0.0,
seed=seed,
)
self.auto_vectorize = False
self._check_parameter_values()
self.seed = seed
def _check_parameter_values(self):
fill_mode, fill_value = self.fill_mode, self.fill_value
if fill_value not in range(0, 256):
raise ValueError(
f"fill_value should be in the range [0, 255]. Got {fill_value}"
)
if fill_mode not in ["constant", "gaussian_noise", "random"]:
raise ValueError(
'`fill_mode` should be "constant", '
f'"gaussian_noise", or "random". Got `fill_mode`={fill_mode}'
)
def get_random_transformation(
self, image=None, label=None, bounding_boxes=None, **kwargs
):
ratio = self.ratio_factor()
# compute grid mask
input_shape = tf.shape(image)
mask = self._compute_grid_mask(input_shape, ratio=ratio)
# convert mask to single-channel image
mask = tf.cast(mask, tf.float32)
mask = tf.expand_dims(mask, axis=-1)
# randomly rotate mask
mask = self.random_rotate(mask)
# compute fill
if self.fill_mode == "constant":
fill_value = tf.fill(input_shape, self.fill_value)
fill_value = tf.cast(fill_value, dtype=self.compute_dtype)
else:
# gaussian noise
fill_value = self._random_generator.random_normal(
shape=input_shape, dtype=self.compute_dtype
)
return mask, fill_value
def _compute_grid_mask(self, input_shape, ratio):
height = tf.cast(input_shape[0], tf.float32)
width = tf.cast(input_shape[1], tf.float32)
# mask side length
input_diagonal_len = tf.sqrt(tf.square(width) + tf.square(height))
mask_side_len = tf.math.ceil(input_diagonal_len)
# grid unit size
unit_size = self._random_generator.random_uniform(
shape=(),
minval=tf.math.minimum(height * 0.5, width * 0.3),
maxval=tf.math.maximum(height * 0.5, width * 0.3) + 1,
dtype=tf.float32,
)
rectangle_side_len = tf.cast((ratio) * unit_size, tf.float32)
# sample x and y offset for grid units randomly between 0 and unit_size
delta_x = self._random_generator.random_uniform(
shape=(), minval=0.0, maxval=unit_size, dtype=tf.float32
)
delta_y = self._random_generator.random_uniform(
shape=(), minval=0.0, maxval=unit_size, dtype=tf.float32
)
# grid size (number of diagonal units in grid)
grid_size = mask_side_len // unit_size + 1
grid_size_range = tf.range(1, grid_size + 1)
# diagonal corner coordinates
unit_size_range = grid_size_range * unit_size
x1 = unit_size_range - delta_x
x0 = x1 - rectangle_side_len
y1 = unit_size_range - delta_y
y0 = y1 - rectangle_side_len
# compute grid coordinates
x0, y0 = tf.meshgrid(x0, y0)
x1, y1 = tf.meshgrid(x1, y1)
# flatten mesh grid
x0 = tf.reshape(x0, [-1])
y0 = tf.reshape(y0, [-1])
x1 = tf.reshape(x1, [-1])
y1 = tf.reshape(y1, [-1])
# convert coordinates to mask
corners = tf.stack([x0, y0, x1, y1], axis=-1)
mask_side_len = tf.cast(mask_side_len, tf.int32)
rectangle_masks = fill_utils.corners_to_mask(
corners, mask_shape=(mask_side_len, mask_side_len)
)
grid_mask = tf.reduce_any(rectangle_masks, axis=0)
return grid_mask
def augment_image(self, image, transformation=None, **kwargs):
mask, fill_value = transformation
input_shape = tf.shape(image)
# center crop mask
input_height = input_shape[0]
input_width = input_shape[1]
mask = _center_crop(mask, input_width, input_height)
# convert back to boolean mask
mask = tf.cast(mask, tf.bool)
return tf.where(mask, fill_value, image)
def augment_bounding_boxes(self, bounding_boxes, **kwargs):
return bounding_boxes
def augment_label(self, label, transformation=None, **kwargs):
return label
def augment_segmentation_mask(self, segmentation_mask, transformation, **kwargs):
return segmentation_mask
def get_config(self):
config = {
"ratio_factor": self.ratio_factor,
"rotation_factor": self.rotation_factor,
"fill_mode": self.fill_mode,
"fill_value": self.fill_value,
"seed": self.seed,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))