In [1]:
using Pkg 
Pkg.activate("..")

using Revise 
using RxInfer
using Random, LinearAlgebra, SpecialFunctions, Plots, StableRNGs, DomainSets , LaTeXStrings , Statistics, StatsFuns
using Optim, ForwardDiff, Flux, Zygote
using CSV, DataFrames
using KernelFunctions, MAT,LoopVectorization
using Plots 
using JLD

[32m[1m  Activating[22m[39m project at `~/PhD/GaussianProcessNode`


In [2]:
include("../GPnode/UniSGPnode.jl")
include("../helper_functions/derivative_helper.jl")
include("../helper_functions/gp_helperfunction.jl")

error_rate (generic function with 1 method)

## Load data

In [3]:
# load data 
data = CSV.read("../data/banana/banana.csv", DataFrame);
x_data = [[data[i,1], data[i,2]] for i=1:size(data,1)];
label = data[:,end] |> (x) -> float(replace(x, -1 => 0));

# split to train, test 
Ntrain = 4000;
xtrain, ytrain = x_data[1:Ntrain], label[1:Ntrain];
xtest, ytest = x_data[Ntrain+1:end], label[Ntrain + 1: end];

#create batch data
data_training = (xtrain, ytrain);
batch_size = 200;
xtrain_minibatches, ytrain_minibatches = split2batch(data_training, batch_size);
nbatches = length(ytrain_minibatches);

## Configure GP

In [4]:
Random.seed!(1)
M = 500;

pos = randperm(Ntrain)[1:M]
Xu = xtrain[pos]; #inducing input 

kernel_gp(θ) = StatsFuns.softplus(θ[1]) * with_lengthscale(SEKernel(),StatsFuns.softplus.([θ[2], θ[3]]));
ndims_params = 3;
θ_init = StatsFuns.invsoftplus.(ones(ndims_params));

Ψ0 = [1.0;;]
Ψ1_trans = kernelmatrix(kernel_gp(θ_init),Xu,[xtrain[1]])
Ψ2 = kernelmatrix(kernel_gp(θ_init),Xu,[xtrain[1]]) * kernelmatrix(kernel_gp(θ_init),[xtrain[1]],Xu);

## Model definition

In [5]:
@model function gp_classification(y,x,θ,mv, Σv, shape, rate)
    v ~ MvNormalMeanCovariance(mv, Σv)
    w ~ GammaShapeRate(shape,rate)
    for i in eachindex(y)
        f[i] ~ UniSGP(x[i],v,w,θ) 
        y[i] ~ Probit(f[i])
    end
end

gp_constraints = @constraints begin
    q(f,v,w) = q(f)q(v)q(w)
end

@meta function meta_gp_classification(Xu,Ψ0,Ψ1_trans,Ψ2,KuuL,kernel,Lu)
    UniSGP() -> UniSGPMeta(nothing,Xu,Ψ0,Ψ1_trans,Ψ2,KuuL,kernel,Lu,0,batch_size)
    Probit() -> ProbitMeta(32)
end


@initialization function gp_initialization(μv,Σv, α, β)
    q(v) = MvNormalMeanCovariance(μv, Σv)
    q(w) = GammaShapeRate(α,β)
end
;

## Perform inference

In [6]:
function my_free_energy(θ; xbatch, ybatch, mv, Σv, shape, rate)
    Kuu = kernelmatrix(kernel_gp(θ), Xu) + 1e-8 * I
    KuuL = fastcholesky!(Kuu).L
    Rv = mv * mv' + Σv
    Lu = fastcholesky!(Rv).U;
    infer_result = infer(
        model = gp_classification(θ = θ, mv = mv, Σv = Σv, shape=shape, rate=rate),
        iterations = 1,
        data = (y = ybatch, x = xbatch,),
        initialization = gp_initialization(mv, Σv,shape, rate),
        constraints = gp_constraints,
        returnvars = (v = KeepLast(),f = KeepLast(),w = KeepLast(),),
        meta = meta_gp_classification(Xu,Ψ0,Ψ1_trans,Ψ2,KuuL,kernel_gp,Lu),
        free_energy = false,
    )
    return (infer_result.posteriors[:v], infer_result.posteriors[:f], infer_result.posteriors[:w])
end


function PerformInference(θ;epochs = 1)
    # FE_value = []
    μ_v = zeros(M)
    Σ_v = 50*diageye(M)
    shape = 0.01
    rate = 0.01
    grad = similar(θ)
    optimizer = Flux.AdaMax()
    θ_optimal = θ

    @inbounds for epoch=1:epochs
        # μ_v = zeros(M)
        # Σ_v = 50*diageye(M)
        # shape = 0.01
        # rate = 0.01
        for b=1:nbatches
            #step 1: Perform inference for v, w 
            qv,qf, qw = my_free_energy(θ_optimal; xbatch = xtrain_minibatches[b], 
                                                ybatch = ytrain_minibatches[b],
                                                mv = μ_v,
                                                Σv = Σ_v,
                                                shape=shape,
                                                rate = rate)
            #step 2: optimize the hyperparameters 
            μ_v, Σ_v = mean_cov(qv)
            Rv = Σ_v + μ_v * μ_v' |> (x) -> fastcholesky!(x).U
            w = mean(qw)
            f = mean.(qf)
            grad_llh_new!(grad,θ_optimal; y_data=f,
                                    x_data=xtrain_minibatches[b],
                                    v = μ_v,
                                    Uv=Rv,
                                    w=w,
                                    kernel=kernel_gp,
                                    Xu=Xu,
                                    chunk_size=2)
            Flux.Optimise.update!(optimizer,θ_optimal,grad)
            # append!(FE_value,fe)
            shape,rate = params(qw)
        end
        # μ_v_marginal = μ_v
        # Σ_v_marginal = Σ_v
        # shape_marginal = shape
        # rate_marginal = rate
    end
    q_v = MvNormalMeanCovariance(μ_v,Σ_v)
    q_w = GammaShapeRate(shape,rate)
    return q_v,q_w, θ_optimal
end
;

In [74]:
@time qv,qw,θ_opt = PerformInference(θ_init; epochs = 500); # ≈ 50min 

#If you can't wait, then load the optimal result
# qv = load("../savefiles/qv_banana.jld")["qv"]
# qw = load("../savefiles/qw_banana.jld")["qw"]
# θ_opt = load("../savefiles/params_optimal_banana.jld")["params_optimal"]

2965.757395 seconds (8.62 G allocations: 4.856 TiB, 10.07% gc time, 0.01% compilation time: 99% of which was recompilation)


In [75]:
StatsFuns.softplus.(θ_opt)

3-element Vector{Float64}:
 0.9856325025617136
 1.0280572449420289
 1.0215368195768595

## Prediction and validation

In [70]:
function predict_new(x_test,qv,qw, qθ, meta)
    prediction_f = @call_rule UniSGP(:out, Marginalisation) (q_in=PointMass(x_test),q_v = qv, q_w = qw,q_θ = qθ, meta=meta)
    prediction_y = @call_rule Probit(:out, Marginalisation) (m_in=prediction_f,meta=ProbitMeta(32))
    return prediction_y
end

predict_new (generic function with 1 method)

In [76]:
predict_mean =[]
predict_var =[]
predict_y = []
Kuu = kernelmatrix(kernel_gp(θ_opt), Xu) + 1e-8 * I
Lu = cholesky(Kuu).U;
KuuL = fastcholesky!(Kuu).L

for i=1:length(ytest)
    prediction_y = predict_new(xtest[i],qv,qw,PointMass(θ_opt), UniSGPMeta(nothing,Xu,Ψ0,Ψ1_trans,Ψ2,KuuL,kernel_gp,Lu,0,batch_size))
    append!(predict_mean,mean(prediction_y))
    append!(predict_var,var(prediction_y)) 
    mean(prediction_y) >=0.5 ? predict = 1.0 : predict = 0.0
    append!(predict_y,predict) 
end

println("Number of error:", num_error(ytest, predict_y))
println("Error rate: ", error_rate(ytest, predict_y))

Number of error:125.0
Error rate: 0.09615384615384616


In [77]:
save("../savefiles/qv_banana.jld","qv",qv)
save("../savefiles/qw_banana.jld","qw",qw)
save("../savefiles/FE_banana.jld","FE",FE)
save("../savefiles/params_optimal_banana.jld","params_optimal",θ_opt)
save("../savefiles/error_rate_banana.jld","error_rate", error_rate(ytest, predict_y))
save("../savefiles/number_error_banana.jld","number_error", num_error(ytest, predict_y))
save("../savefiles/Xu_banana.jld","Xu", Xu)