### Differentiable argmax / Soft argmax

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer, Input, Dense, Conv1D, Reshape

  from ._conv import register_converters as _register_converters


In [2]:
class DifferentiableArgmax(Layer):
    '''
    Differentiable argmax / Soft argmax
    
    Numerically stable in classical neural network value range (-3..3)
    Likely to work between -6..6 but not fully tested
    
    For using with convolutional filters you might want to swap axes:
    
    c = Conv1D(4, 3) (input)
    argmax = DifferentiableArgmax() (c)
    argmax -> maximum of filters per width step (Which filter fire stronger at width steps?)
    
    vs
    
    c = Conv1D(4, 3) (input)
    argmax = tf.transpose(c, perm=[0,2,1]) # (batch, width, filter) -> (batch, filter, width)
    argmax = DifferentiableArgmax() (argmax)
    argmax -> maximum in width dimension per filter (Where the filters fire strongest?)
    '''
    
    def __init__(self):
        pass
    
    def __call__(self, inputs):  
        # Step 1: Scaling for numerical stability
        # During step 2, the sum will never be zero
        # Theoretically float related rounding error possible but neural networks don't use that value range.
        scaling = tf.math.exp(inputs)
        
        # Step 2: Binarize input. Ideally highest value 1, everything else is zero.
        # The 'binarized' variable might not be ideal after the first step if input values are close to each other.
        # Solution is to repeat binarization function.
        # Trying higher power im binarization function will result in numerical instability quickly.
        def binarize(x):
            # Step 1: Make small values smaller, high values higher
            # Highest value will likely to be at least a magnitude higher than the rest.
            a = tf.math.pow(x, 10) 
            # Step 2: Divide all value with the sum.
            # Due to the magnitude difference the highest value / sum will be close to 1.
            # Everything else will be closer to zero.
            sum_a = tf.reduce_sum(a, axis=-1)
            sum_a = tf.expand_dims(sum_a, axis=-1)
            onehot = tf.divide(a, sum_a)            
            return onehot
        
        binarized = binarize(scaling)
        binarized = binarize(binarized)
        binarized = binarize(binarized)
        binarized = binarize(binarized)
        binarized = binarize(binarized)
        binarized = binarize(binarized)
        
        # Step 3: Get argmax of one-hot encoded input
        cumsum = tf.cumsum(binarized, axis = -1, exclusive = True, reverse = True)
        rounding = 2*(tf.clip_by_value(cumsum, clip_value_min=.5, clip_value_max=1) - .5)
        token = tf.reduce_sum(rounding, axis = -1)
        token = tf.expand_dims(token, axis=-1)
        
        return [inputs, token]

#### Example: Dense layer

In [3]:
# Model (for random values test)
input0 =  Input((1))
d = Dense(4, use_bias=False, name='output') (input0)
argmax = DifferentiableArgmax()(d)
model = Model(input0, [argmax])
model.compile(optimizer='adam', loss='mae')
#model.summary()

# Random value test
prediction = model.predict([-2, -1, 0, 1, 2])
print('Input of argmax:\n', prediction[0])
print('\nOutput of argmax:\n', prediction[1])

Input of argmax:
 [[ 1.6582605   0.7575089  -2.009025   -1.6150696 ]
 [ 0.82913023  0.37875444 -1.0045125  -0.8075348 ]
 [ 0.          0.          0.          0.        ]
 [-0.82913023 -0.37875444  1.0045125   0.8075348 ]
 [-1.6582605  -0.7575089   2.009025    1.6150696 ]]

Output of argmax:
 [[0. ]
 [0. ]
 [0.5]
 [2. ]
 [2. ]]


#### Example: Get indexes where convolutional filters fire

In [4]:
timesteps = 7
channels= 1

# Model
input0 =  Input((timesteps, channels))
c = Conv1D(4, 3, padding='valid') (input0)
argmax = tf.transpose(c, perm=[0,2,1]) # (batch, width, filter) -> (batch, filter, width)
argmax = DifferentiableArgmax() (argmax)
model = Model(input0, argmax)
model.compile(optimizer='adam', loss='mse')
#model.summary()

# Print values
samples = [[np.random.random_sample((timesteps,channels))] * 2]
prediction = model.predict(samples)
print('Input of argmax:\n', prediction[0])
print('\nOutput of argmax:\n', prediction[1])

