/
spr.py
306 lines (257 loc) · 12.1 KB
/
spr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from mmpose.registry import KEYPOINT_CODECS
from .base import BaseKeypointCodec
from .utils import (batch_heatmap_nms, generate_displacement_heatmap,
generate_gaussian_heatmaps, get_diagonal_lengths,
get_instance_root)
@KEYPOINT_CODECS.register_module()
class SPR(BaseKeypointCodec):
"""Encode/decode keypoints with Structured Pose Representation (SPR).
See the paper `Single-stage multi-person pose machines`_
by Nie et al (2017) for details
Note:
- instance number: N
- keypoint number: K
- keypoint dimension: D
- image size: [w, h]
- heatmap size: [W, H]
Encoded:
- heatmaps (np.ndarray): The generated heatmap in shape (1, H, W)
where [W, H] is the `heatmap_size`. If the keypoint heatmap is
generated together, the output heatmap shape is (K+1, H, W)
- heatmap_weights (np.ndarray): The target weights for heatmaps which
has same shape with heatmaps.
- displacements (np.ndarray): The dense keypoint displacement in
shape (K*2, H, W).
- displacement_weights (np.ndarray): The target weights for heatmaps
which has same shape with displacements.
Args:
input_size (tuple): Image size in [w, h]
heatmap_size (tuple): Heatmap size in [W, H]
sigma (float or tuple, optional): The sigma values of the Gaussian
heatmaps. If sigma is a tuple, it includes both sigmas for root
and keypoint heatmaps. ``None`` means the sigmas are computed
automatically from the heatmap size. Defaults to ``None``
generate_keypoint_heatmaps (bool): Whether to generate Gaussian
heatmaps for each keypoint. Defaults to ``False``
root_type (str): The method to generate the instance root. Options
are:
- ``'kpt_center'``: Average coordinate of all visible keypoints.
- ``'bbox_center'``: Center point of bounding boxes outlined by
all visible keypoints.
Defaults to ``'kpt_center'``
minimal_diagonal_length (int or float): The threshold of diagonal
length of instance bounding box. Small instances will not be
used in training. Defaults to 32
background_weight (float): Loss weight of background pixels.
Defaults to 0.1
decode_thr (float): The threshold of keypoint response value in
heatmaps. Defaults to 0.01
decode_nms_kernel (int): The kernel size of the NMS during decoding,
which should be an odd integer. Defaults to 5
decode_max_instances (int): The maximum number of instances
to decode. Defaults to 30
.. _`Single-stage multi-person pose machines`:
https://arxiv.org/abs/1908.09220
"""
field_mapping_table = dict(
heatmaps='heatmaps',
heatmap_weights='heatmap_weights',
displacements='displacements',
displacement_weights='displacement_weights',
)
def __init__(
self,
input_size: Tuple[int, int],
heatmap_size: Tuple[int, int],
sigma: Optional[Union[float, Tuple[float]]] = None,
generate_keypoint_heatmaps: bool = False,
root_type: str = 'kpt_center',
minimal_diagonal_length: Union[int, float] = 5,
background_weight: float = 0.1,
decode_nms_kernel: int = 5,
decode_max_instances: int = 30,
decode_thr: float = 0.01,
):
super().__init__()
self.input_size = input_size
self.heatmap_size = heatmap_size
self.generate_keypoint_heatmaps = generate_keypoint_heatmaps
self.root_type = root_type
self.minimal_diagonal_length = minimal_diagonal_length
self.background_weight = background_weight
self.decode_nms_kernel = decode_nms_kernel
self.decode_max_instances = decode_max_instances
self.decode_thr = decode_thr
self.scale_factor = (np.array(input_size) /
heatmap_size).astype(np.float32)
if sigma is None:
sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 32
if generate_keypoint_heatmaps:
# sigma for root heatmap and keypoint heatmaps
self.sigma = (sigma, sigma // 2)
else:
self.sigma = (sigma, )
else:
if not isinstance(sigma, (tuple, list)):
sigma = (sigma, )
if generate_keypoint_heatmaps:
assert len(sigma) == 2, 'sigma for keypoints must be given ' \
'if `generate_keypoint_heatmaps` ' \
'is True. e.g. sigma=(4, 2)'
self.sigma = sigma
def _get_heatmap_weights(self,
heatmaps,
fg_weight: float = 1,
bg_weight: float = 0):
"""Generate weight array for heatmaps.
Args:
heatmaps (np.ndarray): Root and keypoint (optional) heatmaps
fg_weight (float): Weight for foreground pixels. Defaults to 1.0
bg_weight (float): Weight for background pixels. Defaults to 0.0
Returns:
np.ndarray: Heatmap weight array in the same shape with heatmaps
"""
heatmap_weights = np.ones(heatmaps.shape, dtype=np.float32) * bg_weight
heatmap_weights[heatmaps > 0] = fg_weight
return heatmap_weights
def encode(self,
keypoints: np.ndarray,
keypoints_visible: Optional[np.ndarray] = None) -> dict:
"""Encode keypoints into root heatmaps and keypoint displacement
fields. Note that the original keypoint coordinates should be in the
input image space.
Args:
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
keypoints_visible (np.ndarray): Keypoint visibilities in shape
(N, K)
Returns:
dict:
- heatmaps (np.ndarray): The generated heatmap in shape
(1, H, W) where [W, H] is the `heatmap_size`. If keypoint
heatmaps are generated together, the shape is (K+1, H, W)
- heatmap_weights (np.ndarray): The pixel-wise weight for heatmaps
which has same shape with `heatmaps`
- displacements (np.ndarray): The generated displacement fields in
shape (K*D, H, W). The vector on each pixels represents the
displacement of keypoints belong to the associated instance
from this pixel.
- displacement_weights (np.ndarray): The pixel-wise weight for
displacements which has same shape with `displacements`
"""
if keypoints_visible is None:
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
# keypoint coordinates in heatmap
_keypoints = keypoints / self.scale_factor
# compute the root and scale of each instance
roots, roots_visible = get_instance_root(_keypoints, keypoints_visible,
self.root_type)
diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible)
# discard the small instances
roots_visible[diagonal_lengths < self.minimal_diagonal_length] = 0
# generate heatmaps
heatmaps, _ = generate_gaussian_heatmaps(
heatmap_size=self.heatmap_size,
keypoints=roots[:, None],
keypoints_visible=roots_visible[:, None],
sigma=self.sigma[0])
heatmap_weights = self._get_heatmap_weights(
heatmaps, bg_weight=self.background_weight)
if self.generate_keypoint_heatmaps:
keypoint_heatmaps, _ = generate_gaussian_heatmaps(
heatmap_size=self.heatmap_size,
keypoints=_keypoints,
keypoints_visible=keypoints_visible,
sigma=self.sigma[1])
keypoint_heatmaps_weights = self._get_heatmap_weights(
keypoint_heatmaps, bg_weight=self.background_weight)
heatmaps = np.concatenate((keypoint_heatmaps, heatmaps), axis=0)
heatmap_weights = np.concatenate(
(keypoint_heatmaps_weights, heatmap_weights), axis=0)
# generate displacements
displacements, displacement_weights = \
generate_displacement_heatmap(
self.heatmap_size,
_keypoints,
keypoints_visible,
roots,
roots_visible,
diagonal_lengths,
self.sigma[0],
)
encoded = dict(
heatmaps=heatmaps,
heatmap_weights=heatmap_weights,
displacements=displacements,
displacement_weights=displacement_weights)
return encoded
def decode(self, heatmaps: Tensor,
displacements: Tensor) -> Tuple[np.ndarray, np.ndarray]:
"""Decode the keypoint coordinates from heatmaps and displacements. The
decoded keypoint coordinates are in the input image space.
Args:
heatmaps (Tensor): Encoded root and keypoints (optional) heatmaps
in shape (1, H, W) or (K+1, H, W)
displacements (Tensor): Encoded keypoints displacement fields
in shape (K*D, H, W)
Returns:
tuple:
- keypoints (Tensor): Decoded keypoint coordinates in shape
(N, K, D)
- scores (tuple):
- root_scores (Tensor): The root scores in shape (N, )
- keypoint_scores (Tensor): The keypoint scores in
shape (N, K). If keypoint heatmaps are not generated,
`keypoint_scores` will be `None`
"""
# heatmaps, displacements = encoded
_k, h, w = displacements.shape
k = _k // 2
displacements = displacements.view(k, 2, h, w)
# convert displacements to a dense keypoint prediction
y, x = torch.meshgrid(torch.arange(h), torch.arange(w))
regular_grid = torch.stack([x, y], dim=0).to(displacements)
posemaps = (regular_grid[None] + displacements).flatten(2)
# find local maximum on root heatmap
root_heatmap_peaks = batch_heatmap_nms(heatmaps[None, -1:],
self.decode_nms_kernel)
root_scores, pos_idx = root_heatmap_peaks.flatten().topk(
self.decode_max_instances)
mask = root_scores > self.decode_thr
root_scores, pos_idx = root_scores[mask], pos_idx[mask]
keypoints = posemaps[:, :, pos_idx].permute(2, 0, 1).contiguous()
if self.generate_keypoint_heatmaps and heatmaps.shape[0] == 1 + k:
# compute scores for each keypoint
keypoint_scores = self.get_keypoint_scores(heatmaps[:k], keypoints)
else:
keypoint_scores = None
keypoints = torch.cat([
kpt * self.scale_factor[i]
for i, kpt in enumerate(keypoints.split(1, -1))
],
dim=-1)
return keypoints, (root_scores, keypoint_scores)
def get_keypoint_scores(self, heatmaps: Tensor, keypoints: Tensor):
"""Calculate the keypoint scores with keypoints heatmaps and
coordinates.
Args:
heatmaps (Tensor): Keypoint heatmaps in shape (K, H, W)
keypoints (Tensor): Keypoint coordinates in shape (N, K, D)
Returns:
Tensor: Keypoint scores in [N, K]
"""
k, h, w = heatmaps.shape
keypoints = torch.stack((
keypoints[..., 0] / (w - 1) * 2 - 1,
keypoints[..., 1] / (h - 1) * 2 - 1,
),
dim=-1)
keypoints = keypoints.transpose(0, 1).unsqueeze(1).contiguous()
keypoint_scores = torch.nn.functional.grid_sample(
heatmaps.unsqueeze(1), keypoints,
padding_mode='border').view(k, -1).transpose(0, 1).contiguous()
return keypoint_scores