Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
Merge pull request #608 from knorth55/fcis-name-conv
Browse files Browse the repository at this point in the history
update fcis variable names
  • Loading branch information
yuyu2172 committed May 22, 2018
2 parents 6eb81ad + 3d93866 commit 59f805e
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 87 deletions.
30 changes: 16 additions & 14 deletions chainercv/experimental/links/model/fcis/fcis.py
Expand Up @@ -119,11 +119,11 @@ def __call__(self, x, scale=1.):
Variable, Variable, Variable, array, array:
Returns tuple of five values listed below.
* **roi_cmask_scores**: Class-agnostic clipped mask scores for \
* **roi_ag_seg_scores**: Class-agnostic clipped mask scores for \
the proposed ROIs. Its shape is :math:`(R', 2, RH, RW)`
* **ag_locs**: Class-agnostic offsets and scalings for \
the proposed RoIs. Its shape is :math:`(R', 2, 4)`.
* **scores**: Class predictions for the proposed RoIs. \
* **roi_cls_scores**: Class predictions for the proposed RoIs. \
Its shape is :math:`(R', L + 1)`.
* **rois**: RoIs proposed by RPN. Its shape is \
:math:`(R', 4)`.
Expand All @@ -137,9 +137,10 @@ def __call__(self, x, scale=1.):
rpn_features, roi_features = self.extractor(x)
rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn(
rpn_features, img_size, scale)
roi_cmask_scores, roi_ag_locs, roi_cls_scores, rois, roi_indices = \
roi_ag_seg_scores, roi_ag_locs, roi_cls_scores, rois, roi_indices = \
self.head(roi_features, rois, roi_indices, img_size)
return roi_cmask_scores, roi_ag_locs, roi_cls_scores, rois, roi_indices
return roi_ag_seg_scores, roi_ag_locs, roi_cls_scores, \
rois, roi_indices

