In [1]:
using Pkg 
Pkg.activate("..")
Pkg.instantiate();

using Revise 
using ReactiveMP, RxInfer
using Random, LinearAlgebra, SpecialFunctions, Plots, StableRNGs, DomainSets , LaTeXStrings, StatsFuns 
using Optim, ForwardDiff, Flux, Zygote
using KernelFunctions, MAT,LoopVectorization, PDMats
using Plots 
using JLD
import KernelFunctions: SqExponentialKernel, Matern52Kernel, with_lengthscale, kernelmatrix 

[32m[1m  Activating[22m[39m project at `~/PhD/GaussianProcessNode`


In [None]:
include("../helper_functions/gp_helperfunction.jl")
include("../GPnode/UniSGPnode.jl")
include("../helper_functions/derivative_helper.jl")

trace_blkmatrix (generic function with 1 method)

## Load data

In [3]:
# load data 
xtrain_data = matopen("../data/kin40k/kin40k_xtrain.mat");
xtest_data = matopen("../data/kin40k/kin40k_xtest.mat");
ytrain_data = matopen("../data/kin40k/kin40k_ytrain.mat");
ytest_data = matopen("../data/kin40k/kin40k_ytest.mat");

xtrain_matrix = read(xtrain_data, "xtrain");
ytrain = read(ytrain_data, "ytrain") |> (x) -> vcat(x...);
xtest_matrix = read(xtest_data, "xtest");
ytest = read(ytest_data, "ytest") |> (x) -> vcat(x...);

Ntrain = length(ytrain);
xtrain = [xtrain_matrix[i,:] for i=1:Ntrain];
data_training = (xtrain, ytrain);

Ntest = length(ytest);
xtest = [xtest_matrix[i,:] for i=1:Ntest];

batch_size = 500;
xtrain_minibatches, ytrain_minibatches = split2batch(data_training, batch_size);
nbatches = length(ytrain_minibatches);

## Configure GP

In [4]:
# configure GP 
Random.seed!(1)
M = 600;

pos = randperm(Ntrain)[1:M]
Xu = xtrain[pos]; #inducing input 

kernel_gp(θ) = StatsFuns.softplus(θ[1]) * with_lengthscale(SEKernel(),StatsFuns.softplus.(θ[2:end]))

dim_θ = size(xtrain_matrix,2) + 1
gpcache = GPCache();
θ_init = StatsFuns.invsoftplus.(ones(dim_θ));
optimizer = Flux.AdaMax();

Ψ0 = kernelmatrix(kernel_gp(θ_init),[xtrain[1]])
Ψ1_trans = kernelmatrix(kernel_gp(θ_init),Xu,[xtrain[1]])
Ψ2 = kernelmatrix(kernel_gp(θ_init),Xu,[xtrain[1]]) * kernelmatrix(kernel_gp(θ_init),[xtrain[1]],Xu);
w_val = 1e4
Kuu = kernelmatrix(kernel_gp(θ_init), Xu) + 1e-8 * I
Lu = cholesky(Kuu).U;
;

## Model Definition

In [5]:
@model function gp_kin40k(y,x,θ, μ_v ,Σ_v)
    v ~ MvNormalMeanCovariance(μ_v, Σ_v)
    @inbounds for i in eachindex(y)
        y[i] ~ UniSGP(x[i],v,w_val,θ)
    end
end

@meta function meta_gp_regression(Xu,Ψ0,Ψ1_trans,Ψ2,KuuL,kernel,Uv)
    UniSGP() -> UniSGPMeta(nothing,Xu,Ψ0,Ψ1_trans,Ψ2,KuuL,kernel,Uv,0,batch_size)
end

meta_gp_regression (generic function with 1 method)

## Inference

In [6]:
function my_free_energy(θ; xbatch, ybatch,μ_v,Σ_v)
    kernelmatrix!(Kuu,kernel_gp(θ), Xu)
    KuuL = fastcholesky!(Kuu).L
    infer_result = infer(
        model = gp_kin40k(θ = θ, μ_v = μ_v, Σ_v = Σ_v,),
        iterations = 1,
        data = (y = ybatch, x = xbatch,),
        returnvars = (v = KeepLast(),),
        meta = meta_gp_regression(Xu,Ψ0,Ψ1_trans,Ψ2,KuuL,kernel_gp,Lu),
        free_energy = false,
    )
    # return (infer_result.free_energy[end],infer_result.posteriors[:v], infer_result.posteriors[:w])
    return infer_result.posteriors[:v]
end
function PerformInference(θ;epochs = 1)
    # FE_value = []
    grad = similar(θ)
    θ_optimal = copy(θ)
    μ_v_marginal = zeros(M)
    Σ_v_marginal = 50*diageye(M)
    @inbounds for epoch=1:epochs
        μ_v = zeros(M)
        Σ_v = 50*diageye(M)
        @inbounds for b=1:nbatches
            #step 1: Perform inference for v, w 
            qv = my_free_energy(θ_optimal;   xbatch = xtrain_minibatches[b], 
                                                ybatch = ytrain_minibatches[b],
                                                μ_v = μ_v,
                                                Σ_v = Σ_v)
            #step 2: optimize the hyperparameters 
            μ_v, Σ_v = mean_cov(qv)
            Rv = Σ_v + μ_v * μ_v' |> (x) -> fastcholesky!(x).U
            grad_llh_new!(grad,θ_optimal; y_data=ytrain_minibatches[b],
                                    x_data=xtrain_minibatches[b],
                                    v = μ_v,
                                    Uv=Rv,
                                    w=w_val,
                                    kernel=kernel_gp,
                                    Xu=Xu,
                                    chunk_size=4)
            Flux.Optimise.update!(optimizer,θ_optimal,grad)
            # append!(FE_value,fe)
        end
        μ_v_marginal = μ_v
        Σ_v_marginal = Σ_v
    end
    q_v = MvNormalMeanCovariance(μ_v_marginal, Σ_v_marginal)
    return q_v,θ_optimal
end

PerformInference (generic function with 1 method)

In [None]:
qv, θ_opt = PerformInference(θ_opt; epochs = 500); #epochs = 500, approx 3h30min

#If you can't wait, then load the optimal result
# qv = load("../savefiles/qv_kin40k.jld")["qv"]
# θ_opt = load("../savefiles/params_optimal_kin40k.jld")["params_optimal"]

In [130]:
#new way: 
StatsFuns.softplus.(θ_opt)

9-element Vector{Float64}:
 0.17636613718898136
 2.994391934274809
 2.905302600576806
 1.7401945529137626
 2.2697267449222425
 2.0114338358466854
 1.5824668119572332
 1.533898096437981
 2.052099122165972

## Prediction and validation

In [140]:
# predict function 
function predict_new(x_test,qv,qw,θ_opt, meta)
    prediction = @call_rule UniSGP(:out, Marginalisation) ( q_in=PointMass(x_test),
                                                            q_v = qv, 
                                                            q_w = qw,
                                                            q_θ=PointMass(θ_opt), 
                                                            meta=meta)
    return prediction
end
Kuu = kernelmatrix(kernel_gp(θ_opt), Xu) + 1e-8 * I
KuuL = fastcholesky!(Kuu).L
predict_mean =[]
for i=1:Ntest
    prediction = predict_new(xtest[i],qv,PointMass(w_val),θ_opt,UniSGPMeta(nothing,Xu,Ψ0,Ψ1_trans,Ψ2,KuuL,kernel_gp,Lu,0,batch_size))
    append!(predict_mean,mean(prediction))
end

In [136]:
println("SMSE of GP node prediction: ",SMSE(ytest, predict_mean))

SMSE of GP node prediction: 0.08343114079545057


In [137]:
save("../savefiles/qv_kin40k.jld","qv",qv)
save("../savefiles/params_optimal_kin40k.jld","params_optimal",θ_opt)
save("../savefiles/SMSE_kin40k.jld","SMSE", SMSE(ytest, predict_mean))
save("../savefiles/Xu_kin40k.jld","Xu", Xu)
# save("savefiles/qw_kin40k.jld","qw",qw)
# save("savefiles/FE_kin40k.jld","FE",FE)