diff --git a/kornia/augmentation/_2d/mix/base.py b/kornia/augmentation/_2d/mix/base.py index ec0b37448e..4594cbc1ce 100644 --- a/kornia/augmentation/_2d/mix/base.py +++ b/kornia/augmentation/_2d/mix/base.py @@ -24,7 +24,7 @@ class MixAugmentationBaseV2(_BasicAugmentationBase): keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. data_keys: the input type sequential for applying augmentations. - Accepts "input", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". + Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". """ def __init__( diff --git a/kornia/augmentation/_2d/mix/jigsaw.py b/kornia/augmentation/_2d/mix/jigsaw.py index 814152c997..a6fd23fb39 100644 --- a/kornia/augmentation/_2d/mix/jigsaw.py +++ b/kornia/augmentation/_2d/mix/jigsaw.py @@ -24,7 +24,7 @@ class RandomJigsaw(MixAugmentationBaseV2): ensure_perm: to ensure the nonidentical patch permutation generation against the original one. data_keys: the input type sequential for applying augmentations. - Accepts "input", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". + Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". p: probability of applying the transformation for the whole batch. same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it diff --git a/kornia/augmentation/_2d/mix/mosaic.py b/kornia/augmentation/_2d/mix/mosaic.py index aa1b3b22c0..08ceb10af7 100644 --- a/kornia/augmentation/_2d/mix/mosaic.py +++ b/kornia/augmentation/_2d/mix/mosaic.py @@ -35,7 +35,7 @@ class RandomMosaic(MixAugmentationBaseV2): each output will mix 4 images in a 2x2 grid. min_bbox_size: minimum area of bounding boxes. Default to 0. data_keys: the input type sequential for applying augmentations. - Accepts "input", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". + Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". p: probability of applying the transformation for the whole batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index 606b32a8bd..383b36f7d0 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -24,8 +24,8 @@ class AugmentationSequential(ImageSequential): Args: *args: a list of kornia augmentation modules. - data_keys: the input type sequential for applying augmentations. Accepts "input", "mask", "bbox", "bbox_xyxy", - "bbox_xywh", "keypoints". + data_keys: the input type sequential for applying augmentations. Accepts "input", "image", "mask", + "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". same_on_batch: apply the same transformation across the batch. If None, it will not overwrite the function-wise settings. @@ -52,7 +52,7 @@ class AugmentationSequential(ImageSequential): strategies. .. note:: - Mix augmentations (e.g. RandomMixUp, RandomCutMix) can only be working with "input" data key. + Mix augmentations (e.g. RandomMixUp, RandomCutMix) can only be working with "input"/"image" data key. It is not clear how to deal with the conversions of masks, bounding boxes and keypoints. .. note:: diff --git a/kornia/constants.py b/kornia/constants.py index 85dc6e96cc..4dbfa5bdc5 100644 --- a/kornia/constants.py +++ b/kornia/constants.py @@ -123,6 +123,7 @@ def to_torch(cls, value: TKEnum['DType']) -> torch.dtype: # TODO: (low-priority) add INPUT3D, MASK3D, BBOX3D, LAFs etc. class DataKey(Enum, metaclass=_KORNIA_EnumMeta): + IMAGE = 0 INPUT = 0 MASK = 1 BBOX = 2 diff --git a/test/augmentation/test_container.py b/test/augmentation/test_container.py index 1ebd38d72f..2a5897f7d1 100644 --- a/test/augmentation/test_container.py +++ b/test/augmentation/test_container.py @@ -219,7 +219,7 @@ def test_forward(self, random_apply, device, dtype): class TestAugmentationSequential: @pytest.mark.parametrize( - 'data_keys', ["input", ["mask", "input"], ["input", "bbox_yxyx"], [0, 10], [BorderType.REFLECT]] + 'data_keys', ["input", "image", ["mask", "input"], ["input", "bbox_yxyx"], [0, 10], [BorderType.REFLECT]] ) @pytest.mark.parametrize("augmentation_list", [K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)]) def test_exception(self, augmentation_list, data_keys, device, dtype): @@ -292,7 +292,7 @@ def test_random_flips(self, device, dtype): ) aug_hor = K.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=["input", "bbox"], same_on_batch=False + K.RandomHorizontalFlip(p=1.0), data_keys=["image", "bbox"], same_on_batch=False ) out_ver = aug_ver(inp.clone(), bbox.clone()) @@ -371,7 +371,7 @@ def test_random_crops_and_flips(self, device, dtype): def test_random_erasing(self, device, dtype): fill_value = 0.5 input = torch.randn(3, 3, 100, 100, device=device, dtype=dtype) - aug = K.AugmentationSequential(K.RandomErasing(p=1.0, value=fill_value), data_keys=["input", "mask"]) + aug = K.AugmentationSequential(K.RandomErasing(p=1.0, value=fill_value), data_keys=["image", "mask"]) reproducibility_test((input, input), aug)