[View in Colaboratory](https://colab.research.google.com/github/juliusfrost/Research-Paper-Implementations/blob/master/Neural_Arithmetic_Logic_Units.ipynb)

#Neural Arithmetic Logic Units

https://arxiv.org/abs/1808.00508

In [0]:
import tensorflow as tf

## The NALU layer

In [0]:
class NALU(tf.keras.layers.Layer):
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(NALU, self).__init__(**kwargs)
    
  def build(self, input_shape):
    shape = tf.TensorShape((input_shape[1], self.output_dim))
    initializer = tf.keras.initializers.RandomUniform(minval=-1, maxval=1)
    self.G = self.add_weight(name='G', shape=shape, initializer=initializer, trainable=True)
    self.W_hat = self.add_weight(name='W_hat', shape=shape, initializer=initializer, trainable=True)
    self.M_hat = self.add_weight(name='M_hat', shape=shape, initializer=initializer, trainable=True)
    
    super(NALU, self).build(input_shape)
    
    
  def call(self, x):
    W = tf.tanh(self.W_hat) * tf.sigmoid(self.M_hat)
    a = tf.matmul(x, W)
    g = tf.sigmoid(tf.matmul(x, self.G))
    m = tf.exp(tf.matmul(tf.log(tf.abs(x)+0.0001), W))
    y = g*a + (1-g)*m
    return y
  
  def compute_output_shape(self, input_shape):
    shape = tf.TensorShape(input_shape).as_list()
    shape[-1] = self.output_dim
    return tf.TensorShape(shape)
  
  def get_config(self):
    base_config = super(NALU, self).get_config()
    base_config['output_dim'] = self.output_dim
    
  @classmethod
  def from_config(cls, config):
    return cls(**config)


##Train

###seed

In [0]:
import numpy as np
from numpy.random import seed
seed(42)
from tensorflow import set_random_seed
set_random_seed(42)

###build model

In [0]:
x_train = np.random.random((10000, 2))*20-10
y_train = x_train[:,0] + x_train[:,1]


model = tf.keras.Sequential([
    NALU(8),
    NALU(1),
])
model.compile(optimizer=tf.train.AdamOptimizer(0.001),
              loss='mse',       # mean squared error
              metrics=['mae'])

###train model

In [324]:
model.fit(x_train, y_train, batch_size=32, epochs=100, verbose=1)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<tensorflow.python.keras.callbacks.History at 0x7fa9aa74a668>

##Test

Even though we trained on values in the range (-10,10), the model is great at extrapolating to values in the range (-100,100)

In [344]:
r = 100
x_test = np.random.random((10, 2)) *r*2 - r
y_test = x_test[:,0] + x_test[:,1]

for i,y in enumerate(model.predict(x_test)):
  print(str(round(x_test[i][0],2)) + ' + ' + str(round(x_test[i][1],2)) + '\npredicted: ' + str(round(y[0],2)) + '\nactual: ' + str(round(y_test[i],2)) + '\n')

-57.39 + -84.0
predicted: -141.01
actual: -141.39

18.67 + -56.1
predicted: -37.2
actual: -37.43

68.94 + 57.42
predicted: 126.01
actual: 126.35

3.09 + 28.73
predicted: 31.73
actual: 31.82

64.61 + 93.88
predicted: 158.03
actual: 158.5

-68.33 + 34.98
predicted: -33.48
actual: -33.34

87.28 + 86.11
predicted: 172.87
actual: 173.39

82.12 + 85.11
predicted: 166.74
actual: 167.23

-14.31 + 6.01
predicted: -8.32
actual: -8.3

96.28 + -48.89
predicted: 47.07
actual: 47.4

