Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Region Proposal Network (RPN) classification and regression losses (WIP) #41

Merged
merged 17 commits into from
Aug 17, 2017

Conversation

0x00b1
Copy link
Contributor

@0x00b1 0x00b1 commented Jul 10, 2017

Implementation of the Region Proposal Network (RPN) classification and regression losses.

Noteworthy items:

  • We’ve refactored anchor generation as Keras backend functions so they’ll run from the Keras backend libraries like TensorFlow and Theano.

  • The loss is computed on a Lambda layer. This is a workaround for Keras’ assumption that the y_true and y_pred values have identical shapes. We should probably refactor this into a custom layer to simplify model construction.

  • It’s buggy. While there’re plentiful unit tests they likely don’t provide sufficient coverage.

  • We need to write documentation.

@0x00b1 0x00b1 changed the title Region Proposal Network (RPN) classification and regression losses Region Proposal Network (RPN) classification and regression losses (WIP) Jul 10, 2017
@JihongJu
Copy link
Contributor

JihongJu commented Jul 10, 2017

@0x00b1

Keras’ assumption that the y_true and y_pred values have identical shapes.

That is interesting. I wasn't aware of it. Do you think we could add pre-processing to targets and generate y_true on the fly?
Keras model.compile L833

def example(x):
y_true, y_pred = x

return loss(y_true, y_pred)
Copy link
Contributor

@JihongJu JihongJu Jul 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think loss here needs to be passed as an argument to the Lambda layer. Otherwise, it will cause problems while saving and loading the model. See #5396.

(Not tested)

rpn_proposal_loss = keras_rcnn.losses.rpn.proposal(9, (224, 224), 16)
def rpn_loss(x, loss):
    y_true, y_pred = x
    return loss(y_true, y_pred)
loss = keras.layers.Lambda(example, arguments={'loss': rpn_proposal_loss},)([y_true, y_pred])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Incorporating that

@codecov-io
Copy link

codecov-io commented Jul 10, 2017

Codecov Report

Merging #41 into master will increase coverage by 12.35%.
The diff coverage is 100%.

Impacted file tree graph

@@             Coverage Diff             @@
##           master      #41       +/-   ##
===========================================
+ Coverage   53.64%   65.99%   +12.35%     
===========================================
  Files          17       17               
  Lines         563      644       +81     
