## Neural Arithmatic Logic Units
Google DeepMind's research paper: https://arxiv.org/abs/1808.00508

In [73]:
import tensorflow as tf
import numpy as np

![title](images/naluandnac.png)
**The Neural Accumulator (NAC)** is a linear transformation of its inputs.
And the tranformation matrix is elementwise product of **tanh(W)** and **sigmoid(M)**

**The Neural Arithmetic Logic Unit (NALU)** uses two NACs with tied weights.

In [74]:
# The Neural Arithmetic Logic Unit
def NALU(in_dim, out_dim):

    shape = (int(in_dim.shape[-1]), out_dim)
    epsilon = 1e-7 
    
    # NAC
    W_hat = tf.Variable(tf.truncated_normal(shape, stddev=0.02))
    M_hat = tf.Variable(tf.truncated_normal(shape, stddev=0.02))
    G = tf.Variable(tf.truncated_normal(shape, stddev=0.02))
        
    W = tf.tanh(W_hat) * tf.sigmoid(M_hat)
    # Forward propogation
    a = tf.matmul(in_dim, W)
    
    # NALU  
    m = tf.exp(tf.matmul(tf.log(tf.abs(in_dim) + epsilon), W))
    g = tf.sigmoid(tf.matmul(in_dim, G))
    y = g * a + (1 - g) * m
    
    return y

### Helper Function

In [75]:
def generate_dataset(size=10000):
    # input data
    X = np.random.randint(9, size=(size,2))
    # output data (labels)   
    Y = np.prod(X, axis=1, keepdims=True)

        
    return X, Y

### Train NALU on generated data

In [76]:
# Hyperparameters
EPOCHS = 200
LEARNING_RATE = 1e-3
BATCH_SIZE = 10

In [77]:
# create dataset
X_data, Y_data = generate_dataset()

In [78]:
# define placeholders and network
X = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 2])

Y_true = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 1])

Y_pred = NALU(X, 1)

In [79]:
loss = tf.nn.l2_loss(Y_pred - Y_true) 
    
optimizer = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

In [80]:
# create session
sess = tf.Session()
# create writer to store tensorboard graph   
writer = tf.summary.FileWriter('/tmp', sess.graph)
    
init = tf.global_variables_initializer()
    
sess.run(init)

# Run training loop
for i in range(EPOCHS):
    j = 0
    g = 0
        
    while j < len(X_data):
        xs, ys = X_data[j:j + BATCH_SIZE], Y_data[j:j + BATCH_SIZE]

        _, ys_pred, l = sess.run([optimizer, Y_pred, loss], 
                    feed_dict={X: xs, Y_true: ys})
            
        # calculate number of correct predictions from batch
        g += np.sum(np.isclose(ys, ys_pred, atol=1e-4, rtol=1e-4)) 

        j += BATCH_SIZE

    acc = g / len(Y_data)
        
    print(f'epoch {i}, loss: {l}, accuracy: {acc}')

epoch 0, loss: 2732.2294921875, accuracy: 0.0062
epoch 1, loss: 530.1744384765625, accuracy: 0.0104
epoch 2, loss: 195.937744140625, accuracy: 0.0101
epoch 3, loss: 92.95536041259766, accuracy: 0.0101
epoch 4, loss: 49.38507080078125, accuracy: 0.0101
epoch 5, loss: 27.861385345458984, accuracy: 0.0101
epoch 6, loss: 16.27603530883789, accuracy: 0.0101
epoch 7, loss: 9.71354866027832, accuracy: 0.014
epoch 8, loss: 5.8773369789123535, accuracy: 0.0101
epoch 9, loss: 3.588534355163574, accuracy: 0.0101
epoch 10, loss: 2.204646587371826, accuracy: 0.0101
epoch 11, loss: 1.360314965248108, accuracy: 0.0101
epoch 12, loss: 0.8419667482376099, accuracy: 0.0101
epoch 13, loss: 0.5223231315612793, accuracy: 0.0101
epoch 14, loss: 0.3245672583580017, accuracy: 0.017
epoch 15, loss: 0.20196343958377838, accuracy: 0.0226
epoch 16, loss: 0.1257719099521637, accuracy: 0.0163
epoch 17, loss: 0.07839271426200867, accuracy: 0.0336
epoch 18, loss: 0.04889478161931038, accuracy: 0.0538
epoch 19, loss: 

epoch 151, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 152, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 153, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 154, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 155, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 156, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 157, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 158, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 159, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 160, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 161, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 162, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 163, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 164, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 165, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 166, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 167, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 168, loss: 3.7393220464476684e-11, accuracy: 1.0
epoch 169,

### Uncomment to run TensorBoard

In [72]:
# !tensorboard --logdir /tmp

TensorBoard 1.9.0 at http://Akils-Air-2.home:6006 (Press CTRL+C to quit)
^C
