Skip to content

Commit

Permalink
Add pooling options in MobileNetV2 (keras-team#10313)
Browse files Browse the repository at this point in the history
* Add pooling option

* Add pooling test
  • Loading branch information
fuzzythecat authored and taehoonlee committed May 30, 2018
1 parent 315a80a commit fe06696
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
19 changes: 19 additions & 0 deletions keras/applications/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from ..layers import Conv2D
from ..layers import DepthwiseConv2D
from ..layers import GlobalAveragePooling2D
from ..layers import GlobalMaxPooling2D
from ..layers import Add
from ..layers import Flatten
from ..layers import Dense
Expand Down Expand Up @@ -149,6 +150,7 @@ def MobileNetV2(input_shape=None,
include_top=True,
weights='imagenet',
input_tensor=None,
pooling=None,
classes=1000):
"""Instantiates the MobileNetV2 architecture.
Expand Down Expand Up @@ -187,6 +189,18 @@ def MobileNetV2(input_shape=None,
input_tensor: optional Keras tensor (i.e. output of
`layers.Input()`)
to use as image input for the model.
pooling: Optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model
will be the 4D tensor output of the
last convolutional layer.
- `avg` means that global average pooling
will be applied to the output of the
last convolutional layer, and thus
the output of the model will be a
2D tensor.
- `max` means that global max pooling will
be applied.
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
Expand Down Expand Up @@ -408,6 +422,11 @@ def MobileNetV2(input_shape=None,
x = GlobalAveragePooling2D()(x)
x = Dense(classes, activation='softmax',
use_bias=True, name='Logits')(x)
else:
if pooling == 'avg':
x = GlobalAveragePooling2D()(x)
elif pooling == 'max':
x = GlobalMaxPooling2D()(x)

# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
Expand Down
3 changes: 1 addition & 2 deletions tests/keras/applications/applications_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ def test_mobilenet():
_test_application_basic(app)
_test_application_notop(app, last_dim)
_test_application_variable_input_channels(app, last_dim)
if app == applications.MobileNet:
_test_app_pooling(app, last_dim)
_test_app_pooling(app, last_dim)


def test_densenet():
Expand Down

0 comments on commit fe06696

Please sign in to comment.