def prepare(self, img):
"""Preprocess an image for feature extraction.
Expand Down Expand Up @@ -257,27 +258,28 @@ def predict(self, imgs):
# inference
img_var = chainer.Variable(self.xp.array(img[None]))
scale = img_var.shape[3] / size[1]
roi_cmask_scores, _, roi_cls_scores, bboxes, _ = \
roi_ag_seg_scores, _, roi_cls_scores, bboxes, _ = \
self.__call__(img_var, scale)

# We are assuming that batch size is 1.
roi_cmask_score = roi_cmask_scores.array
roi_ag_seg_score = roi_ag_seg_scores.array
roi_cls_score = roi_cls_scores.array
bbox = bboxes / scale

# shape: (n_rois, 4)
bbox[:, 0::2] = self.xp.clip(bbox[:, 0::2], 0, size[0])
bbox[:, 1::2] = self.xp.clip(bbox[:, 1::2], 0, size[1])

roi_cmask_prob = F.softmax(roi_cmask_score).array[:, 1, :, :]
# shape: (n_roi, roi_size, roi_size)
roi_seg_prob = F.softmax(roi_ag_seg_score).array[:, 1, :, :]
roi_cls_prob = F.softmax(roi_cls_score).array

roi_cmask_prob = chainer.cuda.to_cpu(roi_cmask_prob)
roi_seg_prob = chainer.cuda.to_cpu(roi_seg_prob)
roi_cls_prob = chainer.cuda.to_cpu(roi_cls_prob)
bbox = chainer.cuda.to_cpu(bbox)

roi_cmask_prob, bbox, label, roi_cls_prob = mask_voting(
roi_cmask_prob, bbox, roi_cls_prob, size,
roi_seg_prob, bbox, label, roi_cls_prob = mask_voting(
roi_seg_prob, bbox, roi_cls_prob, size,
self.score_thresh, self.nms_thresh,
self.mask_merge_thresh, self.binary_thresh,
limit=self.limit, bg_label=0)
Expand All @@ -287,18 +289,18 @@ def predict(self, imgs):
keep_indices = np.where(
(height > self.min_drop_size) &
(width > self.min_drop_size))[0]
roi_cmask_prob = roi_cmask_prob[keep_indices]
roi_seg_prob = roi_seg_prob[keep_indices]
bbox = bbox[keep_indices]
label = label[keep_indices]
roi_cls_prob = roi_cls_prob[keep_indices]

mask = np.zeros(
(len(roi_cmask_prob), size[0], size[1]), dtype=np.bool)
for i, (roi_cmsk_pb, bb) in enumerate(zip(roi_cmask_prob, bbox)):
(len(roi_seg_prob), size[0], size[1]), dtype=np.bool)
for i, (roi_seg_pb, bb) in enumerate(zip(roi_seg_prob, bbox)):
bb = np.round(bb).astype(np.int32)
y_min, x_min, y_max, x_max = bb
roi_msk_pb = resize(
roi_cmsk_pb.astype(np.float32)[None],
roi_seg_pb.astype(np.float32)[None],
(y_max - y_min, x_max - x_min))
roi_msk = (roi_msk_pb > self.binary_thresh)[0]
mask[i, y_min:y_max, x_min:x_max] = roi_msk
Expand Down
23 changes: 12 additions & 11 deletions chainercv/experimental/links/model/fcis/fcis_resnet101.py
Expand Up @@ -271,7 +271,7 @@ def __call__(self, x, rois, roi_indices, img_size):
h_ag_loc = self.ag_loc(h)

# PSROI pooling and regression
roi_cmask_scores, roi_ag_locs, roi_cls_scores = self._pool(
roi_ag_seg_scores, roi_ag_locs, roi_cls_scores = self._pool(
h_cls_seg, h_ag_loc, rois, roi_indices)
if self.iter2:
# 2nd Iteration
Expand All @@ -287,29 +287,30 @@ def __call__(self, x, rois, roi_indices, img_size):
rois2[:, 1::2] = self.xp.clip(rois2[:, 1::2], 0, img_size[1])

# PSROI pooling and regression
roi_cmask_scores2, roi_ag_locs2, roi_cls_scores2 = self._pool(
roi_ag_seg_scores2, roi_ag_locs2, roi_cls_scores2 = self._pool(
h_cls_seg, h_ag_loc, rois2, roi_indices)

# concat 1st and 2nd iteration results
rois = self.xp.concatenate((rois, rois2))
roi_indices = self.xp.concatenate((roi_indices, roi_indices))
roi_cmask_scores = F.concat(
(roi_cmask_scores, roi_cmask_scores2), axis=0)
roi_ag_seg_scores = F.concat(
(roi_ag_seg_scores, roi_ag_seg_scores2), axis=0)
roi_ag_locs = F.concat(
(roi_ag_locs, roi_ag_locs2), axis=0)
roi_cls_scores = F.concat(
(roi_cls_scores, roi_cls_scores2), axis=0)
return roi_cmask_scores, roi_ag_locs, roi_cls_scores, rois, roi_indices
return roi_ag_seg_scores, roi_ag_locs, roi_cls_scores, \
rois, roi_indices

def _pool(
self, h_cls_seg, h_ag_loc, rois, roi_indices):
# PSROI Pooling
# shape: (n_roi, n_class*2, roi_size, roi_size)
roi_seg_scores = psroi_pooling_2d(
# shape: (n_roi, n_class, 2, roi_size, roi_size)
roi_cls_ag_seg_scores = psroi_pooling_2d(
h_cls_seg, rois, roi_indices,
self.n_class * 2, self.roi_size, self.roi_size,
self.spatial_scale, self.group_size)
roi_seg_scores = roi_seg_scores.reshape(
roi_cls_ag_seg_scores = roi_cls_ag_seg_scores.reshape(
(-1, self.n_class, 2, self.roi_size, self.roi_size))

# shape: (n_roi, 2*4, roi_size, roi_size)
Expand All @@ -320,7 +321,7 @@ def _pool(

# shape: (n_roi, n_class)
roi_cls_scores = _global_average_pooling_2d(
F.max(roi_seg_scores, axis=2))
F.max(roi_cls_ag_seg_scores, axis=2))

# Bbox Regression
# shape: (n_roi, 2*4)
Expand All @@ -331,10 +332,10 @@ def _pool(
# shape: (n_roi, n_class, 2, roi_size, roi_size)
max_cls_indices = roi_cls_scores.array.argmax(axis=1)
# shape: (n_roi, 2, roi_size, roi_size)
roi_cmask_scores = roi_seg_scores[
roi_ag_seg_scores = roi_cls_ag_seg_scores[
self.xp.arange(len(max_cls_indices)), max_cls_indices]

return roi_cmask_scores, roi_ag_locs, roi_cls_scores
return roi_ag_seg_scores, roi_ag_locs, roi_cls_scores


def _global_average_pooling_2d(x):
Expand Down
64 changes: 32 additions & 32 deletions chainercv/experimental/links/model/fcis/utils/mask_voting.py
Expand Up @@ -6,22 +6,22 @@


def _mask_aggregation(
bbox, cmask_prob, cmask_weight,
bbox, seg_prob, seg_weight,
size, binary_thresh
):
assert bbox.shape[0] == len(cmask_prob)
assert bbox.shape[0] == cmask_weight.shape[0]
assert bbox.shape[0] == len(seg_prob)
assert bbox.shape[0] == seg_weight.shape[0]

aggregated_msk = np.zeros(size, dtype=np.float32)
for bb, cmsk_pb, cmsk_w in zip(bbox, cmask_prob, cmask_weight):
for bb, seg_pb, seg_w in zip(bbox, seg_prob, seg_weight):
bb = np.round(bb).astype(np.int32)
y_min, x_min, y_max, x_max = bb
if y_max - y_min > 0 and x_max - x_min > 0:
cmsk_pb = resize(
cmsk_pb.astype(np.float32)[None],
seg_pb = resize(
seg_pb.astype(np.float32)[None],
(y_max - y_min, x_max - x_min))
cmsk_m = (cmsk_pb >= binary_thresh).astype(np.float32)[0]
aggregated_msk[y_min:y_max, x_min:x_max] += cmsk_m * cmsk_w
seg_m = (seg_pb >= binary_thresh).astype(np.float32)[0]
aggregated_msk[y_min:y_max, x_min:x_max] += seg_m * seg_w

y_indices, x_indices = np.where(aggregated_msk >= binary_thresh)
if len(y_indices) == 0 or len(x_indices) == 0:
Expand All @@ -40,7 +40,7 @@ def _mask_aggregation(


def mask_voting(
roi_cmask_prob, bbox, roi_cls_prob, size,
seg_prob, bbox, cls_prob, size,
score_thresh, nms_thresh,
mask_merge_thresh, binary_thresh,
limit=100, bg_label=0
Expand All @@ -54,13 +54,13 @@ def mask_voting(
predicted as the same object class.
Here are notations used.
* :math:`R'` is the total number of RoIs produced across batches.
* :math:`R` is the total number of RoIs produced in one image.
* :math:`L` is the number of classes excluding the background.
* :math:`RH` is the height of pooled image.
* :math:`RW` is the height of pooled image.
Args:
roi_cmask_prob (array): A mask probability array whose shape is
seg_prob (array): A mask probability array whose shape is
:math:`(R, RH, RW)`.
bbox (array): A bounding box array whose shape is
:math:`(R, 4)`.
Expand All @@ -78,7 +78,7 @@ def mask_voting(
Returns:
array, array, array, array:
* **v_cmask_prob**: Merged mask probability. Its shapes is \
* **v_seg_prob**: Merged mask probability. Its shapes is \
:math:`(N, RH, RW)`.
* **v_bbox**: Bounding boxes for the merged masks. Its shape is \
:math:`(N, 4)`.
Expand All @@ -89,10 +89,10 @@ def mask_voting(
"""

