Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
fix faster r-cnn
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyu2172 committed Jun 6, 2017
1 parent 1e2366a commit 7c0bf7f
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 78 deletions.
6 changes: 3 additions & 3 deletions chainercv/links/model/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def predict(self, imgs):
* **bboxes**: A list of float arrays of shape :math:`(R, 4)`, \
where :math:`R` is the number of bounding boxes in a image. \
Each bouding box is organized by \
:obj:`(x_min, y_min, x_max, y_max)` \
:obj:`(y_min, x_min, y_max, x_max)` \
in the second axis.
* **labels** : A list of integer arrays of shape :math:`(R,)`. \
Each value indicates the class of the bounding box. \
Expand Down Expand Up @@ -305,9 +305,9 @@ def predict(self, imgs):
cls_bbox = cls_bbox.reshape(-1, self.n_class * 4)
# clip bounding box
cls_bbox[:, slice(0, 4, 2)] = self.xp.clip(
cls_bbox[:, slice(0, 4, 2)], 0, W / scale)
cls_bbox[:, slice(0, 4, 2)], 0, H / scale)
cls_bbox[:, slice(1, 4, 2)] = self.xp.clip(
cls_bbox[:, slice(1, 4, 2)], 0, H / scale)
cls_bbox[:, slice(1, 4, 2)], 0, W / scale)

prob = F.softmax(roi_score).data

Expand Down
18 changes: 14 additions & 4 deletions chainercv/links/model/faster_rcnn/faster_rcnn_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class FasterRCNNVGG16(FasterRCNN):
'voc07': {
'n_fg_class': 20,
'url': 'https://github.com/yuyu2172/share-weights/releases/'
'download/0.0.2/faster_rcnn_vgg16_voc07_2017_05_24.npz'
'download/0.0.3/faster_rcnn_vgg16_voc07_2017_06_06.npz'
}
}
feat_stride = 16
Expand Down Expand Up @@ -227,10 +227,11 @@ def __call__(self, x, rois, roi_indices, test=True):
"""
roi_indices = roi_indices.astype(np.float32)
rois = self.xp.concatenate(
indices_and_rois = self.xp.concatenate(
(roi_indices[:, None], rois), axis=1)
pool = F.roi_pooling_2d(
x, rois, self.roi_size, self.roi_size, self.spatial_scale)
pool = _roi_pooling_2d_yx(
x, indices_and_rois, self.roi_size, self.roi_size,
self.spatial_scale)

fc6 = _relu(self.fc6(pool))
fc7 = _relu(self.fc7(fc6))
Expand Down Expand Up @@ -291,5 +292,14 @@ def __call__(self, x, test=True):
return h


def _roi_pooling_2d_yx(x, indices_and_rois, outh, outw, spatial_scale):
xy_indices_and_rois = indices_and_rois.copy()
xy_indices_and_rois[:, [1, 3]] = indices_and_rois[:, [2, 4]]
xy_indices_and_rois[:, [2, 4]] = indices_and_rois[:, [1, 3]]
pool = F.roi_pooling_2d(
x, xy_indices_and_rois, outh, outw, spatial_scale)
return pool


def _max_pooling_2d(x):
return F.max_pooling_2d(x, ksize=2)
12 changes: 6 additions & 6 deletions chainercv/links/model/faster_rcnn/region_proposal_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def __call__(self, x, img_size, scale=1., test=True):
"""
n, _, hh, ww = x.shape
anchor = _enumerate_shifted_anchor(
self.xp.array(self.anchor_base), self.feat_stride, ww, hh)
n_anchor = anchor.shape[0] // (ww * hh)
self.xp.array(self.anchor_base), self.feat_stride, hh, ww)
n_anchor = anchor.shape[0] // (hh * ww)
h = F.relu(self.conv1(x))

rpn_locs = self.loc(h)
Expand Down Expand Up @@ -139,19 +139,19 @@ def __call__(self, x, img_size, scale=1., test=True):
return rpn_locs, rpn_scores, rois, roi_indices, anchor


def _enumerate_shifted_anchor(anchor_base, feat_stride, width, height):
def _enumerate_shifted_anchor(anchor_base, feat_stride, height, width):
# Enumerate all shifted anchors:
#
# add A anchors (1, A, 4) to
# cell K shifts (K, 1, 4) to get
# shift anchors (K, A, 4)
# reshape to (K*A, 4) shifted anchors
xp = cuda.get_array_module(anchor_base)
shift_x = xp.arange(0, width * feat_stride, feat_stride)
shift_y = xp.arange(0, height * feat_stride, feat_stride)
shift_x = xp.arange(0, width * feat_stride, feat_stride)
shift_x, shift_y = xp.meshgrid(shift_x, shift_y)
shift = xp.stack((shift_x.ravel(), shift_y.ravel(),
shift_x.ravel(), shift_y.ravel()), axis=1)
shift = xp.stack((shift_y.ravel(), shift_x.ravel(),
shift_y.ravel(), shift_x.ravel()), axis=1)

