In [1]:
using Turing
using Distances
using PyPlot
using Random
using LinearAlgebra

┌ Info: Precompiling Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0]
└ @ Base loading.jl:1423
┌ Info: Precompiling PyPlot [d330b81b-6aea-500a-939a-2ce795aea3ee]
└ @ Base loading.jl:1423


In [2]:
# Squared-exponential covariance function
sqexp_cov_fn(D, phi) = @. exp(-D^2 / (2*phi)) 
# Exponential covariance function
exp_cov_fn(D, phi) = @. exp(-D / phi)
# Exponential covariance function
sqexp_cov_fn2(D, n, l, eps=1e-1) = @.(n*exp(-D^2/(2*l))) + LinearAlgebra.I * eps

sqexp_cov_fn2 (generic function with 2 methods)

In [None]:
@model function marginal_GP(y, X, m=0, s=1, cov_fn=sqexp_cov_fn)
    # Dimensions of predictors .
    N, P = size(X)
    
    # Distance matrix.
    D = pairwise(Distances.Euclidean(), X, dims=1)
    
    # Priors.
    mu ~ Normal(m, s)
    sig2 ~ LogNormal(0, 1)
    phi ~ LogNormal(0, 1)
    
    # Realized covariance function
    K = cov_fn(D, phi)
    y ~ MvNormal(mu * ones(N), K + sig2 * LinearAlgebra.I(N))
end

In [None]:
N = 150
X = randn(N, 1) * 3
y = sin.(vec(X)) + randn(N) * 0.1
plt.scatter(vec(X), y)

In [None]:
chain = sample(marginal_GP(y, X), HMC(0.01, 100), 200)

In [None]:
mu = group(chain, :mu).value.data[:, :, 1]
sig2 = group(chain, :sig2).value.data[:, :, 1]
phi = group(chain, :phi).value.data[:, :, 1];

In [None]:
function make_gp_predict_fn(Xnew, y, X, cov_fn)
    N = size(X, 1)
    M = size(Xnew, 1)
    Q = N + M
    Z = [Xnew; X]
    D = pairwise(Euclidean(), Z, dims=1)
    
    return (mu, sig2, phi) -> let
        K = cov_fn(D, phi) + sig2 * LinearAlgebra.I
        Koo_inv = inv(K[(M+1):end, (M+1):end])
        Knn = K[1:M, 1:M]
        Kno = K[1:M, (M+1):end]
        C = Kno * Koo_inv
        m = C * (y .- mu) .+ mu
        S = Matrix(LinearAlgebra.Hermitian(Knn - C * Kno'))
        mvn = MvNormal(m, S)
        rand(mvn)
    end
end

In [None]:
Xnew = sort(randn(N, 1) * 4, dims=1)
gp_predict = make_gp_predict_fn(Xnew, y, X, sqexp_cov_fn)
ynew = [gp_predict(mu[m], sig2[m], phi[m]) for m in 1:length(mu)]
ynew = hcat(ynew...);

In [None]:
pred_mean = mean(ynew, dims=2)
pred_std = std(ynew, dims=2)

plt.plot(vec(Xnew), vec(pred_mean), color="red", label="Posterior predictive mean")
plt.scatter(vec(X), vec(y), color="grey", label="Data")
plt.fill_between(vec(Xnew), vec(pred_mean .+ pred_std), vec(pred_mean .- pred_std), color="red", alpha=0.2)
#plt.legend(loc="lower right")
#plt.title("GP Posterior predictive with 95% credible interval");