Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
use concatenated arrays instead of 'list of arrays'
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyu2172 committed May 15, 2017
1 parent 4e25beb commit 0599501
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 176 deletions.
54 changes: 28 additions & 26 deletions chainercv/links/model/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def _decide_when_to_stop(self, layers):
return 'start'

rpn_outs = [
'feature', 'rpn_bbox_pred', 'rpn_cls_score',
'proposals', 'anchor']
'features', 'rpn_bboxes', 'rpn_scores',
'rois', 'batch_indices', 'anchor']
for layer in rpn_outs:
layers.pop(layer, None)

Expand All @@ -71,19 +71,20 @@ def _update_if_specified(self, target, source):
if key in target:
target[key] = source[key]

def __call__(self, x, scale=1., layers=['bbox_tfs', 'scores'],
test=True):
def __call__(self, x, scale=1.,
layers=['rois', 'roi_bboxes', 'roi_scores'], test=True):
"""Computes all the feature maps specified by :obj:`layers`.
Here are list of the names of layers that can be collected.
* feature: Feature extractor output.
* rpn_bbox_pred: RPN output.
* rpn_cls_score: RPN output.
* proposals: RPN output.
* features: Feature extractor output.
* rpn_bboxes: RPN output.
* rpn_scores: RPN output.
* rois: RPN output.
* batch_indices: RPN output.
* anchor: RPN output.
* bbox_tfs: Head output.
* scores: Head output.
* roi_bboxes: Head output.
* roi_scores: Head output.
Args:
x (~chainer.Variable): Input variable.
Expand All @@ -106,24 +107,25 @@ def __call__(self, x, scale=1., layers=['bbox_tfs', 'scores'],
img_size = x.shape[2:][::-1]

h = self.feature(x, train=not test)
rpn_bbox_pred, rpn_cls_score, proposals, anchor =\
rpn_bboxes, rpn_scores, rois, batch_indices, anchor =\
self.rpn(h, img_size, scale, train=not test)

self._update_if_specified(
activations,
{'feature': h,
'rpn_bbox_pred': rpn_bbox_pred,
'rpn_cls_score': rpn_cls_score,
'proposals': proposals,
{'features': h,
'rpn_bboxes': rpn_bboxes,
'rpn_scores': rpn_scores,
'rois': rois,
'batch_indices': batch_indices,
'anchor': anchor})
if stop_at == 'rpn':
return activations

bbox_tfs, scores = self.head(h, proposals, train=False)
roi_bboxes, roi_scores = self.head(h, rois, batch_indices, train=False)
self._update_if_specified(
activations,
{'bbox_tfs': bbox_tfs,
'scores': scores})
{'roi_bboxes': roi_bboxes,
'roi_scores': roi_scores})
return activations

def _suppress(self, raw_bbox, raw_prob):
Expand Down Expand Up @@ -193,27 +195,27 @@ def predict(self, imgs):
H, W = img_var.shape[2:]
out = self.__call__(
img_var, scale=scale,
layers=['proposals', 'bbox_tfs', 'scores'])
bbox_tf = out['bbox_tfs'][0]
score = out['scores'][0]
layers=['rois', 'roi_bboxes', 'roi_scores'])
# We are assuming that batch size is 1.
roi_bbox = out['roi_bboxes'].data
roi_score = out['roi_scores'].data
roi = out['rois'] / scale

# Convert predictions to bounding boxes in image coordinates.
# Bounding boxes are scaled to the scale of the input images.
proposal = out['proposals'][0] / scale
bbox_tf_data = bbox_tf.data
mean = self.xp.tile(self.xp.asarray(self.bbox_normalize_mean),
self.n_class)
std = self.xp.tile(self.xp.asarray(self.bbox_normalize_std),
self.n_class)
bbox_tf_data = (bbox_tf_data * std + mean).astype(np.float32)
raw_bbox = bbox_regression_target_inv(proposal, bbox_tf_data)
roi_bbox = (roi_bbox * std + mean).astype(np.float32)
raw_bbox = bbox_regression_target_inv(roi, roi_bbox)
# clip bounding box
raw_bbox[:, slice(0, 4, 2)] = self.xp.clip(
raw_bbox[:, slice(0, 4, 2)], 0, W / scale)
raw_bbox[:, slice(1, 4, 2)] = self.xp.clip(
raw_bbox[:, slice(1, 4, 2)], 0, H / scale)

raw_prob = F.softmax(score).data
raw_prob = F.softmax(roi_score).data

raw_bbox = cuda.to_cpu(raw_bbox)
raw_prob = cuda.to_cpu(raw_prob)
Expand Down
54 changes: 16 additions & 38 deletions chainercv/links/model/faster_rcnn/faster_rcnn_vgg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
import six

import chainer
from chainer import cuda
import chainer.functions as F
import chainer.links as L
from chainer.links import VGG16Layers
Expand All @@ -14,6 +12,10 @@

class FasterRCNNVGG16(FasterRCNNBase):

"""FasterRCNN based on VGG16.
"""

feat_stride = 16

def __init__(self,
Expand Down Expand Up @@ -49,30 +51,6 @@ def __init__(self,
)


def _bboxes_to_roi(bboxes):
xp = cuda.get_array_module(bboxes[0])
bbox_concat = xp.concatenate(bboxes)

batch_index = []
for i, bbox in enumerate(bboxes):
batch_index.append(i * xp.ones(len(bbox), dtype=np.float32))
batch_index = xp.concatenate(batch_index)

roi = xp.concatenate((batch_index[:, None], bbox_concat), axis=1)
return roi


def _batch_to_list(x, separations):
ys = []

separations = [0] + separations
for i in six.moves.range(len(separations) - 1):
start = separations[i]
end = separations[i + 1]
ys.append(x[start:end])
return ys


class VGG16RoIPoolingHead(chainer.Chain):

"""Regress and classify bounding boxes based on RoI pooled features.
Expand All @@ -85,36 +63,36 @@ def __init__(self, n_class, roi_size, spatial_scale,
# these linear links take some time to initialize
fc6=L.Linear(25088, 4096, initialW=fc_initialW),
fc7=L.Linear(4096, 4096, initialW=fc_initialW),
bbox_tf=L.Linear(4096, n_class * 4, initialW=bbox_initialW),
cls_score=L.Linear(4096, n_class, initialW=cls_initialW),
bbox=L.Linear(4096, n_class * 4, initialW=bbox_initialW),
score=L.Linear(4096, n_class, initialW=cls_initialW),
)
self.roi_size = roi_size
self.spatial_scale = spatial_scale

