# Evaluating the NAC and NALU for learning basic functions

In [1]:
using Flux
include("nalu.jl")
include("data.jl");

In [2]:
loss(m, x, y) = sum((m(x) .- y).^2)
val(m, val_data) = sum((loss(m, x, y) for (x, y) in val_data))/length(val_data)

function do_experiment(f)
    train₊, interp₊, extrap₊, _ = gen_data(f, 10000, 100, 10, 100)
    nac₊ = Chain(NAC(100, 2), NAC(2, 1))
    nalu₊ = Chain(NALU(100, 2), NALU(2, 1));

    # train NAC
    Flux.train!((x, y) -> loss(nac₊, x, y), train₊, SGD(params(nac₊), 0.001))
    @show val(nac₊, interp₊)
    @show val(nac₊, extrap₊)

    # train NALU
    Flux.train!((x, y) -> loss(nalu₊, x, y), train₊, SGD(params(nalu₊), 0.001))
    @show val(nalu₊, interp₊)
    @show val(nalu₊, extrap₊)
end

do_experiment (generic function with 1 method)

## Addition

In [3]:
do_experiment(+);

val(nac₊, interp₊) = 2.6058289440365295e-28 (tracked)
val(nac₊, extrap₊) = 2.7799138022460934e-26 (tracked)
val(nalu₊, interp₊) = 3364.9030834830737 (tracked)
val(nalu₊, extrap₊) = 335217.6585132444 (tracked)


## Subtraction

In [4]:
do_experiment(-);

val(nac₊, interp₊) = 2.884641908595823e-28 (tracked)
val(nac₊, extrap₊) = 2.7973425647505424e-26 (tracked)
val(nalu₊, interp₊) = 3754.957319221406 (tracked)
val(nalu₊, extrap₊) = 362387.35815813474 (tracked)


## Multiplication

In [5]:
do_experiment(*);

val(nac₊, interp₊) = 3.037412428562234e6 (tracked)
val(nac₊, extrap₊) = 2.5410421143908295e10 (tracked)
val(nalu₊, interp₊) = 3.036157077064866e6 (tracked)
val(nalu₊, extrap₊) = 2.540994593343057e10 (tracked)


## Division

In [6]:
do_experiment(/);

val(nac₊, interp₊) = 358.91951848286425 (tracked)
val(nac₊, extrap₊) = 744.258729700709 (tracked)
val(nalu₊, interp₊) = 2.6810934269898186e9 (tracked)
val(nalu₊, extrap₊) = 3.6883944620812e14 (tracked)
