-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
data_preprocessor.py
151 lines (129 loc) · 6.08 KB
/
data_preprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number
from typing import Any, Dict, List, Optional, Sequence
import torch
from mmengine.model import BaseDataPreprocessor
from mmseg.registry import MODELS
from mmseg.utils import stack_batch
@MODELS.register_module()
class SegDataPreProcessor(BaseDataPreprocessor):
"""Image pre-processor for segmentation tasks.
Comparing with the :class:`mmengine.ImgDataPreprocessor`,
1. It won't do normalization if ``mean`` is not specified.
2. It does normalization and color space conversion after stacking batch.
3. It supports batch augmentations like mixup and cutmix.
It provides the data pre-processing as follows
- Collate and move data to the target device.
- Pad inputs to the input size with defined ``pad_val``, and pad seg map
with defined ``seg_pad_val``.
- Stack inputs to batch_inputs.
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- Normalize image with defined std and mean.
- Do batch augmentations like Mixup and Cutmix during training.
Args:
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (float, optional): Padding value. Default: 0.
seg_pad_val (float, optional): Padding value of segmentation map.
Default: 255.
padding_mode (str): Type of padding. Default: constant.
- constant: pads with a constant value, this value is specified
with pad_val.
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False.
batch_augments (list[dict], optional): Batch-level augmentations
test_cfg (dict, optional): The padding size config in testing, if not
specify, will use `size` and `size_divisor` params as default.
Defaults to None, only supports keys `size` or `size_divisor`.
"""
def __init__(
self,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
size: Optional[tuple] = None,
size_divisor: Optional[int] = None,
pad_val: Number = 0,
seg_pad_val: Number = 255,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
batch_augments: Optional[List[dict]] = None,
test_cfg: dict = None,
):
super().__init__()
self.size = size
self.size_divisor = size_divisor
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val
assert not (bgr_to_rgb and rgb_to_bgr), (
'`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time')
self.channel_conversion = rgb_to_bgr or bgr_to_rgb
if mean is not None:
assert std is not None, 'To enable the normalization in ' \
'preprocessing, please specify both ' \
'`mean` and `std`.'
# Enable the normalization in preprocessing.
self._enable_normalize = True
self.register_buffer('mean',
torch.tensor(mean).view(-1, 1, 1), False)
self.register_buffer('std',
torch.tensor(std).view(-1, 1, 1), False)
else:
self._enable_normalize = False
# TODO: support batch augmentations.
self.batch_augments = batch_augments
# Support different padding methods in testing
self.test_cfg = test_cfg
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
Dict: Data in the same format as the model input.
"""
data = self.cast_data(data) # type: ignore
inputs = data['inputs']
data_samples = data.get('data_samples', None)
# TODO: whether normalize should be after stack_batch
if self.channel_conversion and inputs[0].size(0) == 3:
inputs = [_input[[2, 1, 0], ...] for _input in inputs]
inputs = [_input.float() for _input in inputs]
if self._enable_normalize:
inputs = [(_input - self.mean) / self.std for _input in inputs]
if training:
assert data_samples is not None, ('During training, ',
'`data_samples` must be define.')
inputs, data_samples = stack_batch(
inputs=inputs,
data_samples=data_samples,
size=self.size,
size_divisor=self.size_divisor,
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
if self.batch_augments is not None:
inputs, data_samples = self.batch_augments(
inputs, data_samples)
else:
assert len(inputs) == 1, (
'Batch inference is not support currently, '
'as the image size might be different in a batch')
# pad images when testing
if self.test_cfg:
inputs, padded_samples = stack_batch(
inputs=inputs,
size=self.test_cfg.get('size', None),
size_divisor=self.test_cfg.get('size_divisor', None),
pad_val=self.pad_val,
seg_pad_val=self.seg_pad_val)
for data_sample, pad_info in zip(data_samples, padded_samples):
data_sample.set_metainfo({**pad_info})
else:
inputs = torch.stack(inputs, dim=0)
return dict(inputs=inputs, data_samples=data_samples)