Input of argmax:
 [[[-0.31248814 -0.31088996  0.03930115 -0.49696922 -0.07191789]
  [ 0.04657157 -0.10839222  0.26490483 -0.11642914  0.02724978]
  [ 0.18337694  0.42130166 -0.03512901  0.2450471   0.32034644]
  [-0.17251475 -0.12242899  0.03154935 -0.2943121   0.02682087]]

 [[-0.31248814 -0.31088996  0.03930115 -0.49696922 -0.07191789]
  [ 0.04657157 -0.10839222  0.26490483 -0.11642914  0.02724978]
  [ 0.18337694  0.42130166 -0.03512901  0.2450471   0.32034644]
  [-0.17251475 -0.12242899  0.03154935 -0.2943121   0.02682087]]]

Output of argmax:
 [[[2.]
  [2.]
  [1.]
  [2.]]

 [[2.]
  [2.]
  [1.]
  [2.]]]


#### Shape test

In [5]:
x = tf.constant([1.0, 4.0, 1.0, 1.0])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

x
tf.Tensor([1. 4. 1. 1.], shape=(4,), dtype=float32)

argmax
tf.Tensor([1.], shape=(1,), dtype=float32)


In [6]:
x = tf.constant([[1.0, 4.0, 1.0, 1.0], [1.0, 1.0, 1.0, 2.0]])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

x
tf.Tensor(
[[1. 4. 1. 1.]
 [1. 1. 1. 2.]], shape=(2, 4), dtype=float32)

argmax
tf.Tensor(
[[1.]
 [3.]], shape=(2, 1), dtype=float32)


In [7]:
x = tf.constant([[[1.0, 4.0, 1.0, 1.0], [1.0, 4.0, 1.0, 1.0]],
                 [[1.0, 4.0, 1.0, 1.0], [1.0, 1.0, 1.0, 2.0]]])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

x
tf.Tensor(
[[[1. 4. 1. 1.]
  [1. 4. 1. 1.]]

 [[1. 4. 1. 1.]
  [1. 1. 1. 2.]]], shape=(2, 2, 4), dtype=float32)

argmax
tf.Tensor(
[[[1.]
  [1.]]

 [[1.]
  [3.]]], shape=(2, 2, 1), dtype=float32)


#### Decision test

In [8]:
x = tf.constant([0., 0., 0., 0.])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

x
tf.Tensor([0. 0. 0. 0.], shape=(4,), dtype=float32)

argmax
tf.Tensor([0.5], shape=(1,), dtype=float32)


In [9]:
x0 = tf.constant([0., 0.000001, 0, 0.])
x1 = tf.constant([0., 0.00001, 0, 0.])
x2 = tf.constant([0., 0.0001, 0, 0.])
x3 = tf.constant([0., 0.001, 0, 0.])
print('x')
print(x0)
print(x1)
print(x2)
print(x3)
print()

argmax0 = DifferentiableArgmax() (x0)
argmax1 = DifferentiableArgmax() (x1)
argmax2 = DifferentiableArgmax() (x2)
argmax3 = DifferentiableArgmax() (x3)
print('argmax')
print(argmax0[1])
print(argmax1[1])
print(argmax2[1])
print(argmax3[1])

x
tf.Tensor([0.e+00 1.e-06 0.e+00 0.e+00], shape=(4,), dtype=float32)
tf.Tensor([0.e+00 1.e-05 0.e+00 0.e+00], shape=(4,), dtype=float32)
tf.Tensor([0.e+00 1.e-04 0.e+00 0.e+00], shape=(4,), dtype=float32)
tf.Tensor([0.    0.001 0.    0.   ], shape=(4,), dtype=float32)

argmax
tf.Tensor([0.6425544], shape=(1,), dtype=float32)
tf.Tensor([0.9999105], shape=(1,), dtype=float32)
tf.Tensor([1.], shape=(1,), dtype=float32)
tf.Tensor([1.], shape=(1,), dtype=float32)


#### Extreme value test

In [10]:
# Minus end of stability
x = tf.constant([-11., -11., -10., -11.])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

x
tf.Tensor([-11. -11. -10. -11.], shape=(4,), dtype=float32)

argmax
tf.Tensor([2.], shape=(1,), dtype=float32)


In [11]:
# Minus begining of instability
x = tf.constant([-11., -12., -12., -12.])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

x
tf.Tensor([-11. -12. -12. -12.], shape=(4,), dtype=float32)

argmax
tf.Tensor([nan], shape=(1,), dtype=float32)


In [12]:
# Plus end of stability
x = tf.constant([7., 7., 7., 8.])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

x
tf.Tensor([7. 7. 7. 8.], shape=(4,), dtype=float32)

argmax
tf.Tensor([3.], shape=(1,), dtype=float32)


In [13]:
# Plus begining of instability
x = tf.constant([8., 8., 8., 9.])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

x
tf.Tensor([8. 8. 8. 9.], shape=(4,), dtype=float32)

argmax
tf.Tensor([nan], shape=(1,), dtype=float32)
