Skip to content

Commit

Permalink
Add back old code, and add a config option
Browse files Browse the repository at this point in the history
  • Loading branch information
Detry322 committed Apr 13, 2018
1 parent 4f3f3f7 commit 6538e3c
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 30 deletions.
43 changes: 42 additions & 1 deletion lib/layer_utils/proposal_layer.py
Expand Up @@ -13,8 +13,47 @@
from model.bbox_transform import bbox_transform_inv, clip_boxes, bbox_transform_inv_tf, clip_boxes_tf
from model.nms_wrapper import nms


def proposal_layer(rpn_cls_prob, rpn_bbox_pred, im_info, cfg_key, _feat_stride, anchors, num_anchors):
"""A simplified version compared to fast/er RCNN
For details please see the technical report
"""
if type(cfg_key) == bytes:
cfg_key = cfg_key.decode('utf-8')
pre_nms_topN = cfg[cfg_key].RPN_PRE_NMS_TOP_N
post_nms_topN = cfg[cfg_key].RPN_POST_NMS_TOP_N
nms_thresh = cfg[cfg_key].RPN_NMS_THRESH

# Get the scores and bounding boxes
scores = rpn_cls_prob[:, :, :, num_anchors:]
rpn_bbox_pred = rpn_bbox_pred.reshape((-1, 4))
scores = scores.reshape((-1, 1))
proposals = bbox_transform_inv(anchors, rpn_bbox_pred)
proposals = clip_boxes(proposals, im_info[:2])

# Pick the top region proposals
order = scores.ravel().argsort()[::-1]
if pre_nms_topN > 0:
order = order[:pre_nms_topN]
proposals = proposals[order, :]
scores = scores[order]

# Non-maximal suppression
keep = nms(np.hstack((proposals, scores)), nms_thresh)

# Pick th top region proposals after NMS
if post_nms_topN > 0:
keep = keep[:post_nms_topN]
proposals = proposals[keep, :]
scores = scores[keep]

# Only support single image as input
batch_inds = np.zeros((proposals.shape[0], 1), dtype=np.float32)
blob = np.hstack((batch_inds, proposals.astype(np.float32, copy=False)))

return blob, scores


def proposal_layer_tf(rpn_cls_prob, rpn_bbox_pred, im_info, cfg_key, _feat_stride, anchors, num_anchors):
if type(cfg_key) == bytes:
cfg_key = cfg_key.decode('utf-8')
pre_nms_topN = cfg[cfg_key].RPN_PRE_NMS_TOP_N
Expand Down Expand Up @@ -42,3 +81,5 @@ def proposal_layer(rpn_cls_prob, rpn_bbox_pred, im_info, cfg_key, _feat_stride,
blob = tf.concat([batch_inds, boxes], 1)

return blob, scores


43 changes: 43 additions & 0 deletions lib/layer_utils/proposal_top_layer.py
Expand Up @@ -11,6 +11,8 @@
from model.bbox_transform import bbox_transform_inv, clip_boxes, bbox_transform_inv_tf, clip_boxes_tf

import tensorflow as tf
import numpy as np
import numpy.random as npr

def proposal_top_layer(rpn_cls_prob, rpn_bbox_pred, im_info, _feat_stride, anchors, num_anchors):
"""A layer that just selects the top region proposals
Expand All @@ -19,6 +21,47 @@ def proposal_top_layer(rpn_cls_prob, rpn_bbox_pred, im_info, _feat_stride, ancho
"""
rpn_top_n = cfg.TEST.RPN_TOP_N

scores = rpn_cls_prob[:, :, :, num_anchors:]

rpn_bbox_pred = rpn_bbox_pred.reshape((-1, 4))
scores = scores.reshape((-1, 1))

length = scores.shape[0]
if length < rpn_top_n:
# Random selection, maybe unnecessary and loses good proposals
# But such case rarely happens
top_inds = npr.choice(length, size=rpn_top_n, replace=True)
else:
top_inds = scores.argsort(0)[::-1]
top_inds = top_inds[:rpn_top_n]
top_inds = top_inds.reshape(rpn_top_n, )

# Do the selection here
anchors = anchors[top_inds, :]
rpn_bbox_pred = rpn_bbox_pred[top_inds, :]
scores = scores[top_inds]

# Convert anchors into proposals via bbox transformations
proposals = bbox_transform_inv(anchors, rpn_bbox_pred)

# Clip predicted boxes to image
proposals = clip_boxes(proposals, im_info[:2])

# Output rois blob
# Our RPN implementation only supports a single input image, so all
# batch inds are 0
batch_inds = np.zeros((proposals.shape[0], 1), dtype=np.float32)
blob = np.hstack((batch_inds, proposals.astype(np.float32, copy=False)))
return blob, scores


def proposal_top_layer_tf(rpn_cls_prob, rpn_bbox_pred, im_info, _feat_stride, anchors, num_anchors):
"""A layer that just selects the top region proposals
without using non-maximal suppression,
For details please see the technical report
"""
rpn_top_n = cfg.TEST.RPN_TOP_N

scores = rpn_cls_prob[:, :, :, num_anchors:]
rpn_bbox_pred = tf.reshape(rpn_bbox_pred, shape=(-1, 4))
scores = tf.reshape(scores, shape=(-1,))
Expand Down
21 changes: 19 additions & 2 deletions lib/layer_utils/snippets.py
Expand Up @@ -11,7 +11,25 @@
import numpy as np
from layer_utils.generate_anchors import generate_anchors

def generate_anchors_pre(height, width, feat_stride=16, anchor_scales=(8, 16, 32), anchor_ratios=(0.5, 1, 2)):
def generate_anchors_pre(height, width, feat_stride, anchor_scales=(8,16,32), anchor_ratios=(0.5,1,2)):
""" A wrapper function to generate anchors given different scales
Also return the number of anchors in variable 'length'
"""
anchors = generate_anchors(ratios=np.array(anchor_ratios), scales=np.array(anchor_scales))
A = anchors.shape[0]
shift_x = np.arange(0, width) * feat_stride
shift_y = np.arange(0, height) * feat_stride
shift_x, shift_y = np.meshgrid(shift_x, shift_y)
shifts = np.vstack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel())).transpose()
K = shifts.shape[0]
# width changes faster, so here it is H, W, C
anchors = anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2))
anchors = anchors.reshape((K * A, 4)).astype(np.float32, copy=False)
length = np.int32(anchors.shape[0])