===========================================
+ Hits          302      425      +123     
+ Misses        261      219       -42
Impacted Files Coverage Δ
keras_rcnn/losses/rpn.py 100% <100%> (ø) ⬆️
keras_rcnn/backend/common.py 100% <100%> (+27.27%) ⬆️
...s_rcnn/layers/object_detection/_object_proposal.py 100% <100%> (ø) ⬆️
keras_rcnn/backend/tensorflow_backend.py 100% <100%> (+20.61%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e341198...9a7a4c1. Read the comment docs.

@jhung0
Copy link
Contributor

jhung0 commented Jul 11, 2017

The code to handle the different shapes was inspired by keras-team/keras#4781 suggests using dummy targets

Copy link
Contributor

@hgaiser hgaiser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed this repository yesterday and, although it is far from finished, it looks really well thought out. I'm hoping I can contribute actively to this project in the near future. I was working on my own port of py-faster-rcnn because I was unhappy with the existing ports for keras / tensorflow, but got stuck on the losses for the RPN network. I mainly believe the issue is with how Keras forces the loss to be computed in a certain way and I made an issue about it on the keras github page (keras-team/keras#7395). If you would like to contribute to that discussion, that'd be helpful. I'm curious how the RPN loss is currently handled in keras-rcnn, that part is a bit unclear to me still. I'm guessing it is some form of workaround, but which kind is it ? :p

In my port I had made a custom layer, analogous to this example: https://github.com/fchollet/keras/blob/master/examples/variational_autoencoder.py#L56-L61 . I got a training network with a loss that was nicely shrinking, but the resulting bboxes seem to favor scaling to the entire image .. Can't figure out for the life of me why that happens.

Anyway, I'm sort of abusing this review to say hello and to see where I can contribute :)


def anchor(base_size=16, ratios=None, scales=None):
"""
Generates a regular grid of multi-aspect and multi-scale anchor boxes.
"""
if ratios is None:
ratios = keras.backend.variable(numpy.array([0.5, 1, 2]))
ratios = keras.backend.cast([0.5, 1, 2], 'float32')
Copy link
Contributor

@hgaiser hgaiser Jul 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not change the default value in the method definition to keras.backend.cast([0.5, 1, 2], 'float32')?

Better yet, https://www.tensorflow.org/api_docs/python/tf/contrib/keras/backend/floatx


def anchor(base_size=16, ratios=None, scales=None):
"""
Generates a regular grid of multi-aspect and multi-scale anchor boxes.
"""
if ratios is None:
ratios = keras.backend.variable(numpy.array([0.5, 1, 2]))
ratios = keras.backend.cast([0.5, 1, 2], 'float32')

if scales is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

return anchors


def _ratio_enum(anchor, ratios):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
"""

# import pdb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used for debugging? Can be removed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep missed that one

ws = proposals[:, 2] - proposals[:, 0] + 1
hs = proposals[:, 3] - proposals[:, 1] + 1

indicies = keras_rcnn.backend.where((ws >= minimum) & (hs >= minimum))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it indices? :p

def __init__(self, maximum_proposals=300, **kwargs):
self.output_dim = (None, None, 4)
self.output_dim = (None, maximum_proposals, 4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be mistaken, but I believe the name dim is generally reserved for a scalar value, representing the length of one dimension (https://keras.io/getting-started/sequential-model-guide/#specifying-the-input-shape)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes seems so. What name would be good? output_shape?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think so.

@hgaiser hgaiser mentioned this pull request Jul 25, 2017
@0x00b1
Copy link
Contributor Author

0x00b1 commented Jul 26, 2017

I just noticed this repository yesterday and, although it is far from finished, it looks really well thought out. I'm hoping I can contribute actively to this project in the near future. I was working on my own port of py-faster-rcnn because I was unhappy with the existing ports for keras / tensorflow, but got stuck on the losses for the RPN network. I mainly believe the issue is with how Keras forces the loss to be computed in a certain way and I made an issue about it on the keras github page (keras-team/keras#7395). If you would like to contribute to that discussion, that'd be helpful. I'm curious how the RPN loss is currently handled in keras-rcnn, that part is a bit unclear to me still. I'm guessing it is some form of workaround, but which kind is it ? :p

Hi, @hgaiser! Let’s collaborate! It’s a tricky problem and we could use the help!

In this pull request, the approach:

  • For y_pred, the two feature maps produced by the convolutional layers that return the regression (for the reference bounding boxes) and the classification (for the reference bounding box’s objectness) are concatenated into one matrix.
  • For y_true, we generate the anchors as described by the Faster RCNN paper. It’s somewhat different than the py-faster-rcnn implementation by @rbgirshick as the anchors are generated on the GPU (as part of the loss computation).

I should add that it’s still wonky since we have an intermediate loss (i.e. a loss computed on an intermediate layer). However, in the past week or two I’ve been thinking about a possible RPN implementation that treats the RPN losses as regularization terms rather than traditional losses.

I’m curious, @hgaiser, is this a problem You’d like to help out with? If so, we should schedule some time to talk! This offer also includes anybody else that is lurking and wants to help!

@hgaiser
Copy link
Contributor

hgaiser commented Jul 26, 2017

Hi,

I would very much like to collaborate! I had been working on a keras port by myself and it was driving me insane ;)

I made a few changes to this PR already, I'd like to share that with you tomorrow, I'll do so in another PR. That is sort of my view on how I imagine it, I would like to hear what you had in mind and what you think of my approach. How shall we talk? I'm always online on Slack (kerasteam), maybe that could be a good channel.


image = keras.layers.Input((224, 224, 3))

y_true = keras.layers.Input((None, 4), name="y_true")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a problem. You’d need to add a bounding box input to the predict method because the model will expect two inputs (an image and corresponding bounding boxes). A possible implementation is creating a custom layer that acts like keras.layers.Input during training but produces reference bounding boxes during test.

y_true = keras.layers.Input((None, 4), name="y_true")


features = keras.layers.Conv2D(64, **options)(image)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’d use a ResNet implementation from the keras-resnet package rather than VGG since we’ll consider this standard and match everything else.

loss_c, loss_r = loss(y_true, y_pred)
return loss_c + loss_r

# use the following for testing:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove these comments.


rpn_loss = keras_rcnn.losses.rpn.proposal(9, (224, 224), 16)

def example(x, loss):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a more descriptive identifier.


y_true = numpy.expand_dims(y_true, 0)

model.fit([a, y_true], [None])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should also test model.predict and model.evaluate.

@mcquin
Copy link

mcquin commented Aug 2, 2017

Remaining work: #52 (comment)


def anchor(base_size=16, ratios=None, scales=None):
"""
Generates a regular grid of multi-aspect and multi-scale anchor boxes.
"""
if ratios is None:
ratios = keras.backend.variable(numpy.array([0.5, 1, 2]))
ratios = keras.backend.cast([0.5, 1, 2], keras.backend.floatx())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't cast intended to cast Tensor's from one type to another? An alternative is to use keras.backend.variable(numpy.array[0.5, 1, 2], dtype=keras.backend.floatx())


targets = keras.backend.transpose(targets)

return targets
return keras.backend.cast(targets, 'float32')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keras.backend.floatx(), but is it even necessary? Isn't targets already float here?

def __init__(self, anchors, **kwargs):
self.anchors = anchors

self.is_placeholder = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed this in the example layer too, but upon further inspection I don't think it is necessary. Looking at Keras' source code it doesn't even appear to be necessary in the example.


super(Classification, self).__init__(**kwargs)

def _loss(self, y_true, y_pred):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm maybe a personal preference, but I think it's better to rename y_true and y_pred to rpn_cls_score and rpn_labels or something similar.


super(Regression, self).__init__(**kwargs)

def _loss(self, y_true, y_pred):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.


self.add_loss(loss, inputs=inputs)

return y_pred
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit semantic, but we could return loss instead of y_pred here. It is ignored, but semantically it makes more sense (and potentially necessary in the future, if Keras decides to change the way losses are handled).

import keras.engine

import keras_rcnn.backend

# FIXME: remove global
RPN_PRE_NMS_TOP_N = 12000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note, 12000 is only used during training, not testing.

@@ -1,21 +1,107 @@
import keras
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire file can be removed?

# ResNet50 as encoder
encoder = keras_resnet.models.ResNet50
image, _, _ = inputs
features = keras_resnet.models.ResNet50(image, blocks=blocks, include_top=False).output
Copy link
Contributor

@hgaiser hgaiser Aug 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this one ... do we want to call the model as a function (ie. keras_resnet.models.ResNet50(image, blocks=blocks, include_top=False)(image)) or do we want to incorporate every layer in the RCNN model? The difference is that in the summary, you'd see all the layers (the way it is now, res2-res4 and everything in between) or you would see a single 'layer' called ResNet50. Let me know what you think. For example, I'm running a FCN network (with two branches, one processing RGB, the other processing depth) at the moment and its summary looks like this:

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_25 (InputLayer)            (None, 1024, 1280, 3) 0                                            
____________________________________________________________________________________________________
input_26 (InputLayer)            (None, 1024, 1280, 1) 0                                            
____________________________________________________________________________________________________
resnet50-rgb (Model)             (None, 1024, 1280, 1) 33155975    input_25[0][0]                   
____________________________________________________________________________________________________
resnet50-depth (Model)           (None, 1024, 1280, 1) 33149703    input_26[0][0]                   
____________________________________________________________________________________________________
merged (Concatenate)             (None, 1024, 1280, 2) 0           resnet50-rgb[1][0]               
                                                                   resnet50-depth[1][0]             
____________________________________________________________________________________________________
segmentation (Conv2D)            (None, 1024, 1280, 1) 3           merged[0][0]                     
____________________________________________________________________________________________________
flatten_11 (Flatten)             (None, 1310720)       0           segmentation[0][0]               
====================================================================================================
Total params: 66,305,681
Trainable params: 66,199,441
Non-trainable params: 106,240


def _loss(self, y_true, y_pred):
# Binary classification loss
x, y = y_pred[:, :, :, :], y_true[:, :, :, self.anchors:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe y_true is shaped (None, None) (as in (batch_id, label_id)). Slicing it like y_true[:, :, :, self.anchors:] won't work (why is slicing necessary on y_true anyway?)

Copy link
Contributor Author

@0x00b1 0x00b1 Aug 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the arguments are swapped.

# Binary classification loss
x, y = y_pred[:, :, :, :], y_true[:, :, :, self.anchors:]

a = y_true[:, :, :, :self.anchors] * keras.backend.binary_crossentropy(x, y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two questions here, why use binary_crossentropy? And why multiply with y_true?
I'm thinking binary_crossentropy should be used when there is only one output, but we have two. One option is to use only one of the two outputs (output for foreground, probably) and use binary_crossentropy (the option you seem to implement here), the other option is to use categorical_crossentropy and convert the labels to one_hot vectors. Which has the preference? I don't know :) intuitively I'd guess categorical_crossentropy. I'd love to hear your thoughts about this.

b = keras.backend.epsilon() + y_true[:, :, :, :self.anchors]
b = keras.backend.sum(b)

return 1.0 * (a / b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not return keras.backend.mean(keras.backend.binary_crossentropy(x, y)) directly? Why divide by anchor overlaps?


def _loss(self, y_true, y_pred):
# Robust L1 Loss
x = y_true[:, :, :, 4 * self.anchors:] - y_pred
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Classification, I believe it is sized (None, None, 4) (batch_id, box_id, box).

@0x00b1
Copy link
Contributor Author

0x00b1 commented Aug 7, 2017

@hgaiser Good call. Fixed compute_output_shape for the loss layers.

@hgaiser
Copy link
Contributor

hgaiser commented Aug 7, 2017 via email


def compute_output_shape(self, input_shape):
return [(None, 1), (None, 4)]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a get_config() method would be useful to save and load trained model.

adjoint_b=False,
a_is_sparse=False,
b_is_sparse=False
a,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just my curiosity here, why 8 spaces?

@0x00b1 0x00b1 force-pushed the features/region-proposal-network-losses branch from 47860b3 to df93c4e Compare August 15, 2017 17:38
@0x00b1 0x00b1 merged commit 14add26 into master Aug 17, 2017
@0x00b1 0x00b1 deleted the features/region-proposal-network-losses branch August 17, 2017 19:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants