In [1]:
using Pkg
Pkg.activate("../OldEnvironment")

[32m[1m  Activating[22m[39m project at `~/PhD/GaussianProcessNode/OldEnvironment`


In [2]:
using ApproximateGPs
using ParameterHandling
using Zygote, StatsFuns
using PDMats: PDMat
using Distributions
using LinearAlgebra
using Optim, Flux
using IterTools: ncycle
using Plots
using JLD, MAT , CSV, DataFrames

## Toy dataset

In [3]:
#load data 
data_path = "../savefiles/" 
xtrain = load(data_path*"xtrain_toyclassification.jld")["xtrain"]
ytrain = load(data_path*"ytrain_toyclassification.jld")["ytrain"]

xtest = load(data_path*"xtest_toyclassification.jld")["xtest"]
ytest = load(data_path*"ytest_toyclassification.jld")["ytest"]

Xu = load(data_path*"Xu_toyclassification.jld")["Xu"];
M = length(Xu);
N = length(ytrain);

In [4]:
raw_initial_params = (
    k=(var=positive(1.0), precision=positive(1.0)),
    m=zeros(M),
    A=positive_definite(Matrix{Float64}(I, M, M)),
);
flat_init_params, unflatten = ParameterHandling.flatten(raw_initial_params)
unpack = ParameterHandling.value ∘ unflatten;

In [5]:
lik = BernoulliLikelihood()
jitter = 1e-3  # added to aid numerical stability

function build_SVGP(params::NamedTuple)
    kernel = params.k.var * with_lengthscale(SqExponentialKernel(),params.k.precision)
    f = LatentGP(GP(kernel), lik, jitter)
    q = MvNormal(params.m, params.A)
    fz = f(Xu).fx
    return SparseVariationalApproximation(fz, q), f
end

function loss(params::NamedTuple;x=xtrain,y=ytrain)
    svgp, f = build_SVGP(params)
    fx = f(xtrain)
    return -elbo(svgp, fx, ytrain)
end;

In [6]:
opt = optimize(
    loss ∘ unpack,
    θ -> only(Zygote.gradient(loss ∘ unpack, θ)),
    flat_init_params,
    LBFGS(),
    inplace=false,
)

 * Status: success (objective increased between iterations)

 * Candidate solution
    Final objective value:     3.066228e+01

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 2.25e-09 ≰ 0.0e+00
    |x - x'|/|x'|          = 8.68e-10 ≰ 0.0e+00
    |f(x) - f(x')|         = 7.11e-15 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 2.32e-16 ≰ 0.0e+00
    |g(x)|                 = 9.15e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   3  (vs limit Inf)
    Iterations:    130
    f(x) calls:    371
    ∇f(x) calls:   371


In [7]:
final_params = unpack(opt.minimizer)

svgp_opt, f_opt = build_SVGP(final_params)
post_opt = posterior(svgp_opt)
l_post_opt = LatentGP(post_opt, BernoulliLikelihood(), jitter);

In [8]:
predict_mean = mean(post_opt(xtest))
p_predict = normcdf.(predict_mean)

predict_bin = zeros(length(p_predict))
for i=1:length(p_predict)
    p_predict[i] > 0.5 ? predict_bin[i] = 1.0 : predict_bin[i] = 0.0
end

In [9]:
# count number of errors (for classification)
function num_error(ytrue, y)
    return sum(abs.(y - ytrue))
end

function error_rate(ytrue, y)
    return num_error(ytrue,y) / length(ytrue)
end

error_rate (generic function with 1 method)

In [10]:
println("Number of error:", num_error(ytest, predict_bin))
println("Error rate: ", error_rate(ytest, predict_bin))

Number of error:34.0
Error rate: 0.085


## Banana dataset

In [6]:
banana_data = CSV.read("../data/banana/banana.csv", DataFrame);
x_banana_data = [[banana_data[i,1], banana_data[i,2]] for i=1:size(banana_data,1)];
banana_label = banana_data[:,end] |> (x) -> float(replace(x, -1 => 0));

Ntrain = 4000;
xtrain_banana, ytrain_banana = x_banana_data[1:Ntrain], banana_label[1:Ntrain];
xtest_banana, ytest_banana = x_banana_data[Ntrain+1:end], banana_label[Ntrain + 1: end];

batch_size = 200;

#load inducing points 
Xu_banana = load(data_path*"Xu_banana.jld")["Xu"]
M_banana = length(Xu_banana)
input_dim = length(Xu_banana[1])

2

## make SVI

In [13]:
struct SVGPModel
    k  :: Vector{Float64}     # kernel parameters
    m_u  :: Vector{Float64}   # variational mean
    A  :: Matrix{Float64}     # square-root of variational covariance
end

Flux.@functor SVGPModel (k, m_u, A);

function make_kernel_banana(k_params)
    variance = StatsFuns.softplus(k_params[1])
    lengthscale = StatsFuns.softplus.(k_params[2:end])
    return variance * with_lengthscale(SqExponentialKernel(), lengthscale)
end

jitter = 1e-5;
function prior_banana(m::SVGPModel)
    kernel = make_kernel_banana(m.k)
    return LatentGP(GP(kernel),BernoulliLikelihood(),jitter)
end

function make_approx_banana(m::SVGPModel, prior)
    # Efficiently constructs S as A*Aᵀ
    S = PDMat(Cholesky(LowerTriangular(m.A)))
    q = MvNormal(m.m_u, S)
    fz = prior(Xu_banana).fx
    return SparseVariationalApproximation(fz, q)
end;

function model_posterior_banana(m::SVGPModel)
    svgp = make_approx_banana(m, prior_banana(m))
    return posterior(svgp)
end;

function (m::SVGPModel)(x)
    post = model_posterior_banana(m)
    return post(x)
end;

function loss(m::SVGPModel, x, y; num_data=length(ytrain_banana))
    f = prior_banana(m)
    fx = f(x)
    svgp = make_approx_banana(m, f)
    return -elbo(svgp, fx, y; num_data)
end;

In [14]:
init_variance = 1
init_lengthscale = ones(input_dim)
k_banana_init = [StatsFuns.invsoftplus(init_variance), StatsFuns.invsoftplus.(init_lengthscale)] |> (x) -> vcat(x...);
m_banana_init = zeros(M_banana)
A_banana_init = Matrix{Float64}(I, M_banana, M_banana)

model_banana = SVGPModel(k_banana_init, m_banana_init, A_banana_init);
opt = Flux.AdaMax()  # Define the optimiser
params_banana = Flux.params(model_banana);  # Extract the model parameters

In [None]:
data_loader = Flux.DataLoader((xtrain_banana, ytrain_banana);batchsize=batch_size)
Flux.train!(
    (x, y) -> loss(model_banana, x, y; num_data=Ntrain),
    Flux.params(model_banana),
    ncycle(data_loader,1000), #1000 epochs 
    opt,
);

In [49]:
post_banana = model_posterior_banana(model_banana)
my_predict_banana = post_banana(xtest_banana);

In [50]:
predict_banana = mean(my_predict_banana)
p_predict_banana = normcdf.(predict_banana)

predict_banana_bin = zeros(length(p_predict_banana))
for i=1:length(p_predict_banana)
    p_predict_banana[i] > 0.5 ? predict_banana_bin[i] = 1.0 : predict_banana_bin[i] = 0.0
end

In [51]:
println("Number of error:", num_error(ytest_banana, predict_banana_bin))
println("Error rate: ", error_rate(ytest_banana, predict_banana_bin))

Number of error:121.0
Error rate: 0.09307692307692307


In [52]:
save("../savefiles/VSGP_posterior_banana.jld","posterior",post_banana)
save("../savefiles/VSGP_model_banana.jld","model",model_banana)

In [None]:
# plot([opt_banana.trace[i].value for i=1:length(opt_banana.trace)])