Premise: train a neural net to add n-bit integers

input data in the form:  
X = [[1,0,...,1,0],  
     [1,0,...,0,1]]  
X.shape = (2,n)  
outpud data:  
Y = X[0] + x[1] in binary as a list with size (n+1)  




In [142]:
from fastai import *
from fastai.data.all import *
from fastai.vision.all import *
import torch

In [239]:
# data generation

def binary(x, bits):
    """takes a tensor of ints and returns a tensor of shape = (*x.shape, bits)
    that represents x in binary bits
    e.g.:
    x = [3,4]
    binary(x,bits=3) --> 
        [[0, 1, 1],  # 3 in binary
         [1, 0, 0]], # 4 in binary"""
    mask = 2**torch.arange(bits-1, -1, -1).to(x.device, x.dtype)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()

def generate_data(samples, bits = 4):
    """generates `samples` number of samples of 
    2 binary numbers as X and their sum as Y
    X.shape = (samples, 2, bits)
    Y.shape = (samples, bits+1)"""
    X = torch.randint(0,2**bits, (samples,2))
    Y = X.sum(1)
    return (binary(X,bits), binary(Y,bits+1))






In [319]:
def generate_data(samples, bits = 4):
    return torch.randint(0,2, (samples, 2, bits))

In [327]:
def vec_bin_array(arr, m):
    """
    Arguments: 
    arr: Numpy array of positive integers
    m: Number of bits of each integer to retain

    Returns a copy of arr with every element replaced with a bit vector.
    Bits encoded as int8's.
    """
    to_str_func = np.vectorize(lambda x: np.binary_repr(x).zfill(m))
    strs = to_str_func(arr)
    ret = np.zeros(list(arr.shape) + [m], dtype=np.int8)
    for bit_ix in range(0, m):
        fetch_bit_func = np.vectorize(lambda x: x[bit_ix] == '1')
        ret[...,bit_ix] = fetch_bit_func(strs).astype("int8")

    return ret 

def get_y(x):
    x = x.dot(1 << np.arange(x.shape[-1] - 1, -1, -1)).sum()
    return x


db = DataBlock(
    get_items = generate_data,
    splitter = RandomSplitter(),
    get_y = get_y,
)
ds = db.datasets(5)
for sample in ds.train:
    print(sample)

(array([[0, 0, 1, 1],
       [1, 0, 0, 1]], dtype=int64), 12)
(array([[0, 0, 0, 1],
       [0, 1, 1, 0]], dtype=int64), 7)
(array([[0, 0, 1, 0],
       [0, 0, 0, 0]], dtype=int64), 2)
(array([[0, 0, 0, 0],
       [0, 0, 1, 0]], dtype=int64), 2)


In [152]:
# data generation

def binary(x, bits):
    """takes a tensor of ints and returns a tensor of shape = (*x.shape, bits)
    that represents x in binary bits
    e.g.:
    x = [3,4]
    binary(x,bits=3) --> 
        [[0, 1, 1],  # 3 in binary
         [1, 0, 0]], # 4 in binary"""
    mask = 2**torch.arange(bits-1, -1, -1).to(x.device, x.dtype)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0).float()

def generate_data_point(bits):
    """generates `samples` number of samples of 
    2 binary numbers as X and their sum as Y
    X.shape = (samples, 2, bits)
    Y.shape = (samples, bits+1)"""
    X = torch.randint(0,2**bits, (2,))
    Y = X.sum()
    return (binary(X,bits).flatten(), binary(Y,bits+1))

In [153]:
bits = 10

class RandDL(DataLoader):

    def create_item(self, s):
        if random.random() > 0.99: stop()
        return generate_data_point(bits)


dl = RandDL(bs=6, drop_last=True)
valid_dl = RandDL(bs=6, drop_last=True)
x1,y1 = first(dl)
print(x1.shape, y1.shape)
print(x1, y1)

dls = DataLoaders(dl, valid_dl)

torch.Size([6, 20]) torch.Size([6, 11])
tensor([[1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1.,
         0., 0.],
        [0., 0., 1., 1., 1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0.,
         0., 1.],
        [1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0.,
         0., 1.],
        [1., 0., 1., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1.,
         1., 1.],
        [0., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1.,
         1., 1.],
        [0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0.,
         1., 1.]]) tensor([[0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0.],
        [0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [1., 0., 1., 0., 0., 0., 1., 0., 1., 0., 1.],
        [1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0.],
        [0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1.]])


In [166]:
def adder_loss(Y_hat, Y):
    Y_hat = Y_hat.sigmoid()
    return torch.where(Y==1, 1-Y_hat, Y_hat).mean()

def calc_grad(xb, yb, model):
    preds = model(xb)
    loss = adder_loss(preds, yb)
    loss.backward()


def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

def validate_epoch(model):
    accs = [batch_accuracy(model(xb), yb) for xb, yb in valid_dl]
    return round(torch.stack(accs).mean().item(), 4)


def train_epoch(model, opt):
    for xb,yb in dl:
        calc_grad(xb, yb, model)
        opt.step()
        opt.zero_grad()


def train_model(model, opt, epochs):
    for i in range(epochs):
        train_epoch(model, opt)
        print(validate_epoch(model), end=' ')


linear_model = nn.Linear(20,1)


opt = SGD(linear_model.parameters(), 0.01)
train_model(linear_model, opt, 20)

0.5091 0.5182 0.4993 0.4864 0.4953 0.5202 0.5058 0.5175 0.5022 0.4738 0.5025 0.4996 0.3939 0.5421 0.5176 0.5146 0.4949 0.4978 0.5 0.5276 

In [156]:
learn = Learner(dls, nn.Linear(2*bits, 1), opt_func=SGD, loss_func=adder_loss, metrics = batch_accuracy)
learn.fit(10, lr=1)

IndexError: tuple index out of range