In [None]:
# Dependencies.
import parameterized_transforms.core as ptc

import numpy as np

import typing as t

In [None]:
class RandomColorErasing(ptc.AtomicTransform):

    def __init__(
        self,
        tx_mode: ptc.TransformMode = ptc.TransformMode.CASCADE,
        default_params_mode: ptc.DefaultParamsMode = ptc.DefaultParamsMode.RANDOMIZED,
    ) -> None:
        """The initializer.

        :param tx_mode: The mode of the transform.
            DEFAULT: `ptc.TransformMode.CASCADE`.
            SUPPORTS: `ptc.TransformMode` enum type.
        :param default_params_mode: The mode of the default (identity) params.
            DEFAULT: `ptc.DefaultParamsMode.RANDOMIZED`.
            SUPPORTS: `ptc.DefaultParamsMode` enum type.
        """

        super(RandomColorErasing, self).__init__()

        self.tx_mode = tx_mode  # Alternatively, pass tx_mode=tx_mode to the above super call.
        self.default_params_mode = default_params_mode

        self.param_count = self.set_param_count()  # Recommended.

    def set_param_count(self) -> int:
        """Returns the number of parameters for the transform.

        :return: The number of parameters.

        We encode the rectangle to be erased using 4 integers `i, j, h, w`.
        The tuple `(i, j)` encodes the top-left point of the crop
            (along the height and the width respectively).
        The integers `h, w` are the crop height and crop width.
        We encode the fill-color in terms of 3 integers in range `[0, 255]`
            representing the `r, g, b` components.

        Thus, in total, there are 7 parameters.
        """
        return 7

    def get_raw_params(self, img: ptc.IMAGE_TYPE) -> ptc.PARAM_TYPE:
        """Get raw parameters to decide which rectangle to erase and what
        should be the fill value.

        :param img: The input image.

        :return: The parameters for the rectangle to be erased and the fill value.
        """

        # Get a randomly sampled fill color.
        fill_color = np.random.randint(
            low=0, high=256, size=[3, ]
        ).astype(np.uint8)

        # Get the image dimensions.
        W, H = tv_fn.get_image_size(img=img)
        # Select two heights for the crop.
        i1, i2 = np.random.randint(low=0, high=H, size=[2, ]).tolist()
        j1, j2 = np.random.randint(low=0, high=W, size=[2, ]).tolist()
        # Get the location and size of the rectangle to be erased.
        i, h = min(i1, i2), max(abs(i2 - i1), 1)
        j, w = min(j1, j2), max(abs(j2 - j1), 1)

        # Return the tuple of raw parameters.
        return (i, j, h, w, fill_color)

    def apply_transform(
            self, img: ptc.IMAGE_TYPE, params: ptc.PARAM_TYPE, **kwargs
    ) -> ptc.IMAGE_TYPE:
        """Augments given data point using given parameters.

        :param img: The data point to be augmented.
        :param params: The raw parameters to be used for augmentation.

        :return: The augmented image.
        """

        # Unpack the raw parameters.
        i, j, h, w, fill_color = params

        # Conver the `PIL` image to `numpy`.
        img_np = np.asarray(img).astype(np.uint8)
        # Replace the rectangular region with the sampled color.
        img_np[i: (i + h), j: (j + w), :] = fill_color

        # Convert this processed image back to a `PIL.image`
        aug = Image.fromarray(img_np)

        return aug

    def post_process_params(
        self, img: ptc.IMAGE_TYPE, params: ptc.PARAM_TYPE
    ) -> ptc.PARAM_TYPE:
        """Post-processes the parameters of augmentations before outputting.

        :param img: The data point to be augmented.
        :param params: The raw local parameters to be post-processed.

        :return: The post-processed parameters.
        """

        # Unpack all parameters into a tuple of scalars.
        i, j, h, w, fill_color = params
        r, g, b = fill_color.tolist()

        return (i, j, h, w, r, g, b)

    def extract_params(
        self, params: ptc.PARAM_TYPE
    ) -> t.Tuple[ptc.PARAM_TYPE, ptc.PARAM_TYPE]:
        """Chunks the input parameters into two sets; the first required for
        the augmentation of the current data and the second to pass on to the
        next augmentations.

        :param params: The parameters remaining from the augmentations so far.

        :return: The tuple of the local and subsequent parameters.
        """

        return params[: self.param_count], params[self.param_count :]

    def pre_process_params(self, img: ptc.IMAGE_TYPE, params: ptc.PARAM_TYPE) -> ptc.PARAM_TYPE:
        """Pre-processes the parameters of augmentations after inputting.

        :param img: The data point to be augmented.
        :param params: The parameters from which to extract local parameters.

        :return: The pre-processed parameters ready for their usage.
        """

        # Unpack all parameters.
        i, j, h, w, r, g, b = params
        # Return the parameters as required for applying the transform.
        fill_color = np.array([r, g, b]).astype(np.uint8)

        # Return the raw parameters.
        return (i, j, h, w, fill_color)

    def get_default_params(self, img: ptc.IMAGE_TYPE, processed: bool = True) -> ptc.PARAM_TYPE:
        """Returns the parameters for preserving the input data information.

        :param img: The data point to be augmented.
        :param processed: Whether we want the processed default parameters.

        :return: The no-augmentation params for the class.
        """

        # If the default parameters are required to be unique, ...
        if self.default_params_mode == ptc.DefaultParamsMode.UNIQUE:

            # Define the rectangle to erase at the origin, with size 0.
            i, j, h, w = 0, 0, 0, 0
            # Define the fill color to be `(r, g, b) = (0, 0, 0)`.
            fill_color = np.zeros([3, ]).astype(np.uint8)

            raw_params = (i, j, h, w, fill_color)

        # Else, ...
        else:  # `self.default_params_mode == ptc.DefaultParamsMode.RANDOMIZED`

            # Select any point of the image as the rectangle location.
            W, H = tv_fn.get_image_size(img=img)
            i = int(np.random.randint(low=0, high=H))
            j = int(np.random.randint(low=0, high=W))
            # Set the size of the rectangle to 0.
            h, w = 0, 0

            # Select any fill color.
            fill_color = np.random.randint(
                low=0, high=256, size=[3, ]
            ).astype(np.uint8)

            raw_params = (i, j, h, w, fill_color)

        return (
            self.post_process_params(img=img, params=raw_params)
            if processed
            else raw_params
        )

    def __str__(self) -> str:
        """Defines the string representation of the transform.

        :return: The string representation of the transform.
        """
        return (
            f"RandomColorErasing("
            f"param_count={self.param_count}, "
            f"tx_mode={self.tx_mode}, "
            f"default_params_mode={self.default_params_mode}"
            f")"
        )