Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GeM layer #16747

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,18 @@
from keras.layers.pooling.max_pooling2d import MaxPooling2D
from keras.layers.pooling.max_pooling3d import MaxPool3D
from keras.layers.pooling.max_pooling3d import MaxPooling3D
from keras.layers.pooling.generalized_mean_pooling1d import (
BaseGeneralizedPooling,
innat marked this conversation as resolved.
Show resolved Hide resolved
)
from keras.layers.pooling.generalized_mean_pooling1d import (
GeneralizedMeanPooling1D,
)
from keras.layers.pooling.generalized_mean_pooling2d import (
GeneralizedMeanPooling2D,
)
from keras.layers.pooling.generalized_mean_pooling3d import (
GeneralizedMeanPooling3D,
)
from keras.layers.rnn.abstract_rnn_cell import AbstractRNNCell

# Recurrent layers.
Expand Down
51 changes: 51 additions & 0 deletions keras/layers/pooling/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ py_library(
":max_pooling1d",
":max_pooling2d",
":max_pooling3d",
":generalized_mean_pooling1d",
":generalized_mean_pooling2d",
":generalized_mean_pooling3d"
],
)

Expand Down Expand Up @@ -232,6 +235,36 @@ py_library(
],
)

py_library(
name = "generalized_mean_pooling1d",
srcs = ["generalized_mean_pooling1d.py"],
srcs_version = "PY3",
deps = [
":base_generalized_pooling",
"//keras:backend",
],
)

py_library(
name = "generalized_mean_pooling2d",
srcs = ["generalized_mean_pooling2d.py"],
srcs_version = "PY3",
deps = [
":base_generalized_pooling",
"//keras:backend",
],
)

py_library(
name = "generalized_mean_pooling3d",
srcs = ["generalized_mean_pooling3d.py"],
srcs_version = "PY3",
deps = [
":base_generalized_pooling",
"//keras:backend",
],
)

tf_py_test(
name = "average_pooling_test",
size = "medium",
Expand Down Expand Up @@ -303,3 +336,21 @@ tf_py_test(
"//keras/testing_infra:test_utils",
],
)

tf_py_test(
name = "generalized_mean_pooling_test",
size = "medium",
srcs = ["generalized_mean_pooling_test.py"],
python_version = "PY3",
shard_count = 8,
tags = [
innat marked this conversation as resolved.
Show resolved Hide resolved
"notsan", # TODO(b/183962355)
],
deps = [
"//:expect_absl_installed",
"//:expect_tensorflow_installed",
"//keras",
"//keras/testing_infra:test_combinations",
"//keras/testing_infra:test_utils",
],
)
12 changes: 12 additions & 0 deletions keras/layers/pooling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@
from keras.layers.pooling.average_pooling2d import AvgPool2D
from keras.layers.pooling.average_pooling3d import AveragePooling3D
from keras.layers.pooling.average_pooling3d import AvgPool3D
from keras.layers.pooling.generalized_mean_pooling1d import (
BaseGeneralizedPooling,
innat marked this conversation as resolved.
Show resolved Hide resolved
)
from keras.layers.pooling.generalized_mean_pooling1d import (
GeneralizedMeanPooling1D,
)
from keras.layers.pooling.generalized_mean_pooling2d import (
GeneralizedMeanPooling2D,
)
from keras.layers.pooling.generalized_mean_pooling3d import (
GeneralizedMeanPooling3D,
)
from keras.layers.pooling.global_average_pooling1d import GlobalAveragePooling1D
from keras.layers.pooling.global_average_pooling1d import GlobalAvgPool1D
from keras.layers.pooling.global_average_pooling2d import GlobalAveragePooling2D
Expand Down
121 changes: 121 additions & 0 deletions keras/layers/pooling/base_generalized_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Private base class for generalized pooling 1D layers."""


from keras import backend
from keras.engine.base_layer import Layer
from keras.engine.input_spec import InputSpec
innat marked this conversation as resolved.
Show resolved Hide resolved
from keras.utils import conv_utils


class BaseGeneralizedPooling(Layer):
"""Abstract class for different generalized mean pooling 1D layers."""

def __init__(
self,
power=1.0,
pool_size=3,
strides=3,
padding="valid",
data_format=None,
name=None,
**kwargs,
):
super().__init__(name=name, **kwargs)

if power <= 0:
raise ValueError(
"The value of `power` in GeneralizedMeanPooling must "
f"be positive number. Got: {power}"
innat marked this conversation as resolved.
Show resolved Hide resolved
)

if data_format is None:
data_format = backend.image_data_format()

if strides is None:
strides = pool_size

self.data_format = data_format
self.strides = strides
self.pool_size = pool_size
self.padding = padding
self.power = power

def build(self, input_shape):
if len(input_shape) == 3:
self.pool_size = conv_utils.normalize_tuple(
self.pool_size, 1, "pool_size"
)

self.strides = conv_utils.normalize_tuple(
self.strides, 1, "strides", allow_zero=True
)
self.padding = conv_utils.normalize_padding(self.padding).upper()

self.data_format = conv_utils.convert_data_format(
self.data_format, 3
)
self.input_spec = InputSpec(ndim=3)
innat marked this conversation as resolved.
Show resolved Hide resolved

elif len(input_shape) == 4:
self.pool_size = conv_utils.normalize_tuple(
self.pool_size, 2, "pool_size"
)

self.strides = conv_utils.normalize_tuple(
self.strides, 2, "strides", allow_zero=True
)
self.padding = conv_utils.normalize_padding(self.padding).upper()

self.data_format = conv_utils.convert_data_format(
self.data_format, 4
)
self.input_spec = InputSpec(ndim=4)

elif len(input_shape) == 5:
self.pool_size = conv_utils.normalize_tuple(
self.pool_size, 3, "pool_size"
)

self.strides = conv_utils.normalize_tuple(
self.strides, 3, "strides", allow_zero=True
)
self.padding = conv_utils.normalize_padding(self.padding).upper()

self.data_format = conv_utils.convert_data_format(
self.data_format, 5
)
self.input_spec = InputSpec(ndim=5)

else:
raise ValueError(
"Invalid input shape. Expected input should be 1D, 2D "
f"and 3D data. Got {input_shape}"
innat marked this conversation as resolved.
Show resolved Hide resolved
)

def call(self, inputs):
raise NotImplementedError

def get_config(self):
config = {
"power": self.power,
"pool_size": self.pool_size,
"strides": self.strides,
"padding": self.padding,
"data_format": self.data_format,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
136 changes: 136 additions & 0 deletions keras/layers/pooling/generalized_mean_pooling1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import tensorflow as tf
from tensorflow.python.util.tf_export import keras_export

from keras.layers.pooling.base_generalized_pooling import BaseGeneralizedPooling


@keras_export("keras.layers.GeneralizedMeanPooling1D")
class GeneralizedMeanPooling1D(BaseGeneralizedPooling):
"""Generalized mean pooling operation for temporal data.

