In [1]:
using UnboundedBNN, Optimisers, ProgressMeter, Zygote, LinearAlgebra, JLD2, Random, Statistics
using UnboundedBNN: TransformedDistribution

In [2]:
function crossentropy(p::Vector, q::Matrix)
    mask = p .== 1
    return - mapreduce(n -> q[mask[n]+1,n], +, 1:length(p))
end

crossentropy (generic function with 1 method)

In [3]:
function accuracy(x::Matrix, y::Vector, model; nr_evals = 100)
    posterior = UnboundedBNN.transform(model.posterior)
    lower, upper = UnboundedBNN.support(posterior)
    post_pdf = UnboundedBNN.pmf.(Ref(posterior), lower:upper)

    logln = zeros(Float64, 2, length(y))
    for k in 1:nr_evals
        output = model(x)
        logln += mapreduce(l -> output[l] .* post_pdf[l], +, 1:length(post_pdf))
    end

    return mean(getindex.(argmax(logln, dims=1), 1) .== (y .* -1/2 .+ 1.5))
end

accuracy (generic function with 1 method)

In [4]:
function loss(y, x, model; batch_prop = 1.0)
    
    posterior = UnboundedBNN.transform(model.posterior)
    lower, upper = UnboundedBNN.support(posterior)
    post_pdf = UnboundedBNN.pmf.(Ref(posterior), lower:upper)

    output = model(x)
    logln = mapreduce(l -> output[l] .* post_pdf[l], +, 1:length(post_pdf))

    kl_poisson = KL_loss(posterior, model.prior)
    kl_input = KL_loss(model.input_layer)
    kl_hidden = mapreduce(l -> dot(post_pdf[1:l], KL_loss.(model.output_layers[1:l])), +, 1:length(post_pdf))
    kl_output = mapreduce(l -> post_pdf[l] * KL_loss(model.output_layers[l]), +, 1:length(post_pdf))
    kl_total = kl_poisson + kl_input + kl_output + kl_hidden

    return crossentropy(y, logln) + batch_prop * kl_total
    
end

loss (generic function with 1 method)

In [5]:
function create_model(dims::Pair, dimmid; max_layers=30)

    expansion_layer = Chain(LinearBBB(dims[1] => dimmid), LeakyReLU())
    intermediate_layer = ntuple(_ -> Chain(LinearBBB(dimmid => dimmid), LeakyReLU()), max_layers)
    output_layer = ntuple(_ -> Chain(LinearBBB(dimmid => dims[2]), Softmax(dims[2])), max_layers)

    prior = discretize(H() * Normal(0.0f0 , 2.0f0))
    posterior = TransformedDistribution(
        SafeNormal([0.0f0], [invsoftplus(2.0f0)]), 
        (
            x -> H() * x,
            x -> truncate_to_quantiles(x, 0.025f0, 0.975f0),
            x -> expand_truncation_to_ints(x),
            x -> discretize(x)
        )
    )

    return Unbounded(
        expansion_layer,
        intermediate_layer,
        output_layer, 
        prior, 
        posterior
    )
end

create_model (generic function with 1 method)

In [6]:
function create_optimiser(model; lr=0.005f0)
    opt = Optimisers.setup(Adam(), model)
    Optimisers.adjust!(opt, lr)
    Optimisers.adjust!(opt.posterior, lr/10)
    return opt
end

create_optimiser (generic function with 1 method)

In [7]:
function run_experiment(folder, ω, dimmid; epochs=5000, batch_size=256, lr=0.005f0, max_layers=30, N=8192, rng=Random.default_rng())
    
    mkpath(folder)

    # loop over dimensions
    for dim in dimmid

        p = ProgressMeter.Progress(length(ω))

        # loop over difficulties
        Threads.@threads for ωi in ω

            # generate data
            Random.seed!(ωi)
            x_train, y_train = generate_spiral(N, ωi)
            x_val,   y_val   = generate_spiral(N, ωi)
            x_test,  y_test  = generate_spiral(N, ωi)

            # create model and optimiser
            model = create_model(2 => 2, dim, max_layers=max_layers)
            opt = create_optimiser(model, lr=lr)
            
            loss_train = zeros(epochs)
            loss_val   = zeros(epochs)
            loss_test  = zeros(epochs)
            best_val = Inf
            best_model = nothing

            for e in 1:epochs
                for n in Iterators.partition(randperm(rng, N), batch_size)
                    _, gs = Zygote.withgradient(m -> loss(y_train[n], x_train[:,n], m; batch_prop = length(n)/N), model)
                    opt, model = Optimisers.update!(opt, model, gs[1])
                end
                loss_train[e] = loss(y_train, x_train, model)
                loss_val[e] = loss(y_val, x_val, model)
                loss_test[e] = loss(y_test, x_test, model)

                if loss_val[e] < best_val
                    best_val = loss_val[e]
                    best_model = model
                end

            end

            #save results
            jldopen("$folder/dim_$(dim)_omega_$(ωi).jld2", "w") do file
                file["data/x_train"] = x_train
                file["data/y_train"] = y_train
                file["data/x_val"] = x_val
                file["data/y_val"] = y_val
                file["data/x_test"] = x_test
                file["data/y_test"] = y_test
                file["model"] = best_model
                file["results/loss_train"] = loss_train
                file["results/loss_val"] = loss_val
                file["results/loss_test"] = loss_test
                file["results/accuracy_train"] = accuracy(x_train, y_train, best_model)
                file["results/accuracy_val"] = accuracy(x_val, y_val, best_model)
                file["results/accuracy_test"] = accuracy(x_test, y_test, best_model)
            end

            ProgressMeter.next!(p)

        end

    end

end

run_experiment (generic function with 1 method)

In [8]:
run_experiment("data/spiral", 1:20, (4, 8, 16))

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:17:01[39m[K