return anchors, length

def generate_anchors_pre_tf(height, width, feat_stride=16, anchor_scales=(8, 16, 32), anchor_ratios=(0.5, 1, 2)):
shift_x = tf.range(width) * feat_stride # width
shift_y = tf.range(height) * feat_stride # height
shift_x, shift_y = tf.meshgrid(shift_x, shift_y)
Expand All @@ -29,4 +47,3 @@ def generate_anchors_pre(height, width, feat_stride=16, anchor_scales=(8, 16, 32
anchors_tf = tf.reshape(tf.add(anchor_constant, shifts), shape=(length, 4))

return tf.cast(anchors_tf, dtype=tf.float32), length

5 changes: 5 additions & 0 deletions lib/model/config.py
Expand Up @@ -269,6 +269,11 @@
# Use GPU implementation of non-maximum suppression
__C.USE_GPU_NMS = True

# Use an end-to-end tensorflow model.
# Note: models in E2E tensorflow mode have only been tested in feed-forward mode,
# but these models are exportable to other tensorflow instances as GraphDef files.
__C.USE_E2E_TF = True

# Default pooling mode, only 'crop' is available
__C.POOLING_MODE = 'crop'

Expand Down
74 changes: 47 additions & 27 deletions lib/nets/network.py
Expand Up @@ -14,9 +14,9 @@

import numpy as np

from layer_utils.snippets import generate_anchors_pre
from layer_utils.proposal_layer import proposal_layer
from layer_utils.proposal_top_layer import proposal_top_layer
from layer_utils.snippets import generate_anchors_pre, generate_anchors_pre_tf
from layer_utils.proposal_layer import proposal_layer, proposal_layer_tf
from layer_utils.proposal_top_layer import proposal_top_layer, proposal_top_layer_tf
from layer_utils.anchor_target_layer import anchor_target_layer
from layer_utils.proposal_target_layer import proposal_target_layer
from utils.visualization import draw_bounding_boxes
Expand Down Expand Up @@ -87,30 +87,44 @@ def _softmax_layer(self, bottom, name):

def _proposal_top_layer(self, rpn_cls_prob, rpn_bbox_pred, name):
with tf.variable_scope(name) as scope:
rois, rpn_scores = proposal_top_layer(
rpn_cls_prob,
rpn_bbox_pred,
self._im_info,
self._feat_stride,
self._anchors,
self._num_anchors
)
if cfg.USE_E2E_TF:
rois, rpn_scores = proposal_top_layer_tf(
rpn_cls_prob,
rpn_bbox_pred,
self._im_info,
self._feat_stride,
self._anchors,
self._num_anchors
)
else:
rois, rpn_scores = tf.py_func(proposal_top_layer,
[rpn_cls_prob, rpn_bbox_pred, self._im_info,
self._feat_stride, self._anchors, self._num_anchors],
[tf.float32, tf.float32], name="proposal_top")

rois.set_shape([cfg.TEST.RPN_TOP_N, 5])
rpn_scores.set_shape([cfg.TEST.RPN_TOP_N, 1])

return rois, rpn_scores

def _proposal_layer(self, rpn_cls_prob, rpn_bbox_pred, name):
with tf.variable_scope(name) as scope:
rois, rpn_scores = proposal_layer(
rpn_cls_prob,
rpn_bbox_pred,
self._im_info,
self._mode,
self._feat_stride,
self._anchors,
self._num_anchors
)
if cfg.USE_E2E_TF:
rois, rpn_scores = proposal_layer_tf(
rpn_cls_prob,
rpn_bbox_pred,
self._im_info,
self._mode,
self._feat_stride,
self._anchors,
self._num_anchors
)
else:
rois, rpn_scores = tf.py_func(proposal_layer,
[rpn_cls_prob, rpn_bbox_pred, self._im_info, self._mode,
self._feat_stride, self._anchors, self._num_anchors],
[tf.float32, tf.float32], name="proposal")

rois.set_shape([None, 5])
rpn_scores.set_shape([None, 1])

Expand Down Expand Up @@ -198,13 +212,19 @@ def _anchor_component(self):
# just to get the shape right
height = tf.to_int32(tf.ceil(self._im_info[0] / np.float32(self._feat_stride[0])))
width = tf.to_int32(tf.ceil(self._im_info[1] / np.float32(self._feat_stride[0])))
anchors, anchor_length = generate_anchors_pre(
height,
width,
self._feat_stride,
self._anchor_scales,
self._anchor_ratios
)
if cfg.USE_E2E_TF:
anchors, anchor_length = generate_anchors_pre_tf(
height,
width,
self._feat_stride,
self._anchor_scales,
self._anchor_ratios
)
else:
anchors, anchor_length = tf.py_func(generate_anchors_pre,
[height, width,
self._feat_stride, self._anchor_scales, self._anchor_ratios],
[tf.float32, tf.int32], name="generate_anchors")
anchors.set_shape([None, 4])
anchor_length.set_shape([])
self._anchors = anchors
Expand Down

0 comments on commit 6538e3c

Please sign in to comment.