diff --git a/configs/_base_/datasets/levir_256x256.py b/configs/_base_/datasets/levir_256x256.py index a2a69aa9e9..6e018a5ae4 100644 --- a/configs/_base_/datasets/levir_256x256.py +++ b/configs/_base_/datasets/levir_256x256.py @@ -11,7 +11,16 @@ train_pipeline = [ dict(type='LoadMultipleRSImageFromFile'), dict(type='LoadAnnotations'), - dict(type='Albu', transforms=albu_train_transforms), + dict( + type='Albu', + keymap={ + 'img': 'image', + 'img2': 'image2', + 'gt_seg_map': 'mask' + }, + transforms=albu_train_transforms, + additional_targets={'image2': 'image'}, + bgr_to_rgb=False), dict(type='ConcatCDInput'), dict(type='PackSegInputs') ] diff --git a/configs/swin/swin-tiny-patch4-window7_upernet_1xb8-20k_levir-256x256.py b/configs/swin/swin-tiny-patch4-window7_upernet_1xb8-20k_levir-256x256.py index 663f769d73..7b90686dd5 100644 --- a/configs/swin/swin-tiny-patch4-window7_upernet_1xb8-20k_levir-256x256.py +++ b/configs/swin/swin-tiny-patch4-window7_upernet_1xb8-20k_levir-256x256.py @@ -8,7 +8,8 @@ size=crop_size, type='SegDataPreProcessor', mean=[123.675, 116.28, 103.53, 123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375, 58.395, 57.12, 57.375]) + std=[58.395, 57.12, 57.375, 58.395, 57.12, 57.375], + bgr_to_rgb=False) model = dict( data_preprocessor=data_preprocessor, diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py index 082ae5b440..64e23230c6 100644 --- a/mmseg/datasets/transforms/transforms.py +++ b/mmseg/datasets/transforms/transforms.py @@ -2329,14 +2329,19 @@ class Albu(BaseTransform): Args: transforms (list[dict]): A list of albu transformations keymap (dict): Contains {'input key':'albumentation-style key'} + additional_targets(dict): Allows applying same augmentations to \ + multiple objects of same type. update_pad_shape (bool): Whether to update padding shape according to \ the output shape of the last transform + bgr_to_rgb (bool): Whether to convert the band order to RGB """ def __init__(self, transforms: List[dict], keymap: Optional[dict] = None, - update_pad_shape: bool = False): + additional_targets: Optional[dict] = None, + update_pad_shape: bool = False, + bgr_to_rgb: bool = True): if not ALBU_INSTALLED: raise ImportError( 'albumentations is not installed, ' @@ -2349,9 +2354,12 @@ def __init__(self, self.transforms = transforms self.keymap = keymap + self.additional_targets = additional_targets self.update_pad_shape = update_pad_shape + self.bgr_to_rgb = bgr_to_rgb - self.aug = Compose([self.albu_builder(t) for t in self.transforms]) + self.aug = Compose([self.albu_builder(t) for t in self.transforms], + additional_targets=self.additional_targets) if not keymap: self.keymap_to_albu = {'img': 'image', 'gt_seg_map': 'mask'} @@ -2417,12 +2425,27 @@ def transform(self, results): results = self.mapper(results, self.keymap_to_albu) # Convert to RGB since Albumentations works with RGB images - results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB) - + if self.bgr_to_rgb: + results['image'] = cv2.cvtColor(results['image'], + cv2.COLOR_BGR2RGB) + if self.additional_targets: + for key, value in self.additional_targets.items(): + if value == 'image': + results[key] = cv2.cvtColor(results[key], + cv2.COLOR_BGR2RGB) + + # Apply Transform results = self.aug(**results) # Convert back to BGR - results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR) + if self.bgr_to_rgb: + results['image'] = cv2.cvtColor(results['image'], + cv2.COLOR_RGB2BGR) + if self.additional_targets: + for key, value in self.additional_targets.items(): + if value == 'image': + results[key] = cv2.cvtColor(results['image2'], + cv2.COLOR_RGB2BGR) # back to the original format results = self.mapper(results, self.keymap_back)