In [None]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.nn import max_pool_with_argmax

In [None]:
class MaxPooling1DWithArgmax(keras.layers.MaxPooling1D):
    def __init__(self, pool_size=2, strides=None,
                 padding='valid', data_format='channels_last', **kwargs):

        super(MaxPooling1DWithArgmax, self).__init__(
            pool_size=pool_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            **kwargs)

        self.store_argmax = False
        self.padding_upper = padding.upper()

    def call(self, inputs):
        if self.store_argmax:
            ret = tf.nn.max_pool_with_argmax(tf.expand_dims(inputs, 1),
                                             ksize=(1, self.pool_size[0]),
                                             strides=(1, self.strides[0]),
                                             padding=self.padding_upper)
            self.ret = ret
            self.argmax = ret.argmax
        return super(MaxPooling1DWithArgmax, self).call(inputs)

In [None]:
pool_size = 4
padding = 'valid'
name = 'max_pool_1d'

max_pool = keras.layers.MaxPooling1D(pool_size=pool_size, strides=pool_size,
                                     padding=padding, name=name)

max_pool_argmax = MaxPooling1DWithArgmax(pool_size=pool_size, strides=pool_size,
                                         padding=padding, name=name+'_with_argmax')

In [None]:
input_shape = (50, 2385, 16)
tf.random.set_seed(5061983)
x = tf.random.uniform(input_shape)
max_pool_argmax.store_argmax = True
y1 = max_pool(x)
y2 = max_pool_argmax(x)
y3 = tf.squeeze(max_pool_argmax.ret.output)
max_pool_argmax.store_argmax = False
assert np.all((y1 == y3).numpy())

In [None]:
%%timeit
y1 = max_pool(x)

In [None]:
%%timeit
y2 = max_pool_argmax(x)

In [None]:
max_pool_argmax.store_argmax = True

In [None]:
%%timeit
y2 = max_pool_argmax(x)