In [1]:
using UnboundedBNN, Optimisers, ProgressMeter, Zygote, LinearAlgebra, JLD2, Random, Plots, Statistics
using UnboundedBNN: TransformedDistribution

In [2]:
to_one_hot(p) = vcat((p .== -1)', (p .== 1)')

to_one_hot (generic function with 1 method)

In [3]:
function binarycrossentropy(p::BitMatrix, q::Matrix)
    return -sum(q .* p)
end

binarycrossentropy (generic function with 1 method)

In [4]:
function predict(model, x; nr_evals=100)
    posterior = UnboundedBNN.transform(model.posterior)
    lower, upper = UnboundedBNN.support(posterior)
    post_pdf = UnboundedBNN.pmf.(Ref(posterior), lower:upper)

    predictions = zeros(Float64, 2, size(x, 2))
    for k in 1:nr_evals
        output = model(x)
        predictions += mapreduce(l -> UnboundedBNN.softmax(output[l]; dims=1) .* post_pdf[l], +, 1:length(post_pdf))
    end
    return [getindex.(argmax(predictions, dims=1), 1)...] .* 2 .- 3
end

predict (generic function with 1 method)

In [5]:
accuracy(y::Vector, yhat::Vector) = mean(y .== yhat)

accuracy (generic function with 1 method)

In [6]:
function loss(y::BitMatrix, x::Matrix{T}, model; batch_prop = 1.0f0) where {T}
    
    posterior = UnboundedBNN.transform(model.posterior)
    lower, upper = UnboundedBNN.support(posterior)
    post_pdf = UnboundedBNN.pmf.(Ref(posterior), lower:upper)

    output = model(x)
    logln = mapreduce(l -> post_pdf[l] .* output[l], +, 1:length(post_pdf)) :: Matrix{T}

    kl_poisson = KL_loss(posterior, model.prior)
    kl_input = KL_loss(model.input_layer)
    kl_hidden = mapreduce(l -> post_pdf[l-lower+1] * sum(li -> KL_loss(model.hidden_layers[li]), 1:l+1), +, lower:upper)
    kl_output = mapreduce(l -> post_pdf[l-lower+1] * KL_loss(model.output_layers[l+1]), +, lower:upper)
    kl_total = kl_poisson #+ kl_input + kl_output + kl_hidden
    return binarycrossentropy(y, logln) + batch_prop * kl_total

end

loss (generic function with 1 method)

In [7]:
function create_model(dims::Pair, dimmid; max_layers=30)

    expansion_layer = Chain(LinearBBB(dims[1] => dimmid), LeakyReLU())
    intermediate_layer = ntuple(_ -> Chain(LinearBBB(dimmid => dimmid), LeakyReLU()), max_layers)
    output_layer = ntuple(_ -> Chain(LinearBBB(dimmid => dims[2]), Softmax(dims[2])), max_layers)

    prior = discretize(H() * Normal(0.0f0 , 1.15f0))
    posterior = TransformedDistribution(
        SafeNormal([0.0f0], [invsoftplus(1.8f0)]), 
        (
            x -> H() * x,
            x -> truncate_to_quantiles(x, 0.025f0, 0.975f0),
            x -> expand_truncation_to_ints(x),
            x -> discretize(x)
        )
    )

    return Unbounded(
        expansion_layer,
        intermediate_layer,
        output_layer, 
        prior, 
        posterior
    )
end

create_model (generic function with 1 method)

In [8]:
function create_optimiser(model; lr=0.005f0)
    opt = Optimisers.setup(Adam(), model)
    Optimisers.adjust!(opt, lr)
    Optimisers.adjust!(opt.posterior, lr/10)
    return opt
end

create_optimiser (generic function with 1 method)

In [9]:
function run_experiment(folder, ω, dimmid, runs; epochs=20000, batch_size=256, lr=0.005f0, max_layers=30, N=1024, rng=Random.default_rng())

    for run in 1:runs

        # set seed
        Random.seed!(run)
        
        # loop over dimensions
        for dim in dimmid

            mkpath(folder * "/run_$(run)/dim_$(dim)")

            p = ProgressMeter.Progress(length(ω))

            # loop over difficulties
            Threads.@threads for ωi in ω

                # generate data
                x_train, y_train = generate_spiral(N, ωi)
                x_val,   y_val   = generate_spiral(N, ωi)
                x_test,  y_test  = generate_spiral(N, ωi)
                y_train_onehot = to_one_hot(y_train)
                y_val_onehot = to_one_hot(y_val)
                y_test_onehot = to_one_hot(y_test)

                # create model and optimiser
                model = create_model(2 => 2, dim, max_layers=max_layers)
                opt = create_optimiser(model, lr=lr)
                
                loss_train = zeros(epochs)
                loss_val   = zeros(epochs)
                loss_test  = zeros(epochs)
                best_val = Inf
                best_model = nothing

                for e in 1:epochs
                    for n in Iterators.partition(randperm(rng, N), batch_size)
                        _, gs = Zygote.withgradient(m -> loss(y_train_onehot[:,n], x_train[:,n], m; batch_prop = length(n)/N), model)
                        opt, model = Optimisers.update!(opt, model, gs[1])
                    end
                    loss_train[e] = loss(y_train_onehot, x_train, model)
                    loss_val[e] = loss(y_val_onehot, x_val, model)
                    loss_test[e] = loss(y_test_onehot, x_test, model)

                    if loss_val[e] < best_val
                        best_val = loss_val[e]
                        best_model = model
                    end

                end

                #save results
                jldopen(folder * "/run_$(run)/dim_$(dim)/dim_$(dim)_omega_$(ωi).jld2", "w") do file
                    file["data/x_train"] = x_train
                    file["data/y_train"] = y_train
                    file["data/x_val"] = x_val
                    file["data/y_val"] = y_val
                    file["data/x_test"] = x_test
                    file["data/y_test"] = y_test
                    file["model"] = best_model
                    file["results/loss_train"] = loss_train
                    file["results/loss_val"] = loss_val
                    file["results/loss_test"] = loss_test
                    file["results/predictions_train"] = predict(best_model, x_train)
                    file["results/predictions_val"] = predict(best_model, x_val)
                    file["results/predictions_test"] = predict(best_model, x_test)
                    file["results/accuracy_train"] = accuracy(y_train, predict(best_model, x_train))
                    file["results/accuracy_val"] = accuracy(y_val, predict(best_model, x_val))
                    file["results/accuracy_test"] = accuracy(y_test, predict(best_model, x_test))
                    file["results/posterior"] = UnboundedBNN.transform(best_model.posterior)
                    file["results/prior"] = best_model.prior
                end

                ProgressMeter.next!(p)

            end

        end

    end

end

run_experiment (generic function with 1 method)

In [10]:
run_experiment("data/spiral/normal", 0:30, 32, 5)

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:45:24[39m[K
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:38:00[39m[K
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:35:25[39m[K
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:34:59[39m[K
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:39:51[39m[K
