-
Notifications
You must be signed in to change notification settings - Fork 425
/
formatting.py
143 lines (120 loc) · 5.17 KB
/
formatting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
import numpy as np
import torch
from mmcv.transforms import to_tensor
from mmcv.transforms.base import BaseTransform
from mmengine.structures import InstanceData, LabelData
from mmselfsup.registry import TRANSFORMS
from mmselfsup.structures import SelfSupDataSample
@TRANSFORMS.register_module()
class PackSelfSupInputs(BaseTransform):
"""Pack data into the format compatible with the inputs of algorithm.
Required Keys:
- img
Added Keys:
- data_samples
- inputs
Args:
key (str): The key of image inputted into the model. Defaults to 'img'.
algorithm_keys (List[str]): Keys of elements related
to algorithms, e.g. mask. Defaults to [].
pseudo_label_keys (List[str]): Keys set to be the attributes of
pseudo_label. Defaults to [].
meta_keys (List[str]): The keys of meta info of an image.
Defaults to [].
"""
def __init__(self,
key: str = 'img',
algorithm_keys: List[str] = [],
pseudo_label_keys: List[str] = [],
meta_keys: List[str] = []) -> None:
assert isinstance(key, str), f'key should be the type of str, instead \
of {type(key)}.'
self.key = key
self.algorithm_keys = algorithm_keys
self.pseudo_label_keys = pseudo_label_keys
self.meta_keys = meta_keys
def transform(self,
results: Dict) -> Dict[torch.Tensor, SelfSupDataSample]:
"""Method to pack the data.
Args:
results (Dict): Result dict from the data pipeline.
Returns:
Dict:
- ``inputs`` (List[torch.Tensor]): The forward data of models.
- ``data_samples`` (SelfSupDataSample): The annotation info of
the forward data.
"""
packed_results = dict()
if self.key in results:
img = results[self.key]
# if img is not a list, convert it to a list
if not isinstance(img, List):
img = [img]
for i, img_ in enumerate(img):
# to handle the single channel image
img_ = np.expand_dims(img_, -1) \
if len(img_.shape) == 2 else img_
if len(img_.shape) == 3:
img_ = np.ascontiguousarray(img_.transpose(2, 0, 1))
elif len(img_.shape) == 5:
# for video data with the shape (B, C, T, H, W)
img_ = img_
else:
raise ValueError(
'img should be 2, 3 or 5 dimensional, '
f'instead of {len(img_.shape)} dimensional.')
img[i] = to_tensor(img_)
packed_results['inputs'] = img
data_sample = SelfSupDataSample()
if len(self.pseudo_label_keys) > 0:
pseudo_label = InstanceData()
data_sample.pseudo_label = pseudo_label
# gt_label, sample_idx, mask, pred_label will be set here
for key in self.algorithm_keys:
self.set_algorithm_keys(data_sample, key, results)
# keys, except for gt_label, sample_idx, mask, pred_label, will be
# set as the attributes of pseudo_label
for key in self.pseudo_label_keys:
# convert data to torch.Tensor
value = to_tensor(results[key])
setattr(data_sample.pseudo_label, key, value)
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample
return packed_results
@classmethod
def set_algorithm_keys(self, data_sample: SelfSupDataSample, key: str,
results: dict) -> None:
"""Set the algorithm keys of SelfSupDataSample.
Args:
data_sample (SelfSupDataSample): An instance of SelfSupDataSample.
key (str): The key, which may be used by the algorithm, such as
gt_label, sample_idx, mask, pred_label. For more keys, please
refer to the attribute of SelfSupDataSample.
results (dict): The results from the data pipeline.
"""
value = to_tensor(results[key])
if key == 'sample_idx':
sample_idx = InstanceData(value=value)
setattr(data_sample, 'sample_idx', sample_idx)
elif key == 'mask':
mask = InstanceData(value=value)
setattr(data_sample, 'mask', mask)
elif key == 'gt_label':
gt_label = LabelData(value=value)
setattr(data_sample, 'gt_label', gt_label)
elif key == 'pred_label':
pred_label = LabelData(value=value)
setattr(data_sample, 'pred_label', pred_label)
else:
raise AttributeError(f'{key} is not a attribute of \
SelfSupDataSample')
def __repr__(self) -> str:
return self.__class__.__name__ + (f'(keys={self.key}, \
algorithm_keys={self.algorithm_keys}, \
pseudo_label_keys={self.pseudo_label_keys}, \
meta_keys={self.meta_keys})')