def __call__(self, x, bboxes, train=False):
def __call__(self, x, rois, batch_indices, train=False):
"""Pool and forward batches of patches.
Args:
x (~chainer.Variable):
bboxes (list of arrays):
rois (array)
batch_indices (array)
Returns:
list of chainer.Variable, list of chainer.Variable
"""
lengths = [len(bbox) for bbox in bboxes]
roi = _bboxes_to_roi(bboxes)
batch_indices = batch_indices.astype(np.float32)
rois = self.xp.concatenate(
(batch_indices[:, None], rois), axis=1)
pool = F.roi_pooling_2d(
x, roi, self.roi_size, self.roi_size, self.spatial_scale)
x, rois, self.roi_size, self.roi_size, self.spatial_scale)

fc6 = F.dropout(F.relu(self.fc6(pool)), train=train)
fc7 = F.dropout(F.relu(self.fc7(fc6)), train=train)
bbox_tf = self.bbox_tf(fc7)
score = self.cls_score(fc7)
roi_bboxes = self.bbox(fc7)
roi_scores = self.score(fc7)

bbox_tfs = _batch_to_list(bbox_tf, lengths)
scores = _batch_to_list(score, lengths)
return bbox_tfs, scores
return roi_bboxes, roi_scores


class VGG16FeatureExtractor(VGG16Layers):
Expand Down
62 changes: 37 additions & 25 deletions chainercv/links/model/faster_rcnn/region_proposal_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ class RegionProposalNetwork(chainer.Chain):

"""Region Proposal Networks introduced in Faster RCNN.
This is Region Proposal Networks introduced in Faster RCNN.
This takes features extracted from an image and predict
This is Region Proposal Networks introduced in Faster RCNN [1].
This takes features extracted from an image and predicts
class agnostic bounding boxes around "objects".
.. [1] Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun. \
Faster R-CNN: Towards Real-Time Object Detection with \
Region Proposal Networks. NIPS 2015.
Args:
n_in (int): Channel size of input.
n_mid (int): Channel size of the intermediate tensor.
in_channels (int): Channel size of input.
mid_channels (int): Channel size of the intermediate tensor.
ratios (list of floats): Anchors with ratios contained in this list
will be generated. Ratio is the ratio of the height by the width.
anchor_scales (list of numbers): Values in :obj:`scales` determine area
Expand All @@ -39,7 +43,7 @@ class agnostic bounding boxes around "objects".
"""

