### Differentiable argmax / Soft argmax

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer, Input, Dense, Conv1D, Reshape

  from ._conv import register_converters as _register_converters


In [2]:
class DifferentiableArgmax(Layer):
    '''
    Differentiable argmax / Soft argmax
    
    Numerically stable in classical neural network value range (-3..3)
    Likely to work between -6..6 but not fully tested
    
    For using with convolutional filters you might want to swap axes:
    
    c = Conv1D(4, 3) (input)
    argmax = DifferentiableArgmax() (c)
    argmax -> maximum of filters per width step (Which filter fire stronger at width steps?)
    
    vs
    
    c = Conv1D(4, 3) (input)
    argmax = tf.transpose(c, perm=[0,2,1]) # (batch, width, filter) -> (batch, filter, width)
    argmax = DifferentiableArgmax() (argmax)
    argmax -> maximum in width dimension per filter (Where the filters fire strongest?)
    '''
    
    def __init__(self):
        pass
    
    def __call__(self, inputs):       
        # For numerical stability -> sum never zero
        scaling = tf.math.exp(inputs)
        
        # Make small values smaller, high values higher -> easier to find difference between 0.1 and 0.11
        a = tf.math.pow(scaling, 10) 
        sum_a = tf.reduce_sum(a, axis=-1)
        sum_a = tf.expand_dims(sum_a, axis=-1)
        # Ideally highest value 1, everything else is zero
        onehot = tf.divide(a, sum_a)
        
        # Variable onehot might be a ambiguous if input values are close to each other
        # Solution: repeat cycle one or more times
        a = tf.math.pow(onehot, 10) 
        sum_a = tf.reduce_sum(a, axis=-1)
        sum_a = tf.expand_dims(sum_a, axis=-1)
        onehot = tf.divide(a, sum_a)
        
        a = tf.math.pow(onehot, 10) 
        sum_a = tf.reduce_sum(a, axis=-1)
        sum_a = tf.expand_dims(sum_a, axis=-1)
        onehot = tf.divide(a, sum_a)
        
        a = tf.math.pow(onehot, 10) 
        sum_a = tf.reduce_sum(a, axis=-1)
        sum_a = tf.expand_dims(sum_a, axis=-1)
        onehot = tf.divide(a, sum_a)
        
        # Get argmax of one-hot encoded input
        cumsum = tf.cumsum(onehot, axis = -1, exclusive = True, reverse = True)
        rounding = 2*(tf.clip_by_value(cumsum, clip_value_min=.5, clip_value_max=1) - .5)
        token = tf.reduce_sum(rounding, axis = -1)
        token = tf.expand_dims(token, axis=-1)
        
        return [inputs, token]

#### Example: Dense layer

In [None]:
# Model (for random values test)
input0 =  Input((1))
d = Dense(5, use_bias=False, name='output') (input0)
argmax = DifferentiableArgmax()(d)
model = Model(input0, [argmax])
model.compile(optimizer='adam', loss='mae')
#model.summary()

# Random value test
prediction = model.predict([-2, -1, 0, 1, 2])
print('Input of argmax (5 sample. 1 list of 5 random values per sample):\n', prediction[0])
print('\nOutput of argmax (index of maximum in lists):\n', prediction[1])

#### Example: Get indexes where convolutional filters fire

In [None]:
timesteps = 7
channels= 1

# Model
input0 =  Input((timesteps, channels))
c = Conv1D(4, 3, padding='valid') (input0)
argmax = tf.transpose(c, perm=[0,2,1]) # (batch, width, filter) -> (batch, filter, width)
argmax = DifferentiableArgmax() (argmax)
model = Model(input0, argmax)
model.compile(optimizer='adam', loss='mse')
#model.summary()

# Print values
samples = [[np.random.random_sample((timesteps,channels))] * 2]
prediction = model.predict(samples)
print('Input of argmax:\n', prediction[0])
print('\nOutput of argmax:\n', prediction[1])

#### Shape test

In [None]:
x = tf.constant([1.0, 4.0, 1.0, 1.0])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

In [None]:
x = tf.constant([[1.0, 4.0, 1.0, 1.0], [1.0, 1.0, 1.0, 2.0]])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

In [None]:
x = tf.constant([[[1.0, 4.0, 1.0, 1.0], [1.0, 4.0, 1.0, 1.0]],
                 [[1.0, 4.0, 1.0, 1.0], [1.0, 1.0, 1.0, 2.0]]])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

#### Decision test

In [None]:
x = tf.constant([0., 0., 0., 0.])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

In [None]:
x0 = tf.constant([0., 0.0001, 0, 0.])
x1 = tf.constant([0., 0.001, 0, 0.])
x2 = tf.constant([0., 0.01, 0, 0.])
x3 = tf.constant([0., 0.1, 0, 0.])
print('x')
print(x0)
print(x1)
print(x2)
print(x3)
print()

argmax0 = DifferentiableArgmax() (x0)
argmax1 = DifferentiableArgmax() (x1)
argmax2 = DifferentiableArgmax() (x2)
argmax3 = DifferentiableArgmax() (x3)
print('argmax')
print(argmax0[1])
print(argmax1[1])
print(argmax2[1])
print(argmax3[1])

#### Extreme value test

In [None]:
# Minus end of stability
x = tf.constant([-11., -11., -10., -11.])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

In [None]:
# Minus begining of instability
x = tf.constant([-11., -12., -12., -12.])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

In [None]:
# Plus end of stability
x = tf.constant([7., 7., 7., 8.])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])

In [None]:
# Plus begining of instability
x = tf.constant([8., 8., 8., 9.])
print('x')
print(x)
print()

argmax = DifferentiableArgmax() (x)
print('argmax')
print(argmax[1])