-
Notifications
You must be signed in to change notification settings - Fork 97
/
model.py
509 lines (430 loc) · 24.9 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
import attr
from typing import Optional, Text, List, Sequence, Tuple
from sleap.nn.config.utils import oneof
@attr.s(auto_attribs=True)
class SingleInstanceConfmapsHeadConfig:
"""Configurations for single instance confidence map heads.
These heads are used in single instance models that make the assumption that only
one of each body part is present in the image. These heads produce confidence maps
with a single peak for each part type which can be detected via global peak finding.
Do not use this head if there is more than one animal present in the image.
Attributes:
part_names: Text name of the body parts (nodes) that the head will be configured
to produce. The number of parts determines the number of channels in the
output. If not specified, all body parts in the skeleton will be used.
sigma: Spread of the Gaussian distribution of the confidence maps as a scalar
float. Smaller values are more precise but may be difficult to learn as they
have a lower density within the image space. Larger values are easier to
learn but are less precise with respect to the peak coordinate. This spread
is in units of pixels of the model input image, i.e., the image resolution
after any input scaling is applied.
output_stride: The stride of the output confidence maps relative to the input
image. This is the reciprocal of the resolution, e.g., an output stride of 2
results in confidence maps that are 0.5x the size of the input. Increasing
this value can considerably speed up model performance and decrease memory
requirements, at the cost of decreased spatial resolution.
offset_refinement: If `True`, model will also output an offset refinement map
used to achieve subpixel localization of peaks during inference. This can
improve the localization accuracy of the model at the cost of additional
memory and training and inference time. If `False` (the default), subpixel
localization can be achieved post-hoc with deterministic refinement, which
does not require additional resources or training, but may not achieve the
same accuracy as learned refinement.
"""
part_names: Optional[List[Text]] = None
sigma: float = 5.0
output_stride: int = 1
offset_refinement: bool = False
@attr.s(auto_attribs=True)
class CentroidsHeadConfig:
"""Configurations for centroid confidence map heads.
These heads are used in topdown models that rely on centroid detection to detect
instances for cropping before predicting the remaining body parts.
Multiple centroids can be present (one per instance), so their coordinates can be
recovered in inference via local peak finding.
Attributes:
anchor_part: Text name of a body part (node) to use as the anchor point. If
None, the midpoint of the bounding box of all visible instance points will
be used as the anchor. The bounding box midpoint will also be used if the
anchor part is specified but not visible in the instance. Setting a reliable
anchor point can significantly improve topdown model accuracy as they
benefit from a consistent geometry of the body parts relative to the center
of the image.
sigma: Spread of the Gaussian distribution of the confidence maps as a scalar
float. Smaller values are more precise but may be difficult to learn as they
have a lower density within the image space. Larger values are easier to
learn but are less precise with respect to the peak coordinate. This spread
is in units of pixels of the model input image, i.e., the image resolution
after any input scaling is applied.
output_stride: The stride of the output confidence maps relative to the input
image. This is the reciprocal of the resolution, e.g., an output stride of 2
results in confidence maps that are 0.5x the size of the input. Increasing
this value can considerably speed up model performance and decrease memory
requirements, at the cost of decreased spatial resolution.
offset_refinement: If `True`, model will also output an offset refinement map
used to achieve subpixel localization of peaks during inference. This can
improve the localization accuracy of the model at the cost of additional
memory and training and inference time. If `False` (the default), subpixel
localization can be achieved post-hoc with deterministic refinement, which
does not require additional resources or training, but may not achieve the
same accuracy as learned refinement.
"""
anchor_part: Optional[Text] = None
sigma: float = 5.0
output_stride: int = 1
offset_refinement: bool = False
@attr.s(auto_attribs=True)
class CenteredInstanceConfmapsHeadConfig:
"""Configurations for centered instance confidence map heads.
These heads are used in topdown multi-instance models that make the assumption that
there is an instance reliably centered in the cropped input image. These heads are
useful when centroids are easy to detect as they learn complex relationships between
the geometry of body parts, even when animals are occluded.
This comes at the cost of a strong reliance on the accuracy of the instance-centered
cropping, i.e., it is heavily limited by the accuracy of the centroid model.
Additionally, since one image crop is evaluated per instance, topdown models scale
linearly with the number of animals in the frame, which can result in poor
performance when many instances are present.
Use this head when centroids are easy to detect, preferably using a consistent body
part as an anchor, and when there are few animals that cover a small region of the
full frame.
Attributes:
anchor_part: Text name of a body part (node) to use as the anchor point. If
None, the midpoint of the bounding box of all visible instance points will
be used as the anchor. The bounding box midpoint will also be used if the
anchor part is specified but not visible in the instance. Setting a reliable
anchor point can significantly improve topdown model accuracy as they
benefit from a consistent geometry of the body parts relative to the center
of the image.
part_names: Text name of the body parts (nodes) that the head will be configured
to produce. The number of parts determines the number of channels in the
output. If not specified, all body parts in the skeleton will be used.
sigma: Spread of the Gaussian distribution of the confidence maps as a scalar
float. Smaller values are more precise but may be difficult to learn as they
have a lower density within the image space. Larger values are easier to
learn but are less precise with respect to the peak coordinate. This spread
is in units of pixels of the model input image, i.e., the image resolution
after any input scaling is applied.
output_stride: The stride of the output confidence maps relative to the input
image. This is the reciprocal of the resolution, e.g., an output stride of 2
results in confidence maps that are 0.5x the size of the input. Increasing
this value can considerably speed up model performance and decrease memory
requirements, at the cost of decreased spatial resolution.
offset_refinement: If `True`, model will also output an offset refinement map
used to achieve subpixel localization of peaks during inference. This can
improve the localization accuracy of the model at the cost of additional
memory and training and inference time. If `False` (the default), subpixel
localization can be achieved post-hoc with deterministic refinement, which
does not require additional resources or training, but may not achieve the
same accuracy as learned refinement.
"""
anchor_part: Optional[Text] = None
part_names: Optional[List[Text]] = None
sigma: float = 5.0
output_stride: int = 1
offset_refinement: bool = False
@attr.s(auto_attribs=True)
class MultiInstanceConfmapsHeadConfig:
"""Configurations for multi-instance confidence map heads.
These heads are used in bottom-up multi-instance models that do not make any
assumption about the connectivity of the body parts. These heads will generate
multiple local peaks for each body part type and must be detected using local peak
finding.
Although this head alone is sufficient to detect multiple copies of each body part
type, it provides no information as to which sets of points should be grouped
together to the same instance. If this is required, a head that provides
connectivity or grouping information is required, e.g., part affinity fields.
Use this head when multiple instances of each body part are present and do not need
to be grouped or will be grouped using additional information.
This head type has the advantage that it only needs to evaluate each frame once to
find all peaks, in contrast to topdown models that must be evaluated for each crop.
This constant scaling with the number of instances can be especially beneficial when
there are many animals present in the frame.
Attributes:
part_names: Text name of the body parts (nodes) that the head will be configured
to produce. The number of parts determines the number of channels in the
output. If not specified, all body parts in the skeleton will be used.
sigma: Spread of the Gaussian distribution of the confidence maps as a scalar
float. Smaller values are more precise but may be difficult to learn as they
have a lower density within the image space. Larger values are easier to
learn but are less precise with respect to the peak coordinate. This spread
is in units of pixels of the model input image, i.e., the image resolution
after any input scaling is applied.
output_stride: The stride of the output confidence maps relative to the input
image. This is the reciprocal of the resolution, e.g., an output stride of 2
results in confidence maps that are 0.5x the size of the input. Increasing
this value can considerably speed up model performance and decrease memory
requirements, at the cost of decreased spatial resolution.
loss_weight: Scalar float used to weigh the loss term for this head during
training. Increase this to encourage the optimization to focus on improving
this specific output in multi-head models.
offset_refinement: If `True`, model will also output an offset refinement map
used to achieve subpixel localization of peaks during inference. This can
improve the localization accuracy of the model at the cost of additional
memory and training and inference time. If `False` (the default), subpixel
localization can be achieved post-hoc with deterministic refinement, which
does not require additional resources or training, but may not achieve the
same accuracy as learned refinement.
"""
part_names: Optional[List[Text]] = None
sigma: float = 5.0
output_stride: int = 1
loss_weight: float = 1.0
offset_refinement: bool = False
@attr.s(auto_attribs=True)
class PartAffinityFieldsHeadConfig:
"""Configurations for multi-instance part affinity field heads.
These heads are used in bottom-up multi-instance models that require information
about body part connectivity in order to group multiple detections of each body part
type into distinct instances.
Part affinity fields are an image-space representation of the directed graph that
defines the skeleton. Pixels that are close to the line (directed edge) formed
between pairs of nodes of the same instance will contain unit vectors pointing along
the direction of the the connection. The similarity between this line and the
average of the unit vectors at the pixels underneath the line can be used as a
matching score to associate candidate pairs of body part detections.
Use this head when multiple instances of each body part are present and need to be
grouped to coherent instances.
This head type has the advantage that it only needs to evaluate each frame once to
find all peaks, in contrast to topdown models that must be evaluated for each crop.
This constant scaling with the number of instances can be especially beneficial when
there are many animals present in the frame.
Attributes:
edges: List of 2-tuples of the form `(source_node, destination_node)` that
define pairs of text names of the directed edges of the graph. If not set,
all edges in the skeleton will be used.
sigma: Spread of the Gaussian distribution that weigh the part affinity fields
as a function of their distance from the edge they represent. Smaller values
are more precise but may be difficult to learn as they have a lower density
within the image space. Larger values are easier to learn but are less
precise with respect to the edge distance, so can be less useful in
disambiguating between edges that are nearby and parallel in direction. This
spread is in units of pixels of the model input image, i.e., the image
resolution after any input scaling is applied.
output_stride: The stride of the output part affinity fields relative to the
input image. This is the reciprocal of the resolution, e.g., an output
stride of 2 results in PAFs that are 0.5x the size of the input. Increasing
this value can considerably speed up model performance and decrease memory
requirements, at the cost of decreased spatial resolution.
loss_weight: Scalar float used to weigh the loss term for this head during
training. Increase this to encourage the optimization to focus on improving
this specific output in multi-head models.
"""
edges: Optional[Sequence[Tuple[Text, Text]]] = None
sigma: float = 15.0
output_stride: int = 1
loss_weight: float = 1.0
@attr.s(auto_attribs=True)
class MultiInstanceConfig:
"""Configuration for combined multi-instance confidence map and PAF model heads.
This configuration specifies a multi-head model that outputs both multi-instance
confidence maps and part affinity fields, which together enable multi-instance pose
estimation in a bottom-up fashion, i.e., no instance cropping or centroids are
required.
Attributes:
confmaps: Part confidence map configuration (see the description in
`MultiInstanceConfmapsHeadConfig`).
pafs: Part affinity fields configuration (see the description in
`PartAffinityFieldsHeadConfig`).
"""
confmaps: MultiInstanceConfmapsHeadConfig = attr.ib(
factory=MultiInstanceConfmapsHeadConfig
)
pafs: PartAffinityFieldsHeadConfig = attr.ib(factory=PartAffinityFieldsHeadConfig)
@oneof
@attr.s(auto_attribs=True)
class HeadsConfig:
"""Configurations related to the model output head type.
Only one attribute of this class can be set, which defines the model output type.
Attributes:
single_instance: An instance of `SingleInstanceConfmapsHeadConfig`.
centroid: An instance of `CentroidsHeadConfig`.
centered_instance: An instance of `CenteredInstanceConfmapsHeadConfig`.
multi_instance: An instance of `MultiInstanceConfig`.
"""
single_instance: Optional[SingleInstanceConfmapsHeadConfig] = None
centroid: Optional[CentroidsHeadConfig] = None
centered_instance: Optional[CenteredInstanceConfmapsHeadConfig] = None
multi_instance: Optional[MultiInstanceConfig] = None
@attr.s(auto_attribs=True)
class LEAPConfig:
"""LEAP backbone configuration.
Attributes:
max_stride: Determines the number of downsampling blocks in the network,
increasing receptive field size at the cost of network size.
output_stride: Determines the number of upsampling blocks in the network.
filters: Base number of filters in the network.
filters_rate: Factor to scale the number of filters by at each block.
up_interpolate: If True, use bilinear upsampling instead of transposed
convolutions for upsampling. This can save computations but may lower
overall accuracy.
stacks: Number of repeated stacks of the network (excluding the stem).
"""
max_stride: int = 8 # determines down blocks
output_stride: int = 1 # determines up blocks
filters: int = 64
filters_rate: float = 2
up_interpolate: bool = False
stacks: int = 1
@attr.s(auto_attribs=True)
class UNetConfig:
"""UNet backbone configuration.
Attributes:
stem_stride: If not None, controls how many stem blocks to use for initial
downsampling. These are useful for learned downsampling that is able to
retain spatial information while reducing large input image sizes.
max_stride: Determines the number of downsampling blocks in the network,
increasing receptive field size at the cost of network size.
output_stride: Determines the number of upsampling blocks in the network.
filters: Base number of filters in the network.
filters_rate: Factor to scale the number of filters by at each block.
middle_block: If True, add an intermediate block between the downsampling and
upsampling branch for additional processing for features at the largest
receptive field size. This will not introduce an extra pooling step.
up_interpolate: If True, use bilinear upsampling instead of transposed
convolutions for upsampling. This can save computations but may lower
overall accuracy.
stacks: Number of repeated stacks of the network (excluding the stem).
"""
stem_stride: Optional[int] = None
max_stride: int = 16
output_stride: int = 1
filters: int = 64
filters_rate: float = 2
middle_block: bool = True
up_interpolate: bool = False
stacks: int = 1
@attr.s(auto_attribs=True)
class HourglassConfig:
"""Hourglass backbone configuration.
Attributes:
stem_stride: Controls how many stem blocks to use for initial downsampling.
These are useful for learned downsampling that is able to retain spatial
information while reducing large input image sizes.
max_stride: Determines the number of downsampling blocks in the network,
increasing receptive field size at the cost of network size.
output_stride: Determines the number of upsampling blocks in the network.
filters: Base number of filters in the network.
filters_increase: Constant to increase the number of filters by at each block.
stacks: Number of repeated stacks of the network (excluding the stem).
"""
stem_stride: int = 4
max_stride: int = 64
output_stride: int = 4
stem_filters: int = 128
filters: int = 256
filter_increase: int = 128
stacks: int = 3
@attr.s(auto_attribs=True)
class UpsamplingConfig:
"""Upsampling stack configuration.
Attributes:
method: If "transposed_conv", use a strided transposed convolution to perform
learnable upsampling. If "interpolation", bilinear upsampling will be used
instead.
skip_connections: If "add", incoming feature tensors form skip connection with
upsampled features via element-wise addition. Height/width are matched via
stride and a 1x1 linear conv is applied if the channel counts do no match
up. If "concatenate", the skip connection is formed via channel-wise
concatenation. If None, skip connections will not be formed.
block_stride: The striding of the upsampling *layer* (not tensor). This is
typically set to 2, such that the tensor doubles in size with each
upsampling step, but can be set higher to upsample to the desired
`output_stride` directly in fewer steps.
filters: Integer that specifies the base number of filters in each convolution
layer. This will be scaled by the `filters_rate` at every upsampling step.
filters_rate: Factor to scale the number of filters in the convolution layers
after each upsampling step. If set to 1, the number of filters won't change.
refine_convs: If greater than 0, specifies the number of 3x3 convolutions that
will be applied after the upsampling step for refinement. These layers can
serve the purpose of "mixing" the skip connection fused features, or to
refine the current feature map after upsampling, which can help to prevent
aliasing and checkerboard effects. If 0, no additional convolutions will be
applied.
conv_batchnorm: Specifies whether batch norm should be applied after each
convolution (and before the ReLU activation).
transposed_conv_kernel_size: Size of the kernel for the transposed convolution.
No effect if bilinear upsampling is used.
"""
method: Text = attr.ib(
default="interpolation",
validator=attr.validators.in_(["interpolation", "transposed_conv"]),
)
skip_connections: Optional[Text] = attr.ib(
default=None,
validator=attr.validators.optional(attr.validators.in_(["add", "concatenate"])),
)
block_stride: int = 2
filters: int = 64
filters_rate: float = 1
refine_convs: int = 2
batch_norm: bool = True
transposed_conv_kernel_size: int = 4
@attr.s(auto_attribs=True)
class ResNetConfig:
"""ResNet backbone configuration.
Attributes:
version: Name of the ResNetV1 variant. Can be one of: "ResNet50", "ResNet101",
or "ResNet152".
weights: Controls how the network weights are initialized. If "random", the
network is not pretrained. If "frozen", the network uses pretrained weights
and keeps them fixed. If "tunable", the network uses pretrained weights and
allows them to be trainable.
upsampling: A `UpsamplingConfig` that defines an upsampling branch if not None.
max_stride: Stride of the backbone feature activations. These should be <= 32.
output_stride: Stride of the final output. If the upsampling branch is not
defined, the output stride is controlled via dilated convolutions or reduced
pooling in the backbone.
"""
version: Text = attr.ib(
default="ResNet50",
validator=attr.validators.in_(["ResNet50", "ResNet101", "ResNet152"]),
)
weights: Text = attr.ib(
default="frozen", validator=attr.validators.in_(["random", "frozen", "tunable"])
)
upsampling: Optional[UpsamplingConfig] = None
max_stride: int = 32
output_stride: int = 4
@attr.s(auto_attribs=True)
class PretrainedEncoderConfig:
"""Configuration for UNet backbone with pretrained encoder.
Attributes:
encoder: Name of the network architecture to use as the encoder.
pretrained: If `True`, use initialized with weights pretrained on ImageNet.
decoder_filters: Base number of filters for the upsampling blocks in the
decoder.
decoder_filters_rate: Factor to scale the number of filters by at each
consecutive upsampling block in the decoder.
output_stride: Stride of the final output.
"""
encoder: Text = attr.ib(default="efficientnetb0")
pretrained: bool = True
decoder_filters: int = 256
decoder_filters_rate: float = 1.0
output_stride: int = 2
@oneof
@attr.s(auto_attribs=True)
class BackboneConfig:
"""Configurations related to the model backbone.
Only one field can be set and will determine which backbone architecture to use.
Attributes:
leap: A `LEAPConfig` instance.
unet: A `UNetConfig` instance.
hourglass: A `HourglassConfig` instance.
resnet: A `ResNetConfig` instance.
"""
leap: Optional[LEAPConfig] = None
unet: Optional[UNetConfig] = None
hourglass: Optional[HourglassConfig] = None
resnet: Optional[ResNetConfig] = None
pretrained_encoder: Optional[PretrainedEncoderConfig] = None
@attr.s(auto_attribs=True)
class ModelConfig:
"""Configurations related to model architecture.
Attributes:
backbone: Configurations related to the main network architecture.
heads: Configurations related to the output heads.
"""
backbone: BackboneConfig = attr.ib(factory=BackboneConfig)
heads: HeadsConfig = attr.ib(factory=HeadsConfig)