A = anchor_base.shape[0]
K = shift.shape[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _get_inside_index(anchor, H, W):
index_inside = xp.where(
(anchor[:, 0] >= 0) &
(anchor[:, 1] >= 0) &
(anchor[:, 2] <= W) & # width
(anchor[:, 3] <= H) # height
(anchor[:, 2] <= H) &
(anchor[:, 3] <= W)
)[0]
return index_inside
41 changes: 21 additions & 20 deletions chainercv/links/model/faster_rcnn/utils/bbox2loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ def bbox2loc(src_bbox, dst_bbox):
Given bounding boxes, this function computes offsets and scales
to match the source bounding boxes to the target bounding boxes.
Mathematcially, given a bounding box whose center is :math:`p_x, p_y` and
size :math:`p_w, p_h` and the target bounding box whose center is
:math:`g_x, g_y` and size :math:`g_w, g_h`, the offsets and scales
:math:`t_x, t_y, t_w, t_h` can be computed by the following formulas.
Mathematcially, given a bounding box whose center is
:math:`(y, x) = p_y, p_x` and
size :math:`p_h, p_w` and the target bounding box whose center is
:math:`g_y, g_x` and size :math:`g_h, g_w`, the offsets and scales
:math:`t_y, t_x, t_h, t_w` can be computed by the following formulas.
* :math:`t_x = \\frac{(g_x - p_x)} {p_w}`
* :math:`t_y = \\frac{(g_y - p_y)} {p_h}`
* :math:`t_w = \\log(\\frac{g_w} {p_w})`
* :math:`t_x = \\frac{(g_x - p_x)} {p_w}`
* :math:`t_h = \\log(\\frac{g_h} {p_h})`
* :math:`t_w = \\log(\\frac{g_w} {p_w})`
The output is same type as the type of the inputs.
The encoding formulas are used in works such as R-CNN [#]_.
Expand All @@ -26,35 +27,35 @@ def bbox2loc(src_bbox, dst_bbox):
Args:
src_bbox (array): An image coordinate array whose shape is
:math:`(R, 4)`. :math:`R` is the number of bounding boxes.
These coordinates are used to compute :math:`p_x, p_y, p_w, p_h`.
These coordinates are used to compute :math:`p_y, p_x, p_h, p_w`.
dst_bbox (array): An image coordinate array whose shape is
:math:`(R, 4)`.
These coordinates are used to compute :math:`g_x, g_y, g_w, g_h`.
These coordinates are used to compute :math:`g_y, g_x, g_h, g_w`.
Returns:
array:
Bounding box offsets and scales from :obj:`src_bbox` \
to :obj:`dst_bbox`. \
This has shape :math:`(R, 4)`.
The second axis contains four values :math:`t_x, t_y, t_w, t_h`.
The second axis contains four values :math:`t_y, t_x, t_h, t_w`.
"""
xp = cuda.get_array_module(src_bbox)

width = src_bbox[:, 2] - src_bbox[:, 0]
height = src_bbox[:, 3] - src_bbox[:, 1]
ctr_x = src_bbox[:, 0] + 0.5 * width
ctr_y = src_bbox[:, 1] + 0.5 * height
height = src_bbox[:, 2] - src_bbox[:, 0]
width = src_bbox[:, 3] - src_bbox[:, 1]
ctr_y = src_bbox[:, 0] + 0.5 * height
ctr_x = src_bbox[:, 1] + 0.5 * width

base_width = dst_bbox[:, 2] - dst_bbox[:, 0]
base_height = dst_bbox[:, 3] - dst_bbox[:, 1]
base_ctr_x = dst_bbox[:, 0] + 0.5 * base_width
base_ctr_y = dst_bbox[:, 1] + 0.5 * base_height
base_height = dst_bbox[:, 2] - dst_bbox[:, 0]
base_width = dst_bbox[:, 3] - dst_bbox[:, 1]
base_ctr_y = dst_bbox[:, 0] + 0.5 * base_height
base_ctr_x = dst_bbox[:, 1] + 0.5 * base_width

dx = (base_ctr_x - ctr_x) / width
dy = (base_ctr_y - ctr_y) / height
dw = xp.log(base_width / width)
dx = (base_ctr_x - ctr_x) / width
dh = xp.log(base_height / height)
dw = xp.log(base_width / width)

loc = xp.vstack((dx, dy, dw, dh)).transpose()
loc = xp.vstack((dy, dx, dh, dw)).transpose()
return loc
14 changes: 7 additions & 7 deletions chainercv/links/model/faster_rcnn/utils/generate_anchor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,23 @@ def generate_anchor_base(base_size=16, ratios=[0.5, 1, 2],
~numpy.ndarray:
An array of shape :math:`(R, 4)`.
Each element is a set of coordinates of a bounding box.
The second axis corresponds to :obj:`x_min, y_min, x_max, y_max`
The second axis corresponds to :obj:`y_min, x_min, y_max, x_max`
of a bounding box.
"""
px = base_size / 2.
py = base_size / 2.
px = base_size / 2.

anchor_base = np.zeros((len(ratios) * len(anchor_scales), 4),
dtype=np.float32)
for i in six.moves.range(len(ratios)):
for j in six.moves.range(len(anchor_scales)):
w = base_size * anchor_scales[j] * np.sqrt(1. / ratios[i])
h = base_size * anchor_scales[j] * np.sqrt(ratios[i])
w = base_size * anchor_scales[j] * np.sqrt(1. / ratios[i])

index = i * len(anchor_scales) + j
anchor_base[index, 0] = px - w / 2.
anchor_base[index, 1] = py - h / 2.
anchor_base[index, 2] = px + w / 2.
anchor_base[index, 3] = py + h / 2.
anchor_base[index, 0] = py - h / 2.
anchor_base[index, 1] = px - w / 2.
anchor_base[index, 2] = py + h / 2.
anchor_base[index, 3] = px + w / 2.
return anchor_base
46 changes: 23 additions & 23 deletions chainercv/links/model/faster_rcnn/utils/loc2bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ def loc2bbox(src_bbox, loc):
:meth:`bbox2loc`, this function decodes the representation to
coordinates in 2D image coordinates.
Given scales and offsets :math:`t_x, t_y, t_w, t_h` and a bounding
box whose center is :math:`p_x, p_y` and size :math:`p_w, p_h`,
the decoded bounding box's center :math:`\\hat{g}_x`, :math:`\\hat{g}_y`
and size :math:`\\hat{g}_w`, :math:`\\hat{g}_h` are calculated
Given scales and offsets :math:`t_y, t_x, t_h, t_w` and a bounding
box whose center is :math:`(y, x) = p_y, p_x` and size :math:`p_h, p_w`,
the decoded bounding box's center :math:`\\hat{g}_y`, :math:`\\hat{g}_x`
and size :math:`\\hat{g}_h`, :math:`\\hat{g}_w` are calculated
by the following formulas.
* :math:`\\hat{g}_x = p_w t_x + p_x`
* :math:`\\hat{g}_y = p_h t_y + p_y`
* :math:`\\hat{g}_w = p_w \\exp(t_w)`
* :math:`\\hat{g}_x = p_w t_x + p_x`
* :math:`\\hat{g}_h = p_h \\exp(t_h)`
* :math:`\\hat{g}_w = p_w \\exp(t_w)`
The decoding formulas are used in works such as R-CNN [#]_.
Expand All @@ -30,16 +30,16 @@ def loc2bbox(src_bbox, loc):
Args:
src_bbox (array): A coordinates of bounding boxes.
Its shape is :math:`(R, 4)`. These coordinates are used to
compute :math:`p_x, p_y, p_w, p_h`.
compute :math:`p_y, p_x, p_h, p_w`.
loc (array): An array with offsets and scales.
The shapes of :obj:`src_bbox` and :obj:`loc` should be same.
This contains values :math:`t_x, t_y, t_w, t_h`.
This contains values :math:`t_y, t_x, t_h, t_w`.
Returns:
array:
Decoded bounding box coordinates. Its shape is :math:`(R, 4)`. \
The second axis contains four values \
:math:`\\hat{g}_x, \\hat{g}_y, \\hat{g}_w, \\hat{g}_h`.
:math:`\\hat{g}_y, \\hat{g}_x, \\hat{g}_h, \\hat{g}_w`.
"""
xp = cuda.get_array_module(src_bbox)
Expand All @@ -49,25 +49,25 @@ def loc2bbox(src_bbox, loc):

src_bbox = src_bbox.astype(src_bbox.dtype, copy=False)

src_width = src_bbox[:, 2] - src_bbox[:, 0]
src_height = src_bbox[:, 3] - src_bbox[:, 1]
src_ctr_x = src_bbox[:, 0] + 0.5 * src_width
src_ctr_y = src_bbox[:, 1] + 0.5 * src_height
src_height = src_bbox[:, 2] - src_bbox[:, 0]
src_width = src_bbox[:, 3] - src_bbox[:, 1]
src_ctr_y = src_bbox[:, 0] + 0.5 * src_height
src_ctr_x = src_bbox[:, 1] + 0.5 * src_width

dx = loc[:, 0::4]
dy = loc[:, 1::4]
dw = loc[:, 2::4]
dh = loc[:, 3::4]
dy = loc[:, 0::4]
dx = loc[:, 1::4]
dh = loc[:, 2::4]
dw = loc[:, 3::4]

ctr_x = dx * src_width[:, xp.newaxis] + src_ctr_x[:, xp.newaxis]
ctr_y = dy * src_height[:, xp.newaxis] + src_ctr_y[:, xp.newaxis]
w = xp.exp(dw) * src_width[:, xp.newaxis]
ctr_x = dx * src_width[:, xp.newaxis] + src_ctr_x[:, xp.newaxis]
h = xp.exp(dh) * src_height[:, xp.newaxis]
w = xp.exp(dw) * src_width[:, xp.newaxis]

dst_bbox = xp.zeros(loc.shape, dtype=loc.dtype)
dst_bbox[:, 0::4] = ctr_x - 0.5 * w
dst_bbox[:, 1::4] = ctr_y - 0.5 * h
dst_bbox[:, 2::4] = ctr_x + 0.5 * w
dst_bbox[:, 3::4] = ctr_y + 0.5 * h
dst_bbox[:, 0::4] = ctr_y - 0.5 * h
dst_bbox[:, 1::4] = ctr_x - 0.5 * w
dst_bbox[:, 2::4] = ctr_y + 0.5 * h
dst_bbox[:, 3::4] = ctr_x + 0.5 * w

return dst_bbox
10 changes: 5 additions & 5 deletions chainercv/links/model/faster_rcnn/utils/proposal_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ def __call__(self, loc, score,

# Clip predicted boxes to image.
roi[:, slice(0, 4, 2)] = np.clip(
roi[:, slice(0, 4, 2)], 0, img_size[1])
roi[:, slice(0, 4, 2)], 0, img_size[0])
roi[:, slice(1, 4, 2)] = np.clip(
roi[:, slice(1, 4, 2)], 0, img_size[0])
roi[:, slice(1, 4, 2)], 0, img_size[1])

# Remove predicted boxes with either height or width < threshold.
min_size = self.min_size * scale
ws = roi[:, 2] - roi[:, 0]
hs = roi[:, 3] - roi[:, 1]
keep = np.where((ws >= min_size) & (hs >= min_size))[0]
hs = roi[:, 2] - roi[:, 0]
ws = roi[:, 3] - roi[:, 1]
keep = np.where((hs >= min_size) & (ws >= min_size))[0]
roi = roi[keep, :]
score = score[keep]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def check_call(self):
x = chainer.Variable(
xp.random.uniform(
low=-1., high=1.,
size=(self.B, 3, feat_size[1] * 16, feat_size[0] * 16)
size=(self.B, 3, feat_size[0] * 16, feat_size[1] * 16)
).astype(np.float32), volatile=chainer.flag.ON)
roi_cls_locs, roi_scores, rois, roi_indices = self.link(
x, test=not self.train)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setUp(self):
proposal_creator_params=self.proposal_creator_params
)
self.x = np.random.uniform(size=(self.B, C, H, W)).astype(np.float32)
self.img_size = (W * feat_stride, H * feat_stride)
self.img_size = (H * feat_stride, W * feat_stride)

def _check_call(self, x, img_size, test):
_, _, H, W = x.shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ class TestGenerateAnchorBase(unittest.TestCase):

def test_generaete_anchor_base(self):
gt = np.array(
[[-120., -24., 136., 40.],
[-248., -56., 264., 72.],
[-504., -120., 520., 136.],
[[-24., -120., 40., 136.],
[-56., -248., 72., 264.],
[-120., -504., 136., 520.],
[-56., -56., 72., 72.],
[-120., -120., 136., 136.],
[-248., -248., 264., 264.],
[-24., -120., 40., 136.],
[-56., -248., 72., 264.],
[-120., -504., 136., 520.]])
[-120., -24., 136., 40.],
[-248., -56., 264., 72.],
[-504., -120., 520., 136.]])

base_size = 16
anchor_scales = [8, 16, 32]
Expand Down

0 comments on commit 7c0bf7f

Please sign in to comment.