# Evaluating the NAC and NALU for learning basic functions

In [1]:
using Statistics
using Flux
include("nalu.jl")
include("data.jl");

In [2]:
loss(m, x, y) = Flux.mse(m(x), y)
val(m, val_data) = mean((loss(m, x, y) for (x, y) in val_data))

function do_experiment(f)
    (train_xs, train_ys), (perm, a_inds, b_inds, _) = binary_data(f, 10000, 100, 10)
    (interp_xs, interp_ys), _ = binary_data(f, 1000, 100, 10, perm=perm)
    (extrap_xs, extrap_ys), _ = binary_data(f, 1000, 100, 1000, perm=perm, min=10)
    
    interp_data = zip(interp_xs, interp_ys)
    extrap_data = zip(extrap_xs, extrap_ys)
    
    nac = Chain(NAC(100, 2), NAC(2, 1))
    nalu = Chain(NALU(100, 2), NALU(2, 1));

    n_epochs = 1000
    batch_size = 100
    
    # train NAC
    print("-- NAC --\n")
    for i = 1:n_epochs
        perm = randperm(length(train_xs))
        (batch_xs, batch_ys) = batch_dataset((train_xs[perm], train_ys[perm]), batch_size)
        train_data = zip(batch_xs, batch_ys)
        Flux.train!((x, y) -> loss(nac, x, y), train_data, ADAM(params(nac), 0.001))
        if i % 100 == 1
            tr_err = val(nac, zip(train_xs, train_ys))
            err = val(nac, interp_data)
            print("Epoch $i: Train Error: $tr_err, Valid Error: $err\n")
            flush(stdout)
        end
    end
    print(">\n")
    tr_err = val(nac, zip(train_xs, train_ys))
    in_err = val(nac, interp_data)
    ex_err = val(nac, extrap_data)
    print("> train error: $tr_err\n")
    print("> interpolation error: $in_err\n")
    print("> extrapolation error: $ex_err\n")
    W = tanh.(nac.layers[1].Ŵ) .* σ.(nac.layers[1].M̂)
    W2 = tanh.(nac.layers[2].Ŵ) .* σ.(nac.layers[2].M̂)
    println("> a_inds mean weight: ", [mean(W[1, a_inds]), mean(W[2, a_inds])])
    println("> b_inds mean weight: ", [mean(W[1, b_inds]), mean(W[2, b_inds])])
    println("> layer 2 W: ", W2)
    flush(stdout)

    # train NALU
    print("\n-- NALU --\n")
    for i = 1:3*n_epochs
        perm = randperm(length(train_xs))
        (batch_xs, batch_ys) = batch_dataset((train_xs[perm], train_ys[perm]), batch_size)
        train_data = zip(batch_xs, batch_ys)
        Flux.train!((x, y) -> loss(nalu, x, y), train_data, ADAM(params(nalu), 0.001))
        if i % 100 == 1
            tr_err = val(nalu, zip(train_xs, train_ys))
            err = val(nalu, interp_data)
            print("Epoch $i: Train Error: $tr_err, Valid Error: $err\n")
        end
    end
    print(">\n")
    tr_err = val(nalu, zip(train_xs, train_ys))
    in_err = val(nalu, interp_data)
    ex_err = val(nalu, extrap_data)
    print("> train error: $tr_err\n")
    print("> interpolation error: $in_err\n")
    print("> extrapolation error: $ex_err\n")
    
    W = tanh.(nalu.layers[1].nac.Ŵ) .* σ.(nalu.layers[1].nac.M̂)
    W2 = tanh.(nalu.layers[2].nac.Ŵ) .* σ.(nalu.layers[2].nac.M̂)
    println("> a_inds mean weight: ", [mean(W[1, a_inds]), mean(W[2, a_inds])])
    println("> b_inds mean weight: ", [mean(W[1, b_inds]), mean(W[2, b_inds])])
    println("> layer 2 W: ", W2)
    
    
    println("> layer 1 bias: ", nalu.layers[1].b)
    println("> layer 2 G weight: ", nalu.layers[2].G)
    println("> layer 2 bias: ", nalu.layers[2].b)
end

do_experiment (generic function with 1 method)

## Addition

In [3]:
do_experiment(+);