roi_cmask_size = roi_cmask_prob.shape[1:]
n_class = roi_cls_prob.shape[1]
seg_size = seg_prob.shape[1:]
n_class = cls_prob.shape[1]

v_cmask_prob = []
v_seg_prob = []
v_bbox = []
v_label = []
v_cls_prob = []
Expand All @@ -105,7 +105,7 @@ def mask_voting(
if label == bg_label:
continue
# non maximum suppression
score_l = roi_cls_prob[:, label]
score_l = cls_prob[:, label]
keep_indices = non_maximum_suppression(
bbox, nms_thresh, score_l)
bbox_l = bbox[keep_indices]
Expand All @@ -127,43 +127,43 @@ def mask_voting(
bbox_l = bbox_l[keep_indices]
score_l = score_l[keep_indices]

v_cmask_prob_l = []
v_seg_prob_l = []
v_bbox_l = []
v_score_l = []

for i, bb in enumerate(bbox_l):
iou = bbox_iou(bbox, bb[np.newaxis, :])
keep_indices = np.where(iou >= mask_merge_thresh)[0]
cmask_weight = roi_cls_prob[keep_indices, label]
cmask_weight = cmask_weight / cmask_weight.sum()
cmask_prob_i = roi_cmask_prob[keep_indices]
seg_weight = cls_prob[keep_indices, label]
seg_weight = seg_weight / seg_weight.sum()
seg_prob_i = seg_prob[keep_indices]
bbox_i = bbox[keep_indices]
m_cmask, m_bbox = _mask_aggregation(
bbox_i, cmask_prob_i, cmask_weight, size, binary_thresh)
if m_cmask is not None and m_bbox is not None:
m_cmask = resize(m_cmask, roi_cmask_size)
m_cmask = np.clip(m_cmask, 0.0, 1.0)
v_cmask_prob_l.append(m_cmask)
m_seg, m_bbox = _mask_aggregation(
bbox_i, seg_prob_i, seg_weight, size, binary_thresh)
if m_seg is not None and m_bbox is not None:
m_seg = resize(m_seg, seg_size)
m_seg = np.clip(m_seg, 0.0, 1.0)
v_seg_prob_l.append(m_seg)
v_bbox_l.append(m_bbox)
v_score_l.append(score_l[i])

if len(v_cmask_prob_l) > 0:
if len(v_seg_prob_l) > 0:
v_label_l = np.repeat(
label - 1, len(v_score_l)).astype(np.int32)

v_cmask_prob += v_cmask_prob_l
v_seg_prob += v_seg_prob_l
v_bbox += v_bbox_l
v_label.append(v_label_l)
v_cls_prob.append(v_score_l)

if len(v_cmask_prob) > 0:
v_cmask_prob = np.concatenate(v_cmask_prob)
if len(v_seg_prob) > 0:
v_seg_prob = np.concatenate(v_seg_prob)
v_bbox = np.concatenate(v_bbox)
v_label = np.concatenate(v_label)
v_cls_prob = np.concatenate(v_cls_prob)
else:
v_cmask_prob = np.empty((0, roi_cmask_size[0], roi_cmask_size[1]))
v_seg_prob = np.empty((0, seg_size[0], seg_size[1]))
v_bbox = np.empty((0, 4))
v_label = np.empty((0, ))
v_cls_prob = np.empty((0, ))
return v_cmask_prob, v_bbox, v_label, v_cls_prob
return v_seg_prob, v_bbox, v_label, v_cls_prob
Expand Up @@ -43,16 +43,16 @@ def __call__(self, x, rois, roi_indices, img_size):
_random_array(self.xp, (n_roi, 2, 4)))
# For each bbox, the score for a selected class is
# overwhelmingly higher than the scores for the other classes.
seg_scores = chainer.Variable(
ag_seg_scores = chainer.Variable(
_random_array(
self.xp, (n_roi, 2, self.roi_size, self.roi_size)))
score_idx = np.random.randint(
low=0, high=self.n_class, size=(n_roi,))
scores = self.xp.zeros((n_roi, self.n_class), dtype=np.float32)
scores[np.arange(n_roi), score_idx] = 100
scores = chainer.Variable(scores)
cls_scores = self.xp.zeros((n_roi, self.n_class), dtype=np.float32)
cls_scores[np.arange(n_roi), score_idx] = 100
cls_scores = chainer.Variable(cls_scores)

return seg_scores, ag_locs, scores, rois, roi_indices
return ag_seg_scores, ag_locs, cls_scores, rois, roi_indices


class DummyRegionProposalNetwork(chainer.Chain):
Expand Down Expand Up @@ -119,22 +119,22 @@ def check_call(self):
xp = self.link.xp

x1 = chainer.Variable(_random_array(xp, (1, 3, 600, 800)))
roi_seg_scores, roi_ag_locs, roi_scores, rois, roi_indices = \
roi_ag_seg_scores, roi_ag_locs, roi_cls_scores, rois, roi_indices = \
self.link(x1)

self.assertIsInstance(roi_seg_scores, chainer.Variable)
self.assertIsInstance(roi_seg_scores.array, xp.ndarray)
self.assertIsInstance(roi_ag_seg_scores, chainer.Variable)
self.assertIsInstance(roi_ag_seg_scores.array, xp.ndarray)
self.assertEqual(
roi_seg_scores.shape,
roi_ag_seg_scores.shape,
(self.n_roi, 2, self.roi_size, self.roi_size))

self.assertIsInstance(roi_ag_locs, chainer.Variable)
self.assertIsInstance(roi_ag_locs.array, xp.ndarray)
self.assertEqual(roi_ag_locs.shape, (self.n_roi, 2, 4))

self.assertIsInstance(roi_scores, chainer.Variable)
self.assertIsInstance(roi_scores.array, xp.ndarray)
self.assertEqual(roi_scores.shape, (self.n_roi, self.n_class))
self.assertIsInstance(roi_cls_scores, chainer.Variable)
self.assertIsInstance(roi_cls_scores.array, xp.ndarray)
self.assertEqual(roi_cls_scores.shape, (self.n_roi, self.n_class))

self.assertIsInstance(rois, xp.ndarray)
self.assertEqual(rois.shape, (self.n_roi, 4))
Expand Down
Expand Up @@ -42,27 +42,27 @@ def check_call(self):
low=-1., high=1.,
size=(self.B, 3, feat_size[0] * 16, feat_size[1] * 16)
).astype(np.float32))
roi_seg_scores, roi_ag_locs, roi_scores, rois, roi_indices = \
roi_ag_seg_scores, roi_ag_locs, roi_cls_scores, rois, roi_indices = \
self.link(x)

