Skip to content

Commit

Permalink
Refactor RPN and modify resnet classifier.
Browse files Browse the repository at this point in the history
  • Loading branch information
Hans Gaiser committed Jul 27, 2017
1 parent 6a13e5c commit 5b764be
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 295 deletions.
114 changes: 20 additions & 94 deletions keras_rcnn/backend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,6 @@
import keras_rcnn.backend


def anchor(base_size=16, ratios=None, scales=None):
"""
Generates a regular grid of multi-aspect and multi-scale anchor boxes.
"""
if ratios is None:
ratios = keras.backend.cast([0.5, 1, 2], 'float32')

if scales is None:
scales = keras.backend.cast([8, 16, 32], 'float32')
base_anchor = keras.backend.cast([1, 1, base_size, base_size],
'float32') - 1
base_anchor = keras.backend.expand_dims(base_anchor, 0)

ratio_anchors = _ratio_enum(base_anchor, ratios)
anchors = _scale_enum(ratio_anchors, scales)

return anchors


def bbox_transform(ex_rois, gt_rois):
ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0
ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0
Expand All @@ -48,8 +29,6 @@ def bbox_transform(ex_rois, gt_rois):


def clip(boxes, shape):
boxes = keras.backend.cast(boxes, dtype='int32')
shape = keras.backend.cast(shape, dtype='int32')
proposals = [
keras.backend.maximum(
keras.backend.minimum(boxes[:, 0::4], shape[1] - 1), 0),
Expand All @@ -64,76 +43,23 @@ def clip(boxes, shape):
return keras.backend.concatenate(proposals, axis=1)


def _mkanchors(ws, hs, x_ctr, y_ctr):
"""
Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
"""

col1 = keras.backend.reshape(x_ctr - 0.5 * (ws - 1), (-1, 1))
col2 = keras.backend.reshape(y_ctr - 0.5 * (hs - 1), (-1, 1))
col3 = keras.backend.reshape(x_ctr + 0.5 * (ws - 1), (-1, 1))
col4 = keras.backend.reshape(y_ctr + 0.5 * (hs - 1), (-1, 1))
anchors = keras.backend.concatenate((col1, col2, col3, col4), axis=1)

return anchors


def _ratio_enum(anchor, ratios):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
"""
# import pdb
# pdb.set_trace()
w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios
ws = keras.backend.round(keras.backend.sqrt(size_ratios))
hs = keras.backend.round(ws * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors


def _scale_enum(anchor, scales):
"""
Enumerate a set of anchors for each scale wrt an anchor.
"""

w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = keras.backend.expand_dims(w, 1) * scales
hs = keras.backend.expand_dims(h, 1) * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors


def _whctrs(anchor):
"""
Return width, height, x center, and y center for an anchor (window).
"""
w = anchor[:, 2] - anchor[:, 0] + 1
h = anchor[:, 3] - anchor[:, 1] + 1
x_ctr = anchor[:, 0] + 0.5 * (w - 1)
y_ctr = anchor[:, 1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr


def shift(shape, stride):
def shift(shape, anchors, stride):
shift_x = keras.backend.arange(0, shape[0]) * stride
shift_y = keras.backend.arange(0, shape[1]) * stride

shift_x, shift_y = keras_rcnn.backend.meshgrid(shift_x, shift_y)
shift_x = keras.backend.reshape(shift_x, [-1])
shift_y = keras.backend.reshape(shift_y, [-1])

shifts = keras.backend.stack([
keras.backend.reshape(shift_x, [-1]),
keras.backend.reshape(shift_y, [-1]),
keras.backend.reshape(shift_x, [-1]),
keras.backend.reshape(shift_y, [-1])
shift_x,
shift_y,
shift_x,
shift_y,
], axis=0)

shifts = keras.backend.transpose(shifts)

anchors = keras_rcnn.backend.anchor()

number_of_anchors = keras.backend.shape(anchors)[0]

k = keras.backend.shape(shifts)[0] # number of base points = feat_h * feat_w
Expand Down Expand Up @@ -174,31 +100,31 @@ def filter_boxes(proposals, minimum):
ws = proposals[:, 2] - proposals[:, 0] + 1
hs = proposals[:, 3] - proposals[:, 1] + 1

indicies = keras_rcnn.backend.where((ws >= minimum) & (hs >= minimum))
indices = keras_rcnn.backend.where((ws >= minimum) & (hs >= minimum))

indicies = keras.backend.flatten(indicies)
indices = keras.backend.flatten(indices)

return keras.backend.cast(indicies, "int32")
return keras.backend.cast(indices, "int32")


def inside_image(y_pred, img_info):
"""
Calc indicies of anchors which are located completely inside of the image
Calc indices of boxes which are located completely inside of the image
whose size is specified by img_info ((height, width, scale)-shaped array).
:param y_pred: anchors
:param boxes: bounding boxes
:param img_info:
:return:
"""
indicies = keras_rcnn.backend.where(
(y_pred[:, 0] >= 0) &
(y_pred[:, 1] >= 0) &
(y_pred[:, 2] < img_info[1]) & # width
(y_pred[:, 3] < img_info[0]) # height
indices = keras_rcnn.backend.where(
(boxes[:, 0] >= 0) &
(boxes[:, 1] >= 0) &
(boxes[:, 2] < img_info[1]) & # width
(boxes[:, 3] < img_info[0]) # height
)

indicies = keras.backend.cast(indicies, "int32")
indices = keras.backend.cast(indices, "int32")

gathered = keras.backend.gather(y_pred, indicies)
gathered = keras.backend.gather(boxes, indices)

return indicies[:, 0], keras.backend.reshape(gathered, [-1, 4])
return indices[:, 0], keras.backend.reshape(gathered, [-1, 4])
21 changes: 5 additions & 16 deletions keras_rcnn/classifiers/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,18 @@
import keras_resnet.blocks


def residual(classes, mask=False):
def residual(classes, mask=False, features=512):
"""Resnet classifiers as in Mask R-CNN."""
def f(x):
if keras.backend.image_data_format() == "channels_last":
channel_axis = 3
else:
channel_axis = 1

y = keras.layers.TimeDistributed(keras.layers.Conv2D(1024, (1, 1)))(x)

# conv5 block as in Deep Residual Networks with first conv operates
# on a 7x7 RoI with stride 1 (instead of 14x14 / stride 2)
for i in range(3):
y = keras_resnet.blocks.time_distributed_bottleneck_2d(512, (1, 1), first=True)(y)

y = keras.layers.TimeDistributed(keras.layers.BatchNormalization(axis=channel_axis))(y)
y = keras.layers.TimeDistributed(keras.layers.Activation("relu"))(y)
y = keras_resnet.blocks.time_distributed_bottleneck_2d(features, stage=3, block=0, stride=1)(x)
y = keras_resnet.blocks.time_distributed_bottleneck_2d(features, stage=3, block=1, stride=1)(y)
y = keras_resnet.blocks.time_distributed_bottleneck_2d(features, stage=3, block=2, stride=1)(y)

# class and box branches
y = keras.layers.TimeDistributed(keras.layers.AveragePooling2D((7, 7)))(y)

y = keras.layers.TimeDistributed(keras.layers.GlobalAveragePooling2D())(y)
score = keras.layers.TimeDistributed(keras.layers.Dense(classes, activation="softmax"))(y)

boxes = keras.layers.TimeDistributed(keras.layers.Dense(4 * classes))(y)

# TODO{JihongJu} the mask branch
Expand Down
66 changes: 43 additions & 23 deletions keras_rcnn/layers/object_detection/_object_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

import keras_rcnn.backend


class ObjectProposal(keras.engine.topology.Layer):
def __init__(self, maximum_proposals=300, **kwargs):
self.output_dim = (None, maximum_proposals, 4)

def __init__(
self,
maximum_proposals=300,
**kwargs
):
self.maximum_proposals = maximum_proposals

super(ObjectProposal, self).__init__(**kwargs)
Expand All @@ -17,36 +18,55 @@ def build(self, input_shape):
super(ObjectProposal, self).build(input_shape)

def call(self, inputs, **kwargs):
return self.propose(inputs[0], inputs[1], self.maximum_proposals)
boxes, scores, anchors, im_info = inputs
return self.propose(boxes, scores, anchors, im_info)

def compute_output_shape(self, input_shape):
return self.output_dim
return [(input_shape[0], self.maximum_proposals, 4), (input_shape[0], self.maximum_proposals)]

@staticmethod
def propose(boxes, scores, maximum):
def propose(self, boxes, scores, anchors, im_info):
# 1. Generate proposals from bbox deltas and shifted anchors
shape = keras.backend.shape(boxes)[1:3]

shifted = keras_rcnn.backend.shift(shape, 16)
# shift the anchors to original image shape
shifted = keras_rcnn.backend.shift(shape, anchors, 16)
shifted = keras.backend.reshape(shifted, (-1, 1, 4))

proposals = keras.backend.reshape(boxes, (-1, 4))
# apply shifts to anchors
anchors = keras.backend.reshape(anchors, (1, -1, 4))
anchors = keras.backend.reshape(anchors + shifted, (-1, 4))

proposals = keras_rcnn.backend.bbox_transform_inv(shifted, proposals)
# reshape predicted bbox to get them into the same order as the anchor
boxes = keras.backend.reshape(boxes, (-1, 4))
scores = keras.backend.reshape(scores, (-1, 1))

proposals = keras_rcnn.backend.clip(proposals, shape)
# convert anchors into proposals via bbox transformations
proposals = keras_rcnn.backend.bbox_transform_inv(anchors, boxes)

indicies = keras_rcnn.backend.filter_boxes(proposals, 1)
# 2. Clip predicted boxes to image
proposals = keras_rcnn.backend.clip(proposals, im_info[:2])

proposals = keras.backend.gather(proposals, indicies)
scores = scores[:, :, :, :9]
scores = keras.backend.reshape(scores, (-1, 1))
scores = keras.backend.gather(scores, indicies)
scores = keras.backend.flatten(scores)
# 3. Remove predicted boxes with either height or width < threshold
indices = keras_rcnn.backend.filter_boxes(proposals, 16.0 * im_info[2])
proposals = keras.backend.gather(proposals, indices)
scores = keras.backend.gather(scores, indices)

# 4. sort all (proposal, score) pairs by score from highest to lowest
# 5. take top pre_nms_topN (e.g. 6000)
# TODO ?

proposals = keras.backend.cast(proposals, keras.backend.floatx())
scores = keras.backend.cast(scores, keras.backend.floatx())
# 6. apply nms (e.g. threshold = 0.7)
# 7. take after_nms_topN (e.g. 300) (#TODO)
# 8. return the top proposals (-> RoIs top) (#TODO)
indices = keras_rcnn.backend.non_maximum_suppression(proposals, scores, self.maximum_proposals, 0.7)
proposals = keras.backend.gather(proposals, indices)
scores = keras.backend.gather(scores, indices)

indicies = keras_rcnn.backend.non_maximum_suppression(proposals, scores, maximum, 0.7)
# These would have to be filled with the proposal labels and target proposals
#labels = keras.backend.placeholder((1,))
#targets = keras.backend.placeholder((1,))
#inside_weights = keras.backend.placeholder((1,))
#outside_weights = keras.backend.placeholder((1,))

proposals = keras.backend.gather(proposals, indicies)
return [keras.backend.expand_dims(proposals, 0), scores]

return keras.backend.expand_dims(proposals, 0)
97 changes: 97 additions & 0 deletions keras_rcnn/layers/object_detection/generate_anchors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick and Sean Bell
# --------------------------------------------------------

import numpy as np

# Verify that we compute the same anchors as Shaoqing's matlab implementation:
#
# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat
# >> anchors
#
# anchors =
#
# -83 -39 100 56
# -175 -87 192 104
# -359 -183 376 200
# -55 -55 72 72
# -119 -119 136 136
# -247 -247 264 264
# -35 -79 52 96
# -79 -167 96 184
# -167 -343 184 360

#array([[ -83., -39., 100., 56.],
# [-175., -87., 192., 104.],
# [-359., -183., 376., 200.],
# [ -55., -55., 72., 72.],
# [-119., -119., 136., 136.],
# [-247., -247., 264., 264.],
# [ -35., -79., 52., 96.],
# [ -79., -167., 96., 184.],
# [-167., -343., 184., 360.]])

def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
scales=2**np.arange(3, 6)):
"""
Generate anchor (reference) windows by enumerating aspect ratios X
scales wrt a reference (0, 0, 15, 15) window.
"""

base_anchor = np.array([1, 1, base_size, base_size]) - 1
ratio_anchors = _ratio_enum(base_anchor, ratios)
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)
for i in range(ratio_anchors.shape[0])])
return anchors

def _whctrs(anchor):
"""
Return width, height, x center, and y center for an anchor (window).
"""

w = anchor[2] - anchor[0] + 1
h = anchor[3] - anchor[1] + 1
x_ctr = anchor[0] + 0.5 * (w - 1)
y_ctr = anchor[1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr

def _mkanchors(ws, hs, x_ctr, y_ctr):
"""
Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
"""

ws = ws[:, np.newaxis]
hs = hs[:, np.newaxis]
anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1),
y_ctr + 0.5 * (hs - 1)))
return anchors

def _ratio_enum(anchor, ratios):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
"""

w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios
ws = np.round(np.sqrt(size_ratios))
hs = np.round(ws * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors

def _scale_enum(anchor, scales):
"""
Enumerate a set of anchors for each scale wrt an anchor.
"""

w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = w * scales
hs = h * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
Loading

0 comments on commit 5b764be

Please sign in to comment.