-
Notifications
You must be signed in to change notification settings - Fork 7.3k
/
mask_head.py
206 lines (171 loc) · 8.3 KB
/
mask_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.layers import Conv2d, ConvTranspose2d, ShapeSpec, cat, get_norm
from detectron2.utils.events import get_event_storage
from detectron2.utils.registry import Registry
ROI_MASK_HEAD_REGISTRY = Registry("ROI_MASK_HEAD")
ROI_MASK_HEAD_REGISTRY.__doc__ = """
Registry for mask heads, which predicts instance masks given
per-region features.
The registered object will be called with `obj(cfg, input_shape)`.
"""
def mask_rcnn_loss(pred_mask_logits, instances):
"""
Compute the mask prediction loss defined in the Mask R-CNN paper.
Args:
pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask)
for class-specific or class-agnostic, where B is the total number of predicted masks
in all images, C is the number of foreground classes, and Hmask, Wmask are the height
and width of the mask predictions. The values are logits.
instances (list[Instances]): A list of N Instances, where N is the number of images
in the batch. These instances are in 1:1
correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask,
...) associated with each instance are stored in fields.
Returns:
mask_loss (Tensor): A scalar tensor containing the loss.
"""
cls_agnostic_mask = pred_mask_logits.size(1) == 1
total_num_masks = pred_mask_logits.size(0)
mask_side_len = pred_mask_logits.size(2)
assert pred_mask_logits.size(2) == pred_mask_logits.size(3), "Mask prediction must be square!"
gt_classes = []
gt_masks = []
for instances_per_image in instances:
if len(instances_per_image) == 0:
continue
if not cls_agnostic_mask:
gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64)
gt_classes.append(gt_classes_per_image)
gt_masks_per_image = instances_per_image.gt_masks.crop_and_resize(
instances_per_image.proposal_boxes.tensor, mask_side_len
).to(device=pred_mask_logits.device)
# A tensor of shape (N, M, M), N=#instances in the image; M=mask_side_len
gt_masks.append(gt_masks_per_image)
if len(gt_masks) == 0:
return pred_mask_logits.sum() * 0
gt_masks = cat(gt_masks, dim=0)
if cls_agnostic_mask:
pred_mask_logits = pred_mask_logits[:, 0]
else:
indices = torch.arange(total_num_masks)
gt_classes = cat(gt_classes, dim=0)
pred_mask_logits = pred_mask_logits[indices, gt_classes]
if gt_masks.dtype == torch.bool:
gt_masks_bool = gt_masks
else:
# Here we allow gt_masks to be float as well (depend on the implementation of rasterize())
gt_masks_bool = gt_masks > 0.5
# Log the training accuracy (using gt classes and 0.5 threshold)
mask_incorrect = (pred_mask_logits > 0.0) != gt_masks_bool
mask_accuracy = 1 - (mask_incorrect.sum().item() / max(mask_incorrect.numel(), 1.0))
num_positive = gt_masks_bool.sum().item()
false_positive = (mask_incorrect & ~gt_masks_bool).sum().item() / max(
gt_masks_bool.numel() - num_positive, 1.0
)
false_negative = (mask_incorrect & gt_masks_bool).sum().item() / max(num_positive, 1.0)
storage = get_event_storage()
storage.put_scalar("mask_rcnn/accuracy", mask_accuracy)
storage.put_scalar("mask_rcnn/false_positive", false_positive)
storage.put_scalar("mask_rcnn/false_negative", false_negative)
mask_loss = F.binary_cross_entropy_with_logits(
pred_mask_logits, gt_masks.to(dtype=torch.float32), reduction="mean"
)
return mask_loss
def mask_rcnn_inference(pred_mask_logits, pred_instances):
"""
Convert pred_mask_logits to estimated foreground probability masks while also
extracting only the masks for the predicted classes in pred_instances. For each
predicted box, the mask of the same class is attached to the instance by adding a
new "pred_masks" field to pred_instances.
Args:
pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask)
for class-specific or class-agnostic, where B is the total number of predicted masks
in all images, C is the number of foreground classes, and Hmask, Wmask are the height
and width of the mask predictions. The values are logits.
pred_instances (list[Instances]): A list of N Instances, where N is the number of images
in the batch. Each Instances must have field "pred_classes".
Returns:
None. pred_instances will contain an extra "pred_masks" field storing a mask of size (Hmask,
Wmask) for predicted class. Note that the masks are returned as a soft (non-quantized)
masks the resolution predicted by the network; post-processing steps, such as resizing
the predicted masks to the original image resolution and/or binarizing them, is left
to the caller.
"""
cls_agnostic_mask = pred_mask_logits.size(1) == 1
if cls_agnostic_mask:
mask_probs_pred = pred_mask_logits.sigmoid()
else:
# Select masks corresponding to the predicted classes
num_masks = pred_mask_logits.shape[0]
class_pred = cat([i.pred_classes for i in pred_instances])
indices = torch.arange(num_masks, device=class_pred.device)
mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid()
# mask_probs_pred.shape: (B, 1, Hmask, Wmask)
num_boxes_per_image = [len(i) for i in pred_instances]
mask_probs_pred = mask_probs_pred.split(num_boxes_per_image, dim=0)
for prob, instances in zip(mask_probs_pred, pred_instances):
instances.pred_masks = prob # (1, Hmask, Wmask)
@ROI_MASK_HEAD_REGISTRY.register()
class MaskRCNNConvUpsampleHead(nn.Module):
"""
A mask head with several conv layers, plus an upsample layer (with `ConvTranspose2d`).
"""
def __init__(self, cfg, input_shape: ShapeSpec):
"""
The following attributes are parsed from config:
num_conv: the number of conv layers
conv_dim: the dimension of the conv layers
norm: normalization for the conv layers
"""
super(MaskRCNNConvUpsampleHead, self).__init__()
# fmt: off
num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES
conv_dims = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM
self.norm = cfg.MODEL.ROI_MASK_HEAD.NORM
num_conv = cfg.MODEL.ROI_MASK_HEAD.NUM_CONV
input_channels = input_shape.channels
cls_agnostic_mask = cfg.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK
# fmt: on
self.conv_norm_relus = []
for k in range(num_conv):
conv = Conv2d(
input_channels if k == 0 else conv_dims,
conv_dims,
kernel_size=3,
stride=1,
padding=1,
bias=not self.norm,
norm=get_norm(self.norm, conv_dims),
activation=F.relu,
)
self.add_module("mask_fcn{}".format(k + 1), conv)
self.conv_norm_relus.append(conv)
self.deconv = ConvTranspose2d(
conv_dims if num_conv > 0 else input_channels,
conv_dims,
kernel_size=2,
stride=2,
padding=0,
)
num_mask_classes = 1 if cls_agnostic_mask else num_classes
self.predictor = Conv2d(conv_dims, num_mask_classes, kernel_size=1, stride=1, padding=0)
for layer in self.conv_norm_relus + [self.deconv]:
weight_init.c2_msra_fill(layer)
# use normal distribution initialization for mask prediction layer
nn.init.normal_(self.predictor.weight, std=0.001)
if self.predictor.bias is not None:
nn.init.constant_(self.predictor.bias, 0)
def forward(self, x):
for layer in self.conv_norm_relus:
x = layer(x)
x = F.relu(self.deconv(x))
return self.predictor(x)
def build_mask_head(cfg, input_shape):
"""
Build a mask head defined by `cfg.MODEL.ROI_MASK_HEAD.NAME`.
"""
name = cfg.MODEL.ROI_MASK_HEAD.NAME
return ROI_MASK_HEAD_REGISTRY.get(name)(cfg, input_shape)