Generalized Mean Pooling (GeM) computes the generalized mean of each
channel in a tensor. It provides a parameter `power` that sets an
exponent enabling the pooling to increase or decrease the contrast
between salient features in the feature map.

The GeM layer is an generalization of the average pooling layer and
spatial max pooling layer. When `power` = 1`, it will act as a average
pooling layer and when `power = inf`, it will act as a spatial
max-pooling layer.

Examples:

1. When pool_size=2, strides=1, padding='valid'

>>> input_shape = (2, 3, 4)
>>> x = tf.random.normal(input_shape)
>>> gem_pool_1d = tf.keras.layers.GeneralizedMeanPooling1D(power=3,
... pool_size=2, strides=1, padding='valid',
... data_format='channels_last')
>>> gem_pool_1d(x)
<tf.Tensor: shape=(1, 4, 1), dtype=float32, numpy=
array([[[1.6509637],
[2.596247 ],
[3.5700185],
[4.5548835]]], dtype=float32)>

2. When pool_size=2, strides=1, padding='valid'

>>> input_shape = (2, 3, 4)
>>> x = tf.random.normal(input_shape)
>>> gem_pool_1d = tf.keras.layers.GeneralizedMeanPooling1D(power=3,
... pool_size=2, strides=2, padding='valid',
... data_format='channels_last')
>>> gem_pool_1d(x)
<tf.Tensor: shape=(1, 2, 1), dtype=float32, numpy=
array([[[1.6509637],
[3.5700185]]], dtype=float32)>

3. When pool_size=2, strides=1, padding='same'

>>> input_shape = (2, 3, 4)
>>> x = tf.random.normal(input_shape)
>>> gem_pool_1d = tf.keras.layers.GeneralizedMeanPooling1D(power=3,
... pool_size=2, strides=1, padding='same',
... data_format='channels_last')
>>> gem_pool_1d(x)
<tf.Tensor: shape=(1, 5, 1), dtype=float32, numpy=
array([[[1.6509637],
[2.596247 ],
[3.5700185],
[4.5548835],
[5.0000005]]], dtype=float32)>

Args:
power: Float power > 0 is an inverse exponent parameter, used during
the generalized mean pooling computation. Setting this exponent as
power > 1 increases the contrast of the pooled feature map and focuses
on the salient features of the image. GeM is a generalization of the
average pooling when `power` = 1 and of spatial max-pooling layer when
`power` = inf or a large number.
pool_size: Integer, size of the average pooling windows.
strides: Integer, or None. Factor by which to downscale.
E.g. 2 will halve the input.
If None, it will default to `pool_size`.
padding: A string. The padding method, either 'valid' or 'same'.
`'valid'` means no padding. `'same'` results in padding evenly to
the left/right or up/down of the input such that output has the same
height/width dimension as the input.
data_format: A string, one of `channels_last` (default) or
`channels_first`. The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, steps, features)` while `channels_first` corresponds
to inputs with shape `(batch, features, steps)`.
name: A string, the name of the layer.

Input shape:
- If `data_format='channels_last'`:
3D tensor with shape `(batch_size, steps, features)`.
- If `data_format='channels_first'`:
3D tensor with shape `(batch_size, features, steps)`.

Output shape:
- If `data_format='channels_last'`:
3D tensor with shape `(batch_size, downsampled_steps, features)`.
- If `data_format='channels_first'`:
3D tensor with shape `(batch_size, features, downsampled_steps)`.

References:
- https://arxiv.org/pdf/1711.02512.pdf
innat marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
power=3.0,
pool_size=2,
strides=None,
padding="valid",
data_format=None,
name=None,
**kwargs
):
self.power = power
self.pool_size = pool_size
self.strides = strides
self.padding = padding
self.data_format = data_format
super().__init__(name=name, **kwargs)

def call(self, inputs):
x = tf.pow(inputs, self.power)
x = tf.nn.avg_pool1d(
x, self.pool_size, self.strides, self.padding, self.data_format
)
x = tf.pow(x, (1.0 / self.power))
return x

def get_config(self):
innat marked this conversation as resolved.
Show resolved Hide resolved
config = {
"power": self.power,
"strides": self.strides,
"pool_size": self.pool_size,
"padding": self.padding,
"data_format": self.data_format,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))