Skip to content

Commit

Permalink
DataKey: add 'image' as alias of 'input' (#2193)
Browse files Browse the repository at this point in the history
* DataKey: add 'image' as alias of 'input'

* Document new 'image' key

* Test image in addition to input

* flake8 fix
  • Loading branch information
adamjstewart committed Feb 11, 2023
1 parent 7c45aff commit e6e17b6
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion kornia/augmentation/_2d/mix/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class MixAugmentationBaseV2(_BasicAugmentationBase):
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it
to the batch form ``False``.
data_keys: the input type sequential for applying augmentations.
Accepts "input", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion kornia/augmentation/_2d/mix/jigsaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class RandomJigsaw(MixAugmentationBaseV2):
ensure_perm: to ensure the nonidentical patch permutation generation against
the original one.
data_keys: the input type sequential for applying augmentations.
Accepts "input", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
p: probability of applying the transformation for the whole batch.
same_on_batch: apply the same transformation across the batch.
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it
Expand Down
2 changes: 1 addition & 1 deletion kornia/augmentation/_2d/mix/mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class RandomMosaic(MixAugmentationBaseV2):
each output will mix 4 images in a 2x2 grid.
min_bbox_size: minimum area of bounding boxes. Default to 0.
data_keys: the input type sequential for applying augmentations.
Accepts "input", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
p: probability of applying the transformation for the whole batch.
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it
to the batch form ``False``.
Expand Down
6 changes: 3 additions & 3 deletions kornia/augmentation/container/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class AugmentationSequential(ImageSequential):
Args:
*args: a list of kornia augmentation modules.
data_keys: the input type sequential for applying augmentations. Accepts "input", "mask", "bbox", "bbox_xyxy",
"bbox_xywh", "keypoints".
data_keys: the input type sequential for applying augmentations. Accepts "input", "image", "mask",
"bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
same_on_batch: apply the same transformation across the batch. If None, it will not overwrite the function-wise
settings.
Expand All @@ -52,7 +52,7 @@ class AugmentationSequential(ImageSequential):
strategies.
.. note::
Mix augmentations (e.g. RandomMixUp, RandomCutMix) can only be working with "input" data key.
Mix augmentations (e.g. RandomMixUp, RandomCutMix) can only be working with "input"/"image" data key.
It is not clear how to deal with the conversions of masks, bounding boxes and keypoints.
.. note::
Expand Down
1 change: 1 addition & 0 deletions kornia/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def to_torch(cls, value: TKEnum['DType']) -> torch.dtype:

# TODO: (low-priority) add INPUT3D, MASK3D, BBOX3D, LAFs etc.
class DataKey(Enum, metaclass=_KORNIA_EnumMeta):
IMAGE = 0
INPUT = 0
MASK = 1
BBOX = 2
Expand Down
6 changes: 3 additions & 3 deletions test/augmentation/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_forward(self, random_apply, device, dtype):

class TestAugmentationSequential:
@pytest.mark.parametrize(
'data_keys', ["input", ["mask", "input"], ["input", "bbox_yxyx"], [0, 10], [BorderType.REFLECT]]
'data_keys', ["input", "image", ["mask", "input"], ["input", "bbox_yxyx"], [0, 10], [BorderType.REFLECT]]
)
@pytest.mark.parametrize("augmentation_list", [K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)])
def test_exception(self, augmentation_list, data_keys, device, dtype):
Expand Down Expand Up @@ -292,7 +292,7 @@ def test_random_flips(self, device, dtype):
)

aug_hor = K.AugmentationSequential(
K.RandomHorizontalFlip(p=1.0), data_keys=["input", "bbox"], same_on_batch=False
K.RandomHorizontalFlip(p=1.0), data_keys=["image", "bbox"], same_on_batch=False
)

out_ver = aug_ver(inp.clone(), bbox.clone())
Expand Down Expand Up @@ -371,7 +371,7 @@ def test_random_crops_and_flips(self, device, dtype):
def test_random_erasing(self, device, dtype):
fill_value = 0.5
input = torch.randn(3, 3, 100, 100, device=device, dtype=dtype)
aug = K.AugmentationSequential(K.RandomErasing(p=1.0, value=fill_value), data_keys=["input", "mask"])
aug = K.AugmentationSequential(K.RandomErasing(p=1.0, value=fill_value), data_keys=["image", "mask"])

reproducibility_test((input, input), aug)

Expand Down

0 comments on commit e6e17b6

Please sign in to comment.