def __init__(
self, n_in=512, n_mid=512, ratios=[0.5, 1, 2],
self, in_channels=512, mid_channels=512, ratios=[0.5, 1, 2],
anchor_scales=[8, 16, 32], feat_stride=16,
proposal_creator_params={},
):
Expand All @@ -52,11 +56,11 @@ def __init__(
initializer = chainer.initializers.Normal(scale=0.01)
super(RegionProposalNetwork, self).__init__(
rpn_conv_3x3=L.Convolution2D(
n_in, n_mid, 3, 1, 1, initialW=initializer),
rpn_cls_score=L.Convolution2D(
n_mid, 2 * n_anchor, 1, 1, 0, initialW=initializer),
rpn_bbox_pred=L.Convolution2D(
n_mid, 4 * n_anchor, 1, 1, 0, initialW=initializer)
in_channels, mid_channels, 3, 1, 1, initialW=initializer),
rpn_score=L.Convolution2D(
mid_channels, 2 * n_anchor, 1, 1, 0, initialW=initializer),
rpn_bbox=L.Convolution2D(
mid_channels, 4 * n_anchor, 1, 1, 0, initialW=initializer)
)

def __call__(self, x, img_size, scale=1., train=False):
Expand Down Expand Up @@ -87,37 +91,45 @@ def __call__(self, x, img_size, scale=1., train=False):
Default value is :obj:`False`.
Returns:
(~chainer.Variable, ~chainer.Variable, list of arrays, array):
(~chainer.Variable, ~chainer.Variable, array, array, array):
This is a tuple of four following values.
This is a tuple of five following values.
* **rpn_bbox_pred**: Predicted regression targets for anchors. \
* **rpn_bboxes**: Predicted regression targets for anchors. \
Its shape is :math:`(1, 4 A, H, W)`.
* **rpn_cls_prob**: Predicted foreground probability for \
* **rpn_scores**: Predicted foreground scores for \
anchors. Its shape is :math:`(1, 2 A, H, W)`.
* **proposals**: List of bounding box arrays which contain RoI \
proposals for regions with high objectness. Its length is same\
as the batch size of the inputs.
* **rois**: A bounding box array containing coordinates of \
proposal boxes. The bounding box array is a concatenation of\
bounding box arrays \
from multiple images in the batch. \
Its shape is :math:`(R', 4)`. Given :math:`R_i` predicted \
bounding boxes for the :math:`i` th image and size of batch \
:math:`N`, :math:`R' = \\sum _{i=1} ^ N R_i`. \
Each bouding box is organized by \
:obj:`(x_min, y_min, x_max, y_max)` in the second axis. \
* **batch_indices**: An array containing indices of images to \
which bounding boxes correspond to. Its shape is :math:`(R',)`.
* **anchor**: Coordinates of anchors. This is an array of bounding\
boxes. Its length is :math:`A`.
"""
xp = cuda.get_array_module(x)
n = x.data.shape[0]
h = F.relu(self.rpn_conv_3x3(x))
rpn_cls_score = self.rpn_cls_score(h)
c, hh, ww = rpn_cls_score.shape[1:]
rpn_cls_prob = F.softmax(rpn_cls_score.reshape(n, 2, -1))
rpn_cls_prob = rpn_cls_prob.reshape(n, c, hh, ww)
rpn_bbox_pred = self.rpn_bbox_pred(h)
rpn_scores = self.rpn_score(h)
c, hh, ww = rpn_scores.shape[1:]
rpn_probs = F.softmax(rpn_scores.reshape(n, 2, -1))
rpn_probs = rpn_probs.reshape(n, c, hh, ww)
rpn_bboxes = self.rpn_bbox(h)

# enumerate all shifted anchors
anchor = _enumerate_shifted_anchor(
xp.array(self.anchor_base), self.feat_stride, ww, hh)
proposals = self.proposal_layer(
rpn_bbox_pred, rpn_cls_prob, anchor, img_size,
rois, batch_indices = self.proposal_layer(
rpn_bboxes, rpn_probs, anchor, img_size,
scale=scale, train=train)
return rpn_bbox_pred, rpn_cls_score, proposals, anchor
return rpn_bboxes, rpn_scores, rois, batch_indices, anchor


def _enumerate_shifted_anchor(anchor_base, feat_stride, width, height):
Expand Down

0 comments on commit 0599501

Please sign in to comment.