In [1]:
from keras import backend as K
from keras.engine.topology import Layer
from keras.layers import Lambda
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Activation, Flatten, Input
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils

Using TensorFlow backend.


In [2]:
import numpy as np

In [42]:
class SoftMaxPool(Layer):
    '''Apply soft-max pooling '''
    def __init__(self, axis=-1, **kwargs):
        
        super(SoftMaxPool, self).__init__(**kwargs)
        
        self.axis = axis
        
    def get_output_shape_for(self, input_shape):
        
        shape = list(input_shape)
        shape[self.axis] = 1
        return tuple(shape)
    
    def call(self, x, mask=None):
        
        m = K.max(x, axis=self.axis, keepdims=True)
        sm = K.exp(x - m)
        w = sm / K.sum(sm, axis=self.axis, keepdims=True)
        return K.sum(x * w, axis=self.axis, keepdims=True)
    
    def get_config(self):
        config = {'axis': self.axis}
        base_config = super(SoftMaxPool, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [64]:
def _keras_smp(x):
    m = K.max(x, axis=-1, keepdims=True)
    sm = K.exp(x - m)
    w = sm / K.sum(sm, axis=-1, keepdims=True)
    return K.sum(x * w, axis=-1, keepdims=True)
    
def _keras_smp_shape(input_shape):
    shape = list(input_shape)
    shape[-1] = 1
    return tuple(shape)

In [65]:
# Numpy implementation for reference
def smp(x, axis=-1):
    
    m = x.max(axis=axis, keepdims=True)
    sm = np.exp(x - m)   #exp(x - m) = exp(x) / exp(m)
    w = sm / np.sum(sm, axis=axis, keepdims=True)
    
    return (x * w).sum(axis=axis, keepdims=True)

In [66]:
LSoftMaxPool = Lambda(_keras_smp, output_shape=_keras_smp_shape)

In [67]:
x = np.random.randn(10, 5)

In [68]:
x

array([[ 1.15473962, -1.11731115,  0.41548108, -1.78850882, -1.13391072],
       [ 0.13043004, -0.33131481,  0.27255506,  1.11169833, -0.71478944],
       [-0.74502153, -0.08293015,  0.10975802,  0.50268116, -1.49282068],
       [ 0.75382782,  1.42877109,  0.80678389,  0.79888669, -0.77715123],
       [-0.93036336, -1.78314324, -0.81846917, -0.04201394,  1.1983932 ],
       [-0.68162087, -1.03825424,  0.8065906 , -1.74972556, -0.62325689],
       [ 0.63785869,  0.15823746,  0.26292411, -1.52545145, -1.59493191],
       [-1.19780559,  0.43175324, -0.65264047,  2.1262413 , -0.9754042 ],
       [-1.00404137, -0.01230852,  0.40263937,  0.17507494,  0.55413324],
       [-0.40114422,  0.98943446, -0.69655683,  0.47005621, -1.16881525]])

In [69]:
y = smp(x**2, axis=-1)

In [70]:
y

array([[ 2.53212754],
       [ 0.640086  ],
       [ 1.53513054],
       [ 1.34285324],
       [ 2.55659086],
       [ 2.42890102],
       [ 2.14244243],
       [ 4.18876931],
       [ 0.46563264],
       [ 0.86693713]])

In [71]:
model = Sequential()

inputs = Input(shape=(10, 5))

lpredictions = LSoftMaxPool(inputs)
predictions = SoftMaxPool(axis=-1)(inputs)

model = Model(input=inputs, output=[lpredictions, predictions])
model.compile(optimizer='rmsprop', loss='binary_crossentropy')

In [72]:
model.predict(np.asarray([x])**2)

[array([[[ 2.53212786],
         [ 0.64008605],
         [ 1.5351305 ],
         [ 1.34285331],
         [ 2.5565908 ],
         [ 2.42890072],
         [ 2.14244223],
         [ 4.18876934],
         [ 0.46563268],
         [ 0.8669371 ]]], dtype=float32), array([[[ 2.53212786],
         [ 0.64008605],
         [ 1.5351305 ],
         [ 1.34285331],
         [ 2.5565908 ],
         [ 2.42890072],
         [ 2.14244223],
         [ 4.18876934],
         [ 0.46563268],
         [ 0.8669371 ]]], dtype=float32)]

In [73]:
smp(x**2)

array([[ 2.53212754],
       [ 0.640086  ],
       [ 1.53513054],
       [ 1.34285324],
       [ 2.55659086],
       [ 2.42890102],
       [ 2.14244243],
       [ 4.18876931],
       [ 0.46563264],
       [ 0.86693713]])

In [74]:
smp(10000000 + x) - 10000000

array([[ 0.59302298],
       [ 0.49227469],
       [ 0.04283264],
       [ 0.96162534],
       [ 0.55035915],
       [ 0.1197499 ],
       [ 0.2254924 ],
       [ 1.56620865],
       [ 0.24131747],
       [ 0.422051  ]])