From da4125587e1977eaef4e0dc9a91d9011d041e90a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Fri, 30 Dec 2022 13:46:52 +0800 Subject: [PATCH] [Refactor] Support TTA (#2184) * tta init * use mmcv transform * test city * add multiscale * fix merge * add softmax to post process * add ut * add tta pipeline to other datasets * remove softmax * add encoder_decoder_tta ut * add encoder_decoder_tta ut * rename * rename file * rename config * rm aug_test * move flip to post process * fix channel --- configs/_base_/datasets/ade20k.py | 16 ++ configs/_base_/datasets/ade20k_640x640.py | 16 ++ configs/_base_/datasets/chase_db1.py | 16 ++ configs/_base_/datasets/cityscapes.py | 16 ++ configs/_base_/datasets/coco-stuff10k.py | 16 ++ configs/_base_/datasets/coco-stuff164k.py | 16 ++ configs/_base_/datasets/drive.py | 16 ++ configs/_base_/datasets/hrf.py | 16 ++ configs/_base_/datasets/isaid.py | 16 ++ configs/_base_/datasets/loveda.py | 16 ++ configs/_base_/datasets/pascal_context_59.py | 16 ++ configs/_base_/datasets/pascal_voc12.py | 16 ++ configs/_base_/datasets/pascal_voc12_aug.py | 17 +- configs/_base_/datasets/potsdam.py | 16 ++ configs/_base_/datasets/stare.py | 16 ++ configs/_base_/datasets/vaihingen.py | 16 ++ configs/_base_/default_runtime.py | 2 + mmseg/models/segmentors/__init__.py | 5 +- mmseg/models/segmentors/base.py | 14 +- mmseg/models/segmentors/seg_tta.py | 48 ++++ tests/test_datasets/test_tta.py | 256 ++++++++---------- .../test_segmentors/test_seg_tta_model.py | 60 ++++ tests/test_models/test_segmentors/utils.py | 4 +- tools/test.py | 7 + 24 files changed, 506 insertions(+), 147 deletions(-) create mode 100644 mmseg/models/segmentors/seg_tta.py create mode 100644 tests/test_models/test_segmentors/test_seg_tta_model.py diff --git a/configs/_base_/datasets/ade20k.py b/configs/_base_/datasets/ade20k.py index 4303b094c5..5840fc17ec 100644 --- a/configs/_base_/datasets/ade20k.py +++ b/configs/_base_/datasets/ade20k.py @@ -23,6 +23,22 @@ dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/ade20k_640x640.py b/configs/_base_/datasets/ade20k_640x640.py index 8478585915..998b06e15b 100644 --- a/configs/_base_/datasets/ade20k_640x640.py +++ b/configs/_base_/datasets/ade20k_640x640.py @@ -23,6 +23,22 @@ dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/chase_db1.py b/configs/_base_/datasets/chase_db1.py index 8cd4f3c284..07604b4d5a 100644 --- a/configs/_base_/datasets/chase_db1.py +++ b/configs/_base_/datasets/chase_db1.py @@ -24,6 +24,22 @@ dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, diff --git a/configs/_base_/datasets/cityscapes.py b/configs/_base_/datasets/cityscapes.py index c2fdee473b..1698e04721 100644 --- a/configs/_base_/datasets/cityscapes.py +++ b/configs/_base_/datasets/cityscapes.py @@ -23,6 +23,22 @@ dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=2, num_workers=2, diff --git a/configs/_base_/datasets/coco-stuff10k.py b/configs/_base_/datasets/coco-stuff10k.py index b00db24691..0c2d55208e 100644 --- a/configs/_base_/datasets/coco-stuff10k.py +++ b/configs/_base_/datasets/coco-stuff10k.py @@ -23,6 +23,22 @@ dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/coco-stuff164k.py b/configs/_base_/datasets/coco-stuff164k.py index e879bdb2aa..f77a0fd65a 100644 --- a/configs/_base_/datasets/coco-stuff164k.py +++ b/configs/_base_/datasets/coco-stuff164k.py @@ -23,6 +23,22 @@ dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/drive.py b/configs/_base_/datasets/drive.py index 248dc8b102..c6242acdb8 100644 --- a/configs/_base_/datasets/drive.py +++ b/configs/_base_/datasets/drive.py @@ -24,6 +24,22 @@ dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/hrf.py b/configs/_base_/datasets/hrf.py index 11b66e7d52..c2fe84f170 100644 --- a/configs/_base_/datasets/hrf.py +++ b/configs/_base_/datasets/hrf.py @@ -24,6 +24,22 @@ dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/isaid.py b/configs/_base_/datasets/isaid.py index 8dafae8fd4..65e256c56d 100644 --- a/configs/_base_/datasets/isaid.py +++ b/configs/_base_/datasets/isaid.py @@ -30,6 +30,22 @@ dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/loveda.py b/configs/_base_/datasets/loveda.py index fcdb05865e..d69bdafceb 100644 --- a/configs/_base_/datasets/loveda.py +++ b/configs/_base_/datasets/loveda.py @@ -23,6 +23,22 @@ dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/pascal_context_59.py b/configs/_base_/datasets/pascal_context_59.py index 9103fe7e3f..0ca02cc94b 100644 --- a/configs/_base_/datasets/pascal_context_59.py +++ b/configs/_base_/datasets/pascal_context_59.py @@ -26,6 +26,22 @@ dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/pascal_voc12.py b/configs/_base_/datasets/pascal_voc12.py index aeb38d0613..8b4b77c2f9 100644 --- a/configs/_base_/datasets/pascal_voc12.py +++ b/configs/_base_/datasets/pascal_voc12.py @@ -23,6 +23,22 @@ dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/pascal_voc12_aug.py b/configs/_base_/datasets/pascal_voc12_aug.py index cd0d3e8682..495595cdfb 100644 --- a/configs/_base_/datasets/pascal_voc12_aug.py +++ b/configs/_base_/datasets/pascal_voc12_aug.py @@ -25,7 +25,22 @@ dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] - +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] dataset_train = dict( type=dataset_type, data_root=data_root, diff --git a/configs/_base_/datasets/potsdam.py b/configs/_base_/datasets/potsdam.py index ef9761c76e..1f4b95df2e 100644 --- a/configs/_base_/datasets/potsdam.py +++ b/configs/_base_/datasets/potsdam.py @@ -23,6 +23,22 @@ dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/stare.py b/configs/_base_/datasets/stare.py index 7fccd71a54..cd12740b2e 100644 --- a/configs/_base_/datasets/stare.py +++ b/configs/_base_/datasets/stare.py @@ -24,6 +24,22 @@ dict(type='LoadAnnotations'), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/datasets/vaihingen.py b/configs/_base_/datasets/vaihingen.py index 2b52135567..ca0ad7915e 100644 --- a/configs/_base_/datasets/vaihingen.py +++ b/configs/_base_/datasets/vaihingen.py @@ -23,6 +23,22 @@ dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] + ]) +] train_dataloader = dict( batch_size=4, num_workers=4, diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index e9fa6e1035..272b4d2467 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -11,3 +11,5 @@ log_level = 'INFO' load_from = None resume = False + +tta_model = dict(type='SegTTAModel') diff --git a/mmseg/models/segmentors/__init__.py b/mmseg/models/segmentors/__init__.py index 387c858bd7..fec0d52c3a 100644 --- a/mmseg/models/segmentors/__init__.py +++ b/mmseg/models/segmentors/__init__.py @@ -2,5 +2,8 @@ from .base import BaseSegmentor from .cascade_encoder_decoder import CascadeEncoderDecoder from .encoder_decoder import EncoderDecoder +from .seg_tta import SegTTAModel -__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder'] +__all__ = [ + 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel' +] diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 1625addf6c..d9ffeceb39 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -124,11 +124,6 @@ def _forward(self, """ pass - @abstractmethod - def aug_test(self, batch_inputs, batch_img_metas): - """Placeholder for augmentation test.""" - pass - def postprocess_result(self, seg_logits: Tensor, data_samples: OptSampleList = None) -> list: @@ -170,6 +165,15 @@ def postprocess_result(self, padding_top:H - padding_bottom, padding_left:W - padding_right] + flip = img_meta.get('flip', None) + if flip: + flip_direction = img_meta.get('flip_direction', None) + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + i_seg_logits = i_seg_logits.flip(dims=(3, )) + else: + i_seg_logits = i_seg_logits.flip(dims=(2, )) + # resize as original shape i_seg_logits = resize( i_seg_logits, diff --git a/mmseg/models/segmentors/seg_tta.py b/mmseg/models/segmentors/seg_tta.py new file mode 100644 index 0000000000..eacb6c00a9 --- /dev/null +++ b/mmseg/models/segmentors/seg_tta.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from mmengine.model import BaseTTAModel +from mmengine.structures import PixelData + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import SampleList + + +@MODELS.register_module() +class SegTTAModel(BaseTTAModel): + + def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList: + """Merge predictions of enhanced data to one prediction. + + Args: + data_samples_list (List[SampleList]): List of predictions + of all enhanced data. + + Returns: + SampleList: Merged prediction. + """ + predictions = [] + for data_samples in data_samples_list: + seg_logits = data_samples[0].seg_logits.data + logits = torch.zeros(seg_logits.shape).to(seg_logits) + for data_sample in data_samples: + seg_logit = data_sample.seg_logits.data + if self.module.out_channels > 1: + logits += seg_logit.softmax(dim=0) + else: + logits += seg_logit.sigmoid() + logits /= len(data_samples) + if self.module.out_channels == 1: + seg_pred = (logits > self.module.decode_head.threshold + ).to(logits).squeeze(1) + else: + seg_pred = logits.argmax(dim=0) + data_sample = SegDataSample( + **{ + 'pred_sem_seg': PixelData(data=seg_pred), + 'gt_sem_seg': data_samples[0].gt_sem_seg + }) + predictions.append(data_sample) + return predictions diff --git a/tests/test_datasets/test_tta.py b/tests/test_datasets/test_tta.py index 6a433647a8..25b1ecdb53 100644 --- a/tests/test_datasets/test_tta.py +++ b/tests/test_datasets/test_tta.py @@ -1,151 +1,131 @@ # Copyright (c) OpenMMLab. All rights reserved. -# import os.path as osp +import os.path as osp -# import mmcv -# import pytest +import mmcv +import pytest -# from mmseg.datasets.transforms import * # noqa -# from mmseg.registry import TRANSFORMS +from mmseg.datasets.transforms import * # noqa +from mmseg.registry import TRANSFORMS -# TODO -# def test_multi_scale_flip_aug(): -# # test assertion if scales=None, scale_factor=1 (not float). -# with pytest.raises(AssertionError): -# tta_transform = dict( -# type='MultiScaleFlipAug', -# scales=None, -# scale_factor=1, -# transforms=[dict(type='Resize', keep_ratio=False)], -# ) -# TRANSFORMS.build(tta_transform) -# # test assertion if scales=None, scale_factor=None. -# with pytest.raises(AssertionError): -# tta_transform = dict( -# type='MultiScaleFlipAug', -# scales=None, -# scale_factor=None, -# transforms=[dict(type='Resize', keep_ratio=False)], -# ) -# TRANSFORMS.build(tta_transform) +def test_multi_scale_flip_aug(): + # test exception + with pytest.raises(TypeError): + tta_transform = dict( + type='TestTimeAug', + transforms=[dict(type='Resize', keep_ratio=False)], + ) + TRANSFORMS.build(tta_transform) -# # test assertion if scales=(512, 512), scale_factor=1 (not float). -# with pytest.raises(AssertionError): -# tta_transform = dict( -# type='MultiScaleFlipAug', -# scales=(512, 512), -# scale_factor=1, -# transforms=[dict(type='Resize', keep_ratio=False)], -# ) -# TRANSFORMS.build(tta_transform) -# meta_keys = ('img', 'ori_shape', 'ori_height', 'ori_width', 'pad_shape', -# 'scale_factor', 'scale', 'flip') -# tta_transform = dict( -# type='MultiScaleFlipAug', -# scales=[(256, 256), (512, 512), (1024, 1024)], -# allow_flip=False, -# resize_cfg=dict(type='Resize', keep_ratio=False), -# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], -# ) -# tta_module = TRANSFORMS.build(tta_transform) + tta_transform = dict( + type='TestTimeAug', + transforms=[[ + dict(type='Resize', scale=scale, keep_ratio=False) + for scale in [(256, 256), (512, 512), (1024, 1024)] + ], [dict(type='mmseg.PackSegInputs')]]) + tta_module = TRANSFORMS.build(tta_transform) -# results = dict() -# # (288, 512, 3) -# img = mmcv.imread( -# osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') -# results['img'] = img -# results['ori_shape'] = img.shape -# results['ori_height'] = img.shape[0] -# results['ori_width'] = img.shape[1] -# # Set initial values for default meta_keys -# results['pad_shape'] = img.shape -# results['scale_factor'] = 1.0 + results = dict() + # (288, 512, 3) + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + results['img'] = img + results['ori_shape'] = img.shape + results['ori_height'] = img.shape[0] + results['ori_width'] = img.shape[1] + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 -# tta_results = tta_module(results.copy()) -# assert [data_sample.scale -# for data_sample in tta_results['data_sample']] == [(256, 256), -# (512, 512), -# (1024, 1024)] -# assert [data_sample.flip for data_sample in tta_results['data_sample'] -# ] == [False, False, False] + tta_results = tta_module(results.copy()) + assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256), + (3, 512, 512), + (3, 1024, 1024)] -# tta_transform = dict( -# type='MultiScaleFlipAug', -# scales=[(256, 256), (512, 512), (1024, 1024)], -# allow_flip=True, -# resize_cfg=dict(type='Resize', keep_ratio=False), -# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], -# ) -# tta_module = TRANSFORMS.build(tta_transform) -# tta_results = tta_module(results.copy()) -# assert [data_sample.scale -# for data_sample in tta_results['data_sample']] == [(256, 256), -# (256, 256), -# (512, 512), -# (512, 512), -# (1024, 1024), -# (1024, 1024)] -# assert [data_sample.flip for data_sample in tta_results['data_sample'] -# ] == [False, True, False, True, False, True] + tta_transform = dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale=scale, keep_ratio=False) + for scale in [(256, 256), (512, 512), (1024, 1024)] + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='mmseg.PackSegInputs')] + ]) + tta_module = TRANSFORMS.build(tta_transform) + tta_results: dict = tta_module(results.copy()) + assert [img.shape for img in tta_results['inputs']] == [(3, 256, 256), + (3, 256, 256), + (3, 512, 512), + (3, 512, 512), + (3, 1024, 1024), + (3, 1024, 1024)] + assert [ + data_sample.metainfo['flip'] + for data_sample in tta_results['data_samples'] + ] == [False, True, False, True, False, True] -# tta_transform = dict( -# type='MultiScaleFlipAug', -# scales=[(512, 512)], -# allow_flip=False, -# resize_cfg=dict(type='Resize', keep_ratio=False), -# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], -# ) -# tta_module = TRANSFORMS.build(tta_transform) -# tta_results = tta_module(results.copy()) -# assert [tta_results['data_sample'][0].scale] == [(512, 512)] -# assert [tta_results['data_sample'][0].flip] == [False] + tta_transform = dict( + type='TestTimeAug', + transforms=[[dict(type='Resize', scale=(512, 512), keep_ratio=False)], + [dict(type='mmseg.PackSegInputs')]]) + tta_module = TRANSFORMS.build(tta_transform) + tta_results = tta_module(results.copy()) + assert [tta_results['inputs'][0].shape] == [(3, 512, 512)] -# tta_transform = dict( -# type='MultiScaleFlipAug', -# scales=[(512, 512)], -# allow_flip=True, -# resize_cfg=dict(type='Resize', keep_ratio=False), -# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], -# ) -# tta_module = TRANSFORMS.build(tta_transform) -# tta_results = tta_module(results.copy()) -# assert [data_sample.scale -# for data_sample in tta_results['data_sample']] == [(512, 512), -# (512, 512)] -# assert [data_sample.flip -# for data_sample in tta_results['data_sample']] == [False, True] + tta_transform = dict( + type='TestTimeAug', + transforms=[ + [dict(type='Resize', scale=(512, 512), keep_ratio=False)], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='mmseg.PackSegInputs')] + ]) + tta_module = TRANSFORMS.build(tta_transform) + tta_results = tta_module(results.copy()) + assert [img.shape for img in tta_results['inputs']] == [(3, 512, 512), + (3, 512, 512)] + assert [ + data_sample.metainfo['flip'] + for data_sample in tta_results['data_samples'] + ] == [False, True] -# tta_transform = dict( -# type='MultiScaleFlipAug', -# scale_factor=[0.5, 1.0, 2.0], -# allow_flip=False, -# resize_cfg=dict(type='Resize', keep_ratio=False), -# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], -# ) -# tta_module = TRANSFORMS.build(tta_transform) -# tta_results = tta_module(results.copy()) -# assert [data_sample.scale -# for data_sample in tta_results['data_sample']] == [(256, 144), -# (512, 288), -# (1024, 576)] -# assert [data_sample.flip for data_sample in tta_results['data_sample'] -# ] == [False, False, False] + tta_transform = dict( + type='TestTimeAug', + transforms=[[ + dict(type='Resize', scale_factor=r, keep_ratio=False) + for r in [0.5, 1.0, 2.0] + ], [dict(type='mmseg.PackSegInputs')]]) + tta_module = TRANSFORMS.build(tta_transform) + tta_results = tta_module(results.copy()) + assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256), + (3, 288, 512), + (3, 576, 1024)] -# tta_transform = dict( -# type='MultiScaleFlipAug', -# scale_factor=[0.5, 1.0, 2.0], -# allow_flip=True, -# resize_cfg=dict(type='Resize', keep_ratio=False), -# transforms=[dict(type='mmseg.PackSegInputs', meta_keys=meta_keys)], -# ) -# tta_module = TRANSFORMS.build(tta_transform) -# tta_results = tta_module(results.copy()) -# assert [data_sample.scale -# for data_sample in tta_results['data_sample']] == [(256, 144), -# (256, 144), -# (512, 288), -# (512, 288), -# (1024, 576), -# (1024, 576)] -# assert [data_sample.flip for data_sample in tta_results['data_sample'] -# ] == [False, True, False, True, False, True] + tta_transform = dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale_factor=r, keep_ratio=True) + for r in [0.5, 1.0, 2.0] + ], + [ + dict(type='RandomFlip', prob=0., direction='horizontal'), + dict(type='RandomFlip', prob=1., direction='horizontal') + ], [dict(type='mmseg.PackSegInputs')] + ]) + tta_module = TRANSFORMS.build(tta_transform) + tta_results = tta_module(results.copy()) + assert [img.shape for img in tta_results['inputs']] == [(3, 144, 256), + (3, 144, 256), + (3, 288, 512), + (3, 288, 512), + (3, 576, 1024), + (3, 576, 1024)] + assert [ + data_sample.metainfo['flip'] + for data_sample in tta_results['data_samples'] + ] == [False, True, False, True, False, True] diff --git a/tests/test_models/test_segmentors/test_seg_tta_model.py b/tests/test_models/test_segmentors/test_seg_tta_model.py new file mode 100644 index 0000000000..c0e76b22f4 --- /dev/null +++ b/tests/test_models/test_segmentors/test_seg_tta_model.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine import ConfigDict +from mmengine.model import BaseTTAModel +from mmengine.structures import PixelData + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import register_all_modules +from .utils import * # noqa: F401,F403 + +register_all_modules() + + +def test_encoder_decoder_tta(): + + segmentor_cfg = ConfigDict( + type='EncoderDecoder', + backbone=dict(type='ExampleBackbone'), + decode_head=dict(type='ExampleDecodeHead'), + train_cfg=None, + test_cfg=dict(mode='whole')) + + cfg = ConfigDict(type='SegTTAModel', module=segmentor_cfg) + + model: BaseTTAModel = MODELS.build(cfg) + + imgs = [] + data_samples = [] + directions = ['horizontal', 'vertical'] + for i in range(12): + flip_direction = directions[0] if i % 3 == 0 else directions[1] + imgs.append(torch.randn(1, 3, 10 + i, 10 + i)) + data_samples.append([ + SegDataSample( + metainfo=dict( + ori_shape=(10, 10), + img_shape=(10 + i, 10 + i), + flip=(i % 2 == 0), + flip_direction=flip_direction), + gt_sem_seg=PixelData(data=torch.randint(0, 19, (1, 10, 10)))) + ]) + + model.test_step(dict(inputs=imgs, data_samples=data_samples)) + + # test out_channels == 1 + segmentor_cfg = ConfigDict( + type='EncoderDecoder', + backbone=dict(type='ExampleBackbone'), + decode_head=dict( + type='ExampleDecodeHead', + num_classes=2, + out_channels=1, + threshold=0.4), + train_cfg=None, + test_cfg=dict(mode='whole')) + model.module = MODELS.build(segmentor_cfg) + for data_sample in data_samples: + data_sample[0].gt_sem_seg.data = torch.randint(0, 2, (1, 10, 10)) + model.test_step(dict(inputs=imgs, data_samples=data_samples)) diff --git a/tests/test_models/test_segmentors/utils.py b/tests/test_models/test_segmentors/utils.py index 9b155c0961..6b440df906 100644 --- a/tests/test_models/test_segmentors/utils.py +++ b/tests/test_models/test_segmentors/utils.py @@ -66,9 +66,9 @@ def forward(self, x): @MODELS.register_module() class ExampleDecodeHead(BaseDecodeHead): - def __init__(self, num_classes=19, out_channels=None): + def __init__(self, num_classes=19, out_channels=None, **kwargs): super().__init__( - 3, 3, num_classes=num_classes, out_channels=out_channels) + 3, 3, num_classes=num_classes, out_channels=out_channels, **kwargs) def forward(self, inputs): return self.cls_seg(inputs[0]) diff --git a/tools/test.py b/tools/test.py index ea1917d182..b21b990f26 100644 --- a/tools/test.py +++ b/tools/test.py @@ -43,6 +43,8 @@ def parse_args(): choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') + parser.add_argument( + '--tta', action='store_true', help='Test time augmentation') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: @@ -99,6 +101,11 @@ def main(): if args.show or args.show_dir: cfg = trigger_visualization_hook(cfg, args) + if args.tta: + cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline + cfg.tta_model.module = cfg.model + cfg.model = cfg.tta_model + # build the runner from config runner = Runner.from_cfg(cfg)