In [1]:
import numpy as np
import keras
import keras.backend as K

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
class NALU(keras.layers.Layer):
    
    def __init__(self, units, **kwargs):
        self.units = units
        super(NALU, self).__init__(**kwargs)
    
    def build(self, input_shape):
        self.weight_sigmoid = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer='glorot_uniform', name='weight_sigmoid')
        self.weight_tanh = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer='glorot_uniform', name='weight_tanh')
        self.weight_gate = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer='glorot_uniform', name='weight_gate')
        super(NALU, self).build(input_shape)
    
    def call(self, inputs):
        w = K.sigmoid(self.weight_sigmoid) * K.tanh(self.weight_tanh)
        adds = inputs @ w
        muls = K.exp((K.log(K.abs(inputs) + K.epsilon())) @ w)
        gates = K.sigmoid(inputs @ self.weight_gate)
        return gates * adds + (1-gates) * muls
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.units)

In [35]:
X = X_input = keras.layers.Input((2,))
X = NALU(8)(X)
X = NALU(1)(X)
M = keras.Model(X_input, X)
M.compile('nadam', 'mse')
M.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         (None, 2)                 0         
_________________________________________________________________
nalu_14 (NALU)               (None, 8)                 48        
_________________________________________________________________
nalu_15 (NALU)               (None, 1)                 24        
Total params: 72
Trainable params: 72
Non-trainable params: 0
_________________________________________________________________


In [36]:
def toy_data(a, b):
    return (2*a) / b
data_input = np.random.random(size=(10000,2))*100 + 1
data_label = toy_data(*np.swapaxes(data_input,0,1))

In [37]:
M.fit(data_input, data_label, batch_size=16, epochs=1000, callbacks=[
    keras.callbacks.ReduceLROnPlateau('loss', patience=3, verbose=1),
    keras.callbacks.EarlyStopping('loss', patience=10, verbose=1),
])

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000

Epoch 00061: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epo

Epoch 88/1000
Epoch 89/1000

Epoch 00089: ReduceLROnPlateau reducing learning rate to 2.000000165480742e-10.
Epoch 90/1000
Epoch 91/1000
Epoch 92/1000

Epoch 00092: ReduceLROnPlateau reducing learning rate to 2.000000165480742e-11.
Epoch 93/1000
Epoch 94/1000
Epoch 95/1000

Epoch 00095: ReduceLROnPlateau reducing learning rate to 2.000000165480742e-12.
Epoch 96/1000
Epoch 97/1000
Epoch 98/1000

Epoch 00098: ReduceLROnPlateau reducing learning rate to 2.000000208848829e-13.
Epoch 99/1000
Epoch 100/1000
Epoch 101/1000

Epoch 00101: ReduceLROnPlateau reducing learning rate to 2.0000002359538835e-14.
Epoch 00101: early stopping


<keras.callbacks.History at 0x7f4c43047710>

In [46]:
test = np.array([[100, 3]])
print(toy_data(*test[0]))
print(M.predict(test))

66.66666666666667
[[66.33641]]
