In [8]:
import numpy as np
import keras
import keras.backend as K

In [287]:
class NALU(keras.layers.Layer):
    
    def __init__(self, units, **kwargs):
        self.units = units
        super(NALU, self).__init__(**kwargs)
    
    def build(self, input_shape):
        self.weight_sigmoid = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer='glorot_normal', name='weight_sigmoid')
        self.weight_tanh = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer='glorot_normal', name='weight_tanh')
        self.weight_gate = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer='glorot_normal', name='weight_gate')
        super(NALU, self).build(input_shape)
    
    def call(self, inputs):
        w = K.sigmoid(self.weight_sigmoid) * K.tanh(self.weight_tanh)
        adds = inputs @ w
        muls = K.exp(K.log(K.abs(inputs) + K.epsilon()) @ w)
        gates = K.sigmoid(inputs @ self.weight_gate)
        return gates * adds + (1-gates) * muls
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.units)

In [288]:
X = X_input = keras.layers.Input((2,))
X = NALU(1)(X)
M = keras.Model(X_input, X)
M.compile('adam', 'mse')
M.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_80 (InputLayer)        (None, 2)                 0         
_________________________________________________________________
nalu_86 (NALU)               (None, 1)                 6         
Total params: 6
Trainable params: 6
Non-trainable params: 0
_________________________________________________________________


In [294]:
def toy_data(a, b):
    return a+b+1
data_input = np.random.random(size=(10000,2))*100 + 1
data_label = toy_data(data_input[:,0], data_input[:,1])

In [295]:
M.fit(data_input, data_label, epochs=1000, callbacks=[
    #keras.callbacks.ReduceLROnPlateau('loss', patience=3, verbose=1),
    #keras.callbacks.EarlyStopping('loss')
])

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
E

Epoch 93/1000
Epoch 94/1000
Epoch 95/1000
Epoch 96/1000
Epoch 97/1000
Epoch 98/1000
Epoch 99/1000
Epoch 100/1000
Epoch 101/1000
Epoch 102/1000
Epoch 103/1000
Epoch 104/1000
Epoch 105/1000
Epoch 106/1000
Epoch 107/1000
Epoch 108/1000
Epoch 109/1000
Epoch 110/1000
Epoch 111/1000
Epoch 112/1000
Epoch 113/1000
Epoch 114/1000
Epoch 115/1000
Epoch 116/1000
Epoch 117/1000
Epoch 118/1000
Epoch 119/1000
Epoch 120/1000
Epoch 121/1000
Epoch 122/1000
Epoch 123/1000
Epoch 124/1000
Epoch 125/1000
Epoch 126/1000
Epoch 127/1000
Epoch 128/1000
Epoch 129/1000
Epoch 130/1000
Epoch 131/1000
Epoch 132/1000
Epoch 133/1000
Epoch 134/1000
Epoch 135/1000
Epoch 136/1000
Epoch 137/1000
Epoch 138/1000
Epoch 139/1000
Epoch 140/1000
Epoch 141/1000
Epoch 142/1000
Epoch 143/1000
Epoch 144/1000
Epoch 145/1000
Epoch 146/1000
Epoch 147/1000
Epoch 148/1000
Epoch 149/1000
   32/10000 [..............................] - ETA: 1s - loss: 0.9276

KeyboardInterrupt: 

In [296]:
M.predict(np.array([[3,2]]))

array([[5.2539997]], dtype=float32)