-- NAC --
Epoch 1: Train Error: 27.675351058563557 (tracked), Valid Error: 29.657817508130222 (tracked)
Epoch 101: Train Error: 0.0014187921155979619 (tracked), Valid Error: 0.0014372415932729127 (tracked)
Epoch 201: Train Error: 0.00017523127317473398 (tracked), Valid Error: 0.00016540948591393427 (tracked)
Epoch 301: Train Error: 0.00015151769731826676 (tracked), Valid Error: 0.00014845050329704207 (tracked)
Epoch 401: Train Error: 8.3279958417224e-5 (tracked), Valid Error: 8.349644996012516e-5 (tracked)
Epoch 501: Train Error: 0.00019598648024221456 (tracked), Valid Error: 0.0001938798890899978 (tracked)
Epoch 601: Train Error: 0.00015110669462168672 (tracked), Valid Error: 0.00014389854415807385 (tracked)
Epoch 701: Train Error: 0.00017923562582576624 (tracked), Valid Error: 0.00017389543568526052 (tracked)
Epoch 801: Train Error: 0.00011509591531513736 (tracked), Valid Error: 0.00011202559015617389 (tracked)
Epoch 901: Train Error: 0.00016291028814279573 (tracked), Valid Error: 0.

## Subtraction

In [4]:
do_experiment(-);

-- NAC --
Epoch 1: Train Error: 27.70209507613458 (tracked), Valid Error: 28.652899998225468 (tracked)
Epoch 101: Train Error: 0.0021821752511707398 (tracked), Valid Error: 0.0022806933211133346 (tracked)
Epoch 201: Train Error: 0.000136851230427532 (tracked), Valid Error: 0.00013872765316376993 (tracked)
Epoch 301: Train Error: 9.353710890418413e-5 (tracked), Valid Error: 9.046970050591172e-5 (tracked)
Epoch 401: Train Error: 0.00017624524864953329 (tracked), Valid Error: 0.00018095727249750013 (tracked)
Epoch 501: Train Error: 0.00014494886787299853 (tracked), Valid Error: 0.00013415914620956866 (tracked)
Epoch 601: Train Error: 0.00010266674473387461 (tracked), Valid Error: 0.00010781793955825337 (tracked)
Epoch 701: Train Error: 8.478415648120279e-5 (tracked), Valid Error: 9.292216587625685e-5 (tracked)
Epoch 801: Train Error: 0.00011795438896525791 (tracked), Valid Error: 0.00012170284575268849 (tracked)
Epoch 901: Train Error: 0.00010926892139197582 (tracked), Valid Error: 0.0001

## Multiplication

In [5]:
do_experiment(*);

-- NAC --
Epoch 1: Train Error: 27.03159243885835 (tracked), Valid Error: 26.395735853530173 (tracked)
Epoch 101: Train Error: 26.81089105223812 (tracked), Valid Error: 26.701388389890923 (tracked)
Epoch 201: Train Error: 26.81028719906098 (tracked), Valid Error: 26.71447942206562 (tracked)
Epoch 301: Train Error: 26.809981446165317 (tracked), Valid Error: 26.72393392053178 (tracked)
Epoch 401: Train Error: 26.809769014033662 (tracked), Valid Error: 26.739240116226267 (tracked)
Epoch 501: Train Error: 26.80970432266569 (tracked), Valid Error: 26.72652831849909 (tracked)
Epoch 601: Train Error: 26.809456546826112 (tracked), Valid Error: 26.721292017479698 (tracked)
Epoch 701: Train Error: 26.809560920122493 (tracked), Valid Error: 26.73382069351075 (tracked)
Epoch 801: Train Error: 26.809647276092 (tracked), Valid Error: 26.73161009389696 (tracked)
Epoch 901: Train Error: 26.809245384161702 (tracked), Valid Error: 26.73566042966793 (tracked)
>
> train error: 26.809771274134885 (tracked)

## Division

In [6]:
do_experiment(/);

-- NAC --
Epoch 1: Train Error: 4.844155105341939 (tracked), Valid Error: 5.309706547373761 (tracked)
Epoch 101: Train Error: 4.808495802869758 (tracked), Valid Error: 5.356533975401366 (tracked)
Epoch 201: Train Error: 4.807874385590283 (tracked), Valid Error: 5.365930614729175 (tracked)
Epoch 301: Train Error: 4.8075130625973035 (tracked), Valid Error: 5.366770737173705 (tracked)
Epoch 401: Train Error: 4.807250300491777 (tracked), Valid Error: 5.364711877323474 (tracked)
Epoch 501: Train Error: 4.80712517353235 (tracked), Valid Error: 5.36192051818307 (tracked)
Epoch 601: Train Error: 4.806776836038288 (tracked), Valid Error: 5.3651561596705735 (tracked)
Epoch 701: Train Error: 4.806667362307201 (tracked), Valid Error: 5.365520027618493 (tracked)
Epoch 801: Train Error: 4.806655060502668 (tracked), Valid Error: 5.363824755467786 (tracked)
Epoch 901: Train Error: 4.806454057743202 (tracked), Valid Error: 5.363070065506841 (tracked)
>
> train error: 4.806267063851542 (tracked)
> inter