In [None]:
using Turing
using Random
using Statistics
using Distributions
using StatsPlots
using LaTeXStrings
using DataFrames
using FFTW
using Dierckx
using Zygote
using Interpolations

In [None]:
Turing.setadbackend(:zygote)

In [None]:
default(xtickfont=font(12),  ytickfont=font(12), guidefont=font(12), 
    legendfontsize=10, lw=2, ms=8)

In [None]:
N = 8;
γ = 0.01;
x = LinRange(0,1,N+1)[2:end-1];

# true value that we wish to recover
# x_data = 0:0.1:1;
# x_data = x_data[2:end-1];
x_data = copy(x);
n_data = length(x_data);
uᵗ(x) = x*(1-x);

Random.seed!(500); # set a seed for reproducibility
y_data = @. uᵗ(x_data) + γ * randn()

In [None]:
plot(x, uᵗ.(x))
scatter!(x_data, y_data)

In [None]:
"""
`build_field` - Build a mean zero Gaussian random field with the (-Δ)^{-α} covariance operator in dimension one

### Fields
`ξ`   - Vector of N(0,1) values  
`α=1` - Smoothness parameter
"""
# function build_field(ξ;α=one(eltype(ξ)))
#     N = length(ξ)
    
#     # uhat = zeros(complex(eltype(ξ)),2*N); # preallocate space

#     # construct the eigenvalues
#     k = 1:N;
#     λ = @. 1/(π*k)^(2*α);

#     # fill in the nonzero entries
#     # NOTE we need to multiply by 2 *N for FFT scaling
#     # @. uhat[2:N+1] = 2 * N * sqrt(λ) * sqrt(2) * ξ;
    
#     uhat = [0;  2 * N * sqrt.(λ) * sqrt(2) .* ξ; zeros(N - 1)]


#     # invert and get the relevant imaginary part
#     u = @views imag.(ifft(uhat))[N+2:end];
#     return u
# end

function build_field(ξ; α=one(eltype(ξ)))
    N = length(ξ)
    
    # construct the eigenvalues
    πk = π * (1:N);
    # NOTE we need to multiply by 2 *N for FFT scaling
    c = 2N * sqrt(2)
    umid = @. c * ξ / πk^α;
    uhat = [0; umid; zeros(N - 1)]

    # invert and get the relevant imaginary part
    u = @views imag.(ifft(uhat)[N+2:end]);
    return u
end

In [None]:
@model function mean_recovery(x_data, y_data)
    ξ ~ MvNormal(zeros(N), 1.)
    u = build_field(ξ);
    u_spl = LinearInterpolation(x, u);
    u_pred = u_spl(x_data);
    n_data = length(x_data)
    
    for i in 1:n_data
       y_data[i]~Normal(u_pred[i], γ)
    end
    
end

In [None]:
model=mean_recovery(y_data)

In [None]:
chain = sample(model, HMC(0.01, 10), 10^4)

In [None]:
plot(chain)

In [None]:
chain_array = Array(chain);

In [None]:
u_post= zeros(length(chain), N-1);
for i in 1:length(chain)
   u_post[i,:] .=  build_field(chain_array[i,:]);
end

In [None]:
u_mean = mean(u_post, dims = 1)[:];
u_var = var(u_post, dims = 1)[:];

In [None]:
plt = plot(x, uᵗ.(x), label="Truth")
plot!(plt, x, u_mean, label="Posterior Mean", ribbon = sqrt.(u_var))
n_samples = 100;
Random.seed!(500);
k_samples  = rand(1:length(chain), n_samples);
for k in k_samples
    plot!(plt, x, u_post[k,:], alpha=0.1, color = "#BBBBBB", label="")
end
display(plt)