In [19]:
from keras import backend as K
from keras.engine.topology import Layer
from keras.layers import Lambda
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Activation, Flatten, Input
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils

In [2]:
import numpy as np

In [3]:
def _keras_smp(x):
    m = K.max(x, axis=-1, keepdims=True)
    sm = K.exp(x - m)
    w = sm / K.sum(sm, axis=-1, keepdims=True)
    return K.sum(x * w, axis=-1)
    
def _keras_smp_shape(input_shape):
    shape = list(input_shape)
    return tuple(shape[:-1])

In [11]:
def smp(x, axis=-1):
    
    m = x.max(axis=axis, keepdims=True)
    sm = np.exp(x - m)   #exp(x - m) = exp(x) / exp(m)
    w = sm / np.sum(sm, axis=axis, keepdims=True)
    
    return (x * w).sum(axis=axis)

In [12]:
SoftMaxPool = Lambda(_keras_smp, output_shape=_keras_smp_shape)

In [13]:
x = np.random.randn(10, 5)

In [14]:
x

array([[-0.43883716, -0.99593852, -1.00922389, -0.02993694, -0.59930385],
       [ 0.08620084,  0.10199767, -0.33701471,  0.54612009,  2.19451448],
       [ 0.91566838, -0.34808806, -1.39974601, -0.7126849 ,  0.51767488],
       [ 2.38131497, -1.94506525,  0.05374397, -0.68452835, -0.82992932],
       [-1.15572444,  0.44791267,  1.57010151, -0.33031556,  0.08427308],
       [-0.27722465, -0.4196747 , -0.77073023,  0.3595472 ,  0.72007522],
       [-0.24289576, -0.68263   ,  0.43795361,  0.42691328,  0.40979546],
       [-1.12907955, -0.75063486,  0.04952153,  0.54253876,  0.68944537],
       [-0.15829633, -1.14638162, -0.51998859, -0.56066538,  0.43501487],
       [-0.89704661,  1.32197154, -0.33352295,  0.11153553, -1.03825087]])

In [15]:
y = smp(x, axis=-1)

In [17]:
y

array([-0.47505555,  1.51367474,  0.39433429,  1.91663342,  0.91116797,
        0.21472835,  0.24068545,  0.29497026, -0.11760838,  0.64767425])

In [20]:
model = Sequential()

inputs = Input(shape=(10, 5))

predictions = SoftMaxPool(inputs)

model = Model(input=inputs, output=predictions)
model.compile(optimizer='rmsprop', loss='binary_crossentropy')

In [21]:
model.predict(np.asarray([x]))

array([[-0.47505555,  1.51367474,  0.39433429,  1.91663337,  0.91116792,
         0.21472836,  0.24068543,  0.29497027, -0.1176084 ,  0.64767414]])

In [22]:
smp(x)

array([-0.47505555,  1.51367474,  0.39433429,  1.91663342,  0.91116797,
        0.21472835,  0.24068545,  0.29497026, -0.11760838,  0.64767425])

In [23]:
smp(10000000 + x) - 10000000

array([-0.47505555,  1.51367474,  0.39433429,  1.91663342,  0.91116797,
        0.21472835,  0.24068545,  0.29497026, -0.11760838,  0.64767425])