In [None]:
using Turing
using DifferentialEquations
using DataFrames
using Random
using Distributions
using Plots
using StatsPlots
using LaTeXStrings

In [None]:
default(xtickfont=font(14),  ytickfont=font(14), guidefont=font(14), 
    legendfontsize=12, lw=2, ms=8)

In [None]:
function decay_problem(du, u, p, t)
    λ = p[1];
    du[1] = -λ * u[1];
end

In [None]:
u0 = [10.];
pᵗ = 0.3;
prob_decay = ODEProblem(decay_problem,u0,(0.0,10.0),pᵗ)
sol = solve(prob_decay,Tsit5())
plot(sol)

In [None]:
Δt = 0.5;
sol_data = solve(prob_decay,Tsit5(), saveat=Δt)
γ = 1;
Random.seed!(100);
y_data = Array(sol_data) + γ * randn(size(Array(sol_data)));
plot(sol)
scatter!(sol_data.t, y_data')

In [None]:
@model function bayes_ode(y_data, prob_decay)
    k ~ LogNormal()
    p = [k];
    prob_ = remake(prob_decay, p=p)
    y_pred =  solve(prob_,Tsit5(), saveat=Δt)
    for i in 1:length(y_pred)
       y_data[i] ~ Normal(y_pred[i][1], γ)
    end
end

In [None]:
model = bayes_ode(y_data, prob_decay)

In [None]:
chain = sample(model, HMC(0.01, 100), 10^4)

In [None]:
plot(chain)

In [None]:
chain_prior = sample(model, Prior(), 10^4);


In [None]:
histogram(Array(chain),label="Posterior")
histogram!(Array(chain_prior), label="Prior")
xlims!(0.2, 0.4)

In [None]:
n_samples = 100;
Random.seed!(500);
k_samples = Array(chain)[rand(1:length(chain), n_samples)]
plt = plot(sol, yscale=:log10)
for i in 1:n_samples
    sol_ = solve(remake(prob_decay,p=[k_samples[i]]),Tsit5())
    plot!(plt, sol_, alpha=0.1, color = "#BBBBBB", label="")
end
display(plt)