n_roi = roi_seg_scores.shape[0]
n_roi = roi_ag_seg_scores.shape[0]
if self.train:
self.assertGreaterEqual(self.B * self.n_train_post_nms, n_roi)
else:
self.assertGreaterEqual(self.B * self.n_test_post_nms * 2, n_roi)

self.assertIsInstance(roi_seg_scores, chainer.Variable)
self.assertIsInstance(roi_seg_scores.array, xp.ndarray)
self.assertIsInstance(roi_ag_seg_scores, chainer.Variable)
self.assertIsInstance(roi_ag_seg_scores.array, xp.ndarray)
self.assertEqual(
roi_seg_scores.shape, (n_roi, 2, 21, 21))
roi_ag_seg_scores.shape, (n_roi, 2, 21, 21))

self.assertIsInstance(roi_ag_locs, chainer.Variable)
self.assertIsInstance(roi_ag_locs.array, xp.ndarray)
self.assertEqual(roi_ag_locs.shape, (n_roi, 2, 4))

self.assertIsInstance(roi_scores, chainer.Variable)
self.assertIsInstance(roi_scores.array, xp.ndarray)
self.assertEqual(roi_scores.shape, (n_roi, self.n_class))
self.assertIsInstance(roi_cls_scores, chainer.Variable)
self.assertIsInstance(roi_cls_scores.array, xp.ndarray)
self.assertEqual(roi_cls_scores.shape, (n_roi, self.n_class))

self.assertIsInstance(rois, xp.ndarray)
self.assertEqual(rois.shape, (n_roi, 4))
Expand Down

0 comments on commit 59f805e

Please sign in to comment.