# Bayesian inversion for a diffusion coefficient
In this case, letting $p$ solve
$$
-\frac{d}{dx}\left(a(x)\frac{dp}{dx}\right) = f, \quad 0<x<1,\quad p(0) = p(1) =0,
$$
we have the data
$$
y_i = p(x_i) +\eta_i, \quad \eta_i \sim N(0, \gamma^2).
$$
We wish to learn the distribution of $a(x)$ from the data.  In this example we make $a$ a log normal, with 
$$
a(x) = e^{u(x)}
$$
and then put the $N(0,(-\Delta)^{-1}))$ prior on $u$, with Dirichlet boundary conditions, making it a.s. continuous, and vanishing at $(0,1)$.  

This makes use of the `Turing` module to do posterior sampling (with HMC here).  As this is dependent on automatic differentiation, certain modifications of the random field generator were needed for compatibility.

The way the problem is implemented here, the prior, $\mu_0$, really corresponds to the $\xi_k$, in the KLSE.  The prior would have us believe:
$$
u = \sum_{k=1}^\infty \sqrt{\lambda_k}\xi_k \varphi_k(x)
$$
where, in particular, the $\xi_k\sim N(0,1)$. After truncating to only $N$ terms, what the sampler will return are M samples of $(\xi^{(n)})_{n=1}^{M}$, where each $\xi \sim \mu^y$, and the components are unlikely to be $N(0,1)$ i.i.d.  The consequence of this is that to comapre our recovered $u$ to the true $u^\dagger$, we will need to sample this posterior and then substitute the values into the series expansion.

This example also makes use of an interpolant so that $x_i$ at which we have our measurements need not coincide with the numerical mesh.

In [None]:
using Turing
using Random
using Statistics
using Distributions
using StatsPlots
using LaTeXStrings
using FFTW
using Zygote
using DataInterpolations
using LinearAlgebra
using Printf

In [None]:
Turing.setadbackend(:zygote)

In [None]:
default(xtickfont=font(12),  ytickfont=font(12), guidefont=font(12), 
    legendfontsize=10, lw=2, ms=8)

In [None]:
function solve_bvp(a, f, Δx)
    A = diagm(-1=> -a[2:end-1]/Δx^2, 
        0=>(a[1:end-1]+a[2:end])/Δx^2, 
        1=>-a[2:end-1]/Δx^2);
    p = A\f;
    return p
end

In [None]:
N = 8;
γ = 0.001;
x = LinRange(0,1,N+1)|>collect;
xx = x[2:end-1];
Δx = x[2]-x[1];
x_mid = x[1:end-1].+Δx/2;

# recover on all points
# x_data = copy(x);

# recover on particular points
x_data = 0:0.2:1;
x_data = x_data[2:end-1]|>collect;

n_data = length(x_data);

# true value that we wish to recover
uᵗ(x) = x*(1-x);
aᵗ(x) = exp(uᵗ(x));
f(x) = x^2;
p = solve_bvp(aᵗ.(x_mid), f.(xx), Δx);

# generate noisy data
Random.seed!(500); # set a seed for reproducibility
p_spl = LinearInterpolation(p,xx);
y_data = @. p_spl(x_data) + γ * randn()

In [None]:
plot(xx, p)
scatter!(x_data, y_data)

In [None]:
# this has been modified to be compatible with `Zygote` automatic differentiation

"""
`build_field` - Build a mean zero Gaussian random field with the (-Δ)^{-α} covariance operator in dimension one

### Fields
`ξ`   - Vector of N(0,1) values  
`α=1` - Smoothness parameter
"""
function build_field(ξ; α=1.)
    # construct the eigenvalues
    k = 1:N;
    λ = @. 1/(π*k)^(2*α);

    # fill in the nonzero entries
    # NOTE we need to multiply by 2 *N for FFT scaling
    uhat = [0; @. 2 * N * sqrt(λ) * sqrt(2) * ξ; zeros(N-1)];

    # invert and get the relevant imaginary part.  note that this includes the endpoints
    u = @views [0; imag.(ifft(uhat)[N+2:end]); 0];
    return u
end

In [None]:
@model function diffusion_coefficient_recovery(x_data, y_data)
    ξ ~ MvNormal(zeros(N), Diagonal(ones(N)))
    u = build_field(ξ);
    # interpolate u onto the midpoints for the finite difference scheme
    u_spl = LinearInterpolation(u,x);
    u_mid = u_spl.(x_mid);
    # get the diffusion coefficient
    a = exp.(u_mid);
    # solve the differential equation
    p = solve_bvp(a, f.(xx), Δx);
    # interpolate the solution onto the points at which we have measurements
    p_spl = LinearInterpolation(p,xx);
    p_pred = p_spl.(x_data);
    
    n_data = length(x_data)
    for i in 1:n_data
       y_data[i]~Normal(p_pred[i], γ)
    end
    
end

In [None]:
model=diffusion_coefficient_recovery(x_data,y_data)

In [None]:
chain = sample(model, HMC(0.01, 10), 10^4)

In [None]:
plot(chain)

In [None]:
# convert to standard array data structure
chain_array = Array(chain);

In [None]:
size(chain_array)

In [None]:
# transform into u(x) samples
u_post= zeros(length(chain), N+1);
for i in 1:length(chain)
   u_post[i,:] .=  build_field(chain_array[i,:]);
end

In [None]:
# transform into u(x) samples
a_post= zeros(length(chain), N+1);
for i in 1:length(chain)
   a_post[i,:] .=  exp.(build_field(chain_array[i,:]));
end

In [None]:
u_mean = mean(u_post, dims = 1)[:];
u_var = var(u_post, dims = 1)[:];

a_mean = mean(a_post, dims = 1)[:];
a_var = var(a_post, dims = 1)[:];

In [None]:
plt = plot(x, uᵗ.(x), label = "Truth")
plot!(plt, x, u_mean, label = "Posterior Mean", ribbon = sqrt.(u_var))
xlabel!("x")
ylabel!("u")

In [None]:
plt = plot(x, exp.(uᵗ.(x)), label = "Truth")
plot!(plt, x, a_mean, label = "Posterior Mean", ribbon = sqrt.(a_var))
xlabel!("x")
ylabel!("a = exp(u)")