From ce0b8a00659d4aabdcfca8d1535bdecb5e9c2520 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Wed, 28 Oct 2020 16:36:25 +0800 Subject: [PATCH] add unittest --- mmaction/datasets/pipelines/__init__.py | 10 +-- mmaction/datasets/pipelines/augmentations.py | 14 ++--- tests/test_data/test_augmentations.py | 65 ++++++++++++++++++-- 3 files changed, 71 insertions(+), 18 deletions(-) diff --git a/mmaction/datasets/pipelines/__init__.py b/mmaction/datasets/pipelines/__init__.py index 71c5aa3c30..78b8a91744 100644 --- a/mmaction/datasets/pipelines/__init__.py +++ b/mmaction/datasets/pipelines/__init__.py @@ -1,8 +1,8 @@ -from .augmentations import (AudioAmplify, BoxClip, BoxCrop, BoxFlip, BoxPad, BoxRescale, - CenterCrop, ColorJitter, Flip, Fuse, MelSpectrogram, - MultiGroupCrop, MultiScaleCrop, Normalize, - RandomCrop, RandomResizedCrop, RandomScale, Resize, - TenCrop, ThreeCrop) +from .augmentations import (AudioAmplify, BoxClip, BoxCrop, BoxFlip, BoxPad, + BoxRescale, CenterCrop, ColorJitter, Flip, Fuse, + MelSpectrogram, MultiGroupCrop, MultiScaleCrop, + Normalize, RandomCrop, RandomResizedCrop, + RandomScale, Resize, TenCrop, ThreeCrop) from .compose import Compose from .formating import (Collect, FormatAudioShape, FormatShape, ImageToTensor, ToDataContainer, ToTensor, Transpose) diff --git a/mmaction/datasets/pipelines/augmentations.py b/mmaction/datasets/pipelines/augmentations.py index 6e5330b782..a2209f6e45 100644 --- a/mmaction/datasets/pipelines/augmentations.py +++ b/mmaction/datasets/pipelines/augmentations.py @@ -87,7 +87,7 @@ def __call__(self, results): @PIPELINES.register_module() -class RandomScale(object): +class RandomScale: def __init__(self, scales, mode='range', **kwargs): self.mode = mode @@ -135,7 +135,7 @@ def __repr__(self): @PIPELINES.register_module() -class BoxRescale(object): +class BoxRescale: def __call__(self, results): img_h, img_w = results['img_shape'] @@ -167,7 +167,7 @@ def __call__(self, results): @PIPELINES.register_module() -class BoxCrop(object): +class BoxCrop: def __call__(self, results): proposals = results['proposals'] @@ -194,7 +194,7 @@ def __call__(self, results): @PIPELINES.register_module() -class BoxFlip(object): +class BoxFlip: _directions = ['horizontal', 'vertical'] def __init__(self, direction='horizontal'): @@ -243,7 +243,7 @@ def __repr__(self): @PIPELINES.register_module() -class BoxClip(object): +class BoxClip: def __call__(self, results): proposals = results['proposals'] @@ -262,7 +262,7 @@ def __call__(self, results): @PIPELINES.register_module() -class BoxPad(object): +class BoxPad: def __init__(self, max_num_gts=None): self.max_num_gts = max_num_gts @@ -285,7 +285,7 @@ def __call__(self, results): padded_proposals = None results['proposals'] = padded_proposals - results['ann']['entity_boxes'] = entity_boxes + results['ann']['entity_boxes'] = padded_entity_boxes return results def __repr__(self): diff --git a/tests/test_data/test_augmentations.py b/tests/test_data/test_augmentations.py index 9ed1d7253a..94e74798ce 100644 --- a/tests/test_data/test_augmentations.py +++ b/tests/test_data/test_augmentations.py @@ -7,8 +7,8 @@ # yapf: disable from mmaction.datasets.pipelines import (AudioAmplify, BoxClip, BoxCrop, - BoxFlip, BoxRescale, CenterCrop, - ColorJitter, Flip, Fuse, + BoxFlip, BoxPad, BoxRescale, + CenterCrop, ColorJitter, Flip, Fuse, MelSpectrogram, MultiGroupCrop, MultiScaleCrop, Normalize, RandomCrop, RandomResizedCrop, RandomScale, @@ -1188,10 +1188,63 @@ def test_box_clip(self): target_keys = ['ann', 'proposals', 'img_shape'] results = dict( proposals=np.array([[-9.304, -9.688001, 207.079995, 333.928002]]), - img_shape=(520, 480), + img_shape=(335, 210), ann=dict( entity_boxes=np.array( [[-2.584, -7.608002, 212.120004, 338.920019]]))) - BoxClip() - assert target_keys - assert results + + box_clip = BoxClip() + results_ = copy.deepcopy(results) + results_ = box_clip(results_) + + self.check_keys_contain(results_.keys(), target_keys) + assert_array_equal(results_['ann']['entity_boxes'], + np.array([[0., 0., 209., 334.]])) + assert_array_equal(results_['proposals'], + np.array([[0., 0., 207.079995, 333.928002]])) + + results_ = copy.deepcopy(results) + results_['proposals'] = None + results_ = box_clip(results_) + assert results_['proposals'] is None + + def test_box_pad(self): + target_keys = ['ann', 'proposals', 'img_shape'] + results = dict( + proposals=np.array([[-9.304, -9.688001, 207.079995, 333.928002], + [-2.584, -7.608002, 212.120004, 338.920019]]), + img_shape=(335, 210), + ann=dict( + entity_boxes=np.array([[ + -2.584, -7.608002, 212.120004, 338.920019 + ], [-9.304, -9.688001, 207.079995, 333.928002]]))) + + box_pad_none = BoxPad() + results_ = copy.deepcopy(results) + results_ = box_pad_none(results_) + self.check_keys_contain(results_.keys(), target_keys) + assert_array_equal(results_['proposals'], results['proposals']) + assert_array_equal(results_['ann']['entity_boxes'], + results['ann']['entity_boxes']) + + box_pad = BoxPad(3) + results_ = copy.deepcopy(results) + results_ = box_pad(results_) + self.check_keys_contain(results_.keys(), target_keys) + assert_array_equal( + results_['proposals'], + np.array([[-9.304, -9.688001, 207.079995, 333.928002], + [-2.584, -7.608002, 212.120004, 338.920019], + [0., 0., 0., 0.]], + dtype=np.float32)) + assert_array_equal( + results_['ann']['entity_boxes'], + np.array([[-2.584, -7.608002, 212.120004, 338.920019], + [-9.304, -9.688001, 207.079995, 333.928002], + [0., 0., 0., 0.]], + dtype=np.float32)) + + results_ = copy.deepcopy(results) + results_['proposals'] = None + results_ = box_pad(results_) + assert results_['proposals'] is None