<img src="images/nalu.png" align=right width=40%></img>
# Neural Arithmetic Logic Units
Author: Jin Yeom (jinyeom@utexas.edu)

## Contents
- [Neural accumulator](#Neural-accumulator)
- [Neural arithmetic logic unit](#Neural-arithmetic-logic-unit)
- [Experiments](#Experiments)
    - [Simple function learning tasks](#Simple-function-learning-tasks)

In [5]:
import math

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torchsummary import summary
from tqdm import tqdm_notebook as tqdm

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device =", device)

device = cpu


## Neural accumulator

In [11]:
class NAC(nn.Module):
    def __init__(self, in_features, out_features):
        super(NAC, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.W_hat = nn.Parameter(torch.Tensor(out_features, in_features))
        self.M_hat = nn.Parameter(torch.Tensor(out_features, in_features))
        self.W = torch.tanh(self.W_hat) * torch.sigmoid(self.M_hat)
        self.reset_parameters()
        
    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.in_features)
        self.W_hat.data.uniform_(-stdv, stdv)
        self.M_hat.data.uniform_(-stdv, stdv)
        
    def forward(self, x):
        return F.linear(x, self.W, None)

## Neural arithmetic logic unit

In [13]:
class NALU(nn.Module):
    def __init__(self, in_features, out_features, eps=1e-8):
        super(NALU, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.nac1 = NAC(in_features, out_features) # add/sub
        self.nac2 = NAC(in_features, out_features) # mul/div
        self.G = nn.Parameter(torch.Tensor(in_features))
        self.eps = eps
        
    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.in_features)
        self.G.data.uniform_(-stdv, stdv)
        self.nac1.reset_parameters()
        self.nac2.reset_parameters()
        
    def forward(self, x):
        g = torch.sigmoid(torch.matmul(x, self.G))
        a = self.nac1(x)
        m = torch.exp(self.nac2(torch.log(torch.abs(x) + self.eps)))
        return g * a + (1 - g) * m

## Experiments

### Simple function learning tasks

In [None]:
def generate_data(num_train, num_test, op, dim=100, sub_dim=10, range_=(5, 10)):
    data = torch.Tensor(dim).uniform_(*range_).unsqueeze_(1)
    X, y = [], []
    for i in tqdm(num_train + num_test):
        indices = np.split(np.random.choice(dim, sub_dim * 2), 2)
        a = data[indices[0]].sum()
        b = data[indices[1]].sum()
        X.append((a, b))
        y.append(op(a, b))
    X = torch.tensor(X)
    y = torch.tensor(y)