# More on Parameter Estimation

In [None]:
include("tools.jl")

In [None]:
# NOTE: 
#  N = #basis, and not the maximal degree!!!
#  M = #observations

# chebbasis(x, N) = ... is defined in tools.jl 

chebsamples(M) = cos.(pi * rand(M))

unifsamples(M) = 2 * (rand(M) .- 0.5)

monobasis(x, N) = [ x^n for n = 0:N-1 ]

function designmatrix(X, N, basis)
    A = zeros(length(X), N)
    for (m, x) in enumerate(X)
        A[m, :] .= basis(x, N)
    end
    return A
end

function lsqfit(X, F, N, basis)
    A = designmatrix(X, N, basis)
    θ = A \ F
    return x -> dot(basis(x, N), θ)
end


In [None]:
# Showing some condition numbers! 

NN = (2).^(3:9)
# MM = 2 * ceil.(Int, NN .* log.(NN))
MM = ceil.(Int, NN.^(3/2))

κc = [] 
κm = [] 
for (N, M) in zip(NN, MM)
    X = unifsamples(M)
    Ac = designmatrix(X, N, chebbasis)
    Am = designmatrix(X, N, monobasis)
    push!(κc, cond(Ac))
    push!(κm, cond(Am))
end

plot(; xscale = :log10, yscale = :log10, 
        xlabel = "N", ylabel = L"\kappa", size = (400, 300), 
        legend = :topleft)
plot!(NN, κc, lw=3, m=:o, ms=6, label = "Chebyshev")
plot!(NN, κm, lw=3, m=:o, ms=6, label = "Monomials")

In [None]:
# ERROR IN LINEAR LEAST SQUARES  
# To demonstrate the numerical stability issues 
# we solve linear least squares with cheb basis
# and monomial basis and compare. Theoretically, 
# with infinite precision, the solutions should 
# be identical.

# samples = chebsamples 
samples = unifsamples 

f(x) = abs( (x - 0.7) * (x + 0.8) )^3
NN = (2).^(3:9)
MM = 2 * ceil.(Int, NN .* log.(NN))
xscale = :log10
rate = NN.^(-3.5)

# f(x) = 1 / (1 + 100 * x^2)
# NN = 5:10:100
# MM = ceil.(Int, 2 * NN.^(3/2))
# xscale = :linear
# rate = 0.3*(1.1).^(-NN)


rmse(p) = (x = samples(2_000); norm(p.(x) - f.(x)) / sqrt(length(x)))

errc = [] 
errm = [] 
for (N, M) in zip(NN, MM)
    X = samples(M)
    F = f.(X)
    pc = lsqfit(X, F, N, chebbasis)
    pm = lsqfit(X, F, N, monobasis)
    push!(errc, rmse(pc))
    push!(errm, rmse(pm))
end

plot(; xscale = xscale, yscale = :log10, size = (400, 300))
plot!(NN, errc, lw=2, m=:o, ms=6, label = "chebyshev")
plot!(NN, errm, lw=2, m=:o, ms=6, label = "monomials")
plot!(NN, rate, c=:black, ls=:dash, label = "")

## More Parameter Estimation Examples

### Example 1:

$$
   \langle L_m, f \rangle = \int x^m f(x) \,dx 
$$
We fit chebyshev polynomials, but we observe moments w.r.t. the monomial basis. The transformation between those bases is ill-conditioned, and this shows in conditioning of the LSQ problem.

In [None]:
using QuadGK

"""
generates the observations f -> int f(x) x^m dx 
"""
obs_moments(m, rtol=1e-6, atol=1e-8) = 
        f -> quadgk(x -> f(x) * x^m, -1, 1; rtol=rtol, atol=atol)[1]

With the above implementation in mind we can take an elegant functional approach to implementing the Least squares system... But for "real-world" problems (1000s to 1000000s of data, and parameters) this would like be inefficient and not a good approach.

In [None]:
"""
f : target function 
fbasis : function evaluating the basis 
train : list of observations (callable)
"""
function lsqsys(f, fbasis, train)
    M = length(train)
    B = train[1](fbasis)
    N = length(B) 
    A = zeros(M, N)
    Y = zeros(M) 
    for (m, X) in enumerate(train)
        Y[m] = X(f)
        A[m, :] = X(fbasis)
    end
    return A, Y 
end



In [None]:
NN = 3:2:15
MM = ceil.(Int, 2 * NN.^(2))
conds = [] 
sig1 = [] 

for (N, M) in zip(NN, MM)
    fbasis = x -> chebbasis(x, N)
    train = [ obs_moments(m-1) for m= 1:M ]
    A, Y = lsqsys(f, fbasis, train)
    push!(conds, cond(A))
    push!(sig1, minimum(svdvals(A)))
end

ata_table([ NN conds sig1], ["N", "cond(A)", "σ1"] )

### Example 2

In the first least squares problem we consider the observations are of the form 
$$
    f \mapsto K^{-1} \sum_{k=1}^K f(x_k)
$$
where $x_k$ are iid. The idea is that we no longer able to measure point values but can only measure "groups of point values", or "averages of point values". The initial intuition is that averages cannot be inverted and hence this might lead to an ill-conditioned parameter estimation problem.

In [None]:
"""
This implements an observation of the form 
```
   g -> sum_k=1^K g(x_k)
```
"""
obs_sumvals(k, frand = () -> 2*rand()-1) = 
    let X = [ frand() for _= 1:k] 
        g -> sum(g.(X)) / length(X)
    end


In [None]:
f(x) = 1 / (1 + 10 * x^2)
# rate = 0.3*(exp.(- asinh(1/sqrt(10)) * NN))
randc() = cos(pi * rand())

NN = 5:5:50
MM = ceil.(Int, 2 * NN.^(3/2))
conds = [] 
sig1 = [] 


for (N, M) in zip(NN, MM)
    fbasis = x -> chebbasis(x, N)
    train = [ obs_sumvals(10, randc) for _=1:M ]
    A, Y = lsqsys(f, fbasis, train)
    push!(conds, cond(A))
    push!(sig1, minimum(svdvals(A)))
end

ata_table([ NN conds sig1], ["N", "cond(A)", "σ1"] )

### Example 3

In the next example we explore what happens if we only observe the function f in a subdomain. In principle this fully determines the polynomial.

In [None]:
"""
This implements a very simple observation of the form 
```
   g -> g(x_k)
```
were ``x_k in [-a, a]``
"""
obs_subdom(a=0.5) = 
    let x = a * (2 * rand() - 1)
        g -> g(x)
    end



In [None]:

NN = 5:10:50
MM = ceil.(Int, 2 * NN.^(4/3))
conds = [] 
sig1 = [] 

for (N, M) in zip(NN, MM)
    fbasis = x -> chebbasis(x, N)
    train = [ obs_subdom() for _= 1:M ]
    A, Y = lsqsys(f, fbasis, train)
    push!(conds, cond(A))
    push!(sig1, minimum(svdvals(A)))
end

ata_table([ NN conds sig1], ["N", "cond(A)", "σ1"] )

### Regularisation

Example 1 is the most natural in a way, we use this to explore regularisation.

In [None]:
f(x) = 1 / (1 + 4 * x^2)
rate = 0.3*(exp.(- asinh(1/2) * NN))

N = 15
M = ceil(Int, 2*N)

fbasis = x -> chebbasis(x, N)
train = [ obs_moments(m-1) for m= 1:M ]
# train = [ obs_sumvals(10, randc) for m= 1:M ]
# train = [ obs_subdom() for _=1:M ]
A, Y = lsqsys(f, fbasis, train)
@show extrema(svdvals(A));

In [None]:
# best approximation 
xe = cos.(π * range(0, 1, length=1000))
X = chebsamples(100)
pbest = lsqfit(X, f.(X), N, chebbasis)
besterr = norm(f.(xe) - pbest.(xe), Inf)
@show besterr;

In [None]:
Alpha = reverse((0.1).^(-1:0.33:7))
Delta = reverse((0.1).^(-1:0.33:7))
errs = zeros(length(Alpha), length(Delta))

for (ia, α) in enumerate(Alpha), (id, δ) in enumerate(Delta)
    Ar = [A; α * Matrix(I, (N,N))]
    Yr = [Y + δ * randn(length(Y)); zeros(N)]
    Θ = Ar \ Yr
    p = x -> dot(Θ, fbasis(x))
    errs[ia, id] = norm(f.(xe) - p.(xe), Inf)
end

In [None]:
Plots.heatmap(Alpha, Delta, log.(errs), 
                xscale = :log10, yscale = :log10, size = (400,350),
                xlabel = L"\alpha", ylabel = L"\delta")

In [None]:
plt = plot(; xscale = :log10, 
             yscale = :log10, size = (400, 300), 
             xlabel = L"\alpha", ylabel = "error") 
σ1 = minimum(svdvals(A))
for (ip, id) in enumerate([1, 5, 10, 15])
    δ = Delta[id]
    plot!(plt, Alpha, errs[:, id], c=ip, label = "δ = $(round(δ, digits=7))", lw=2)
    hline!(plt, [0.05*sqrt(δ/σ1)], c=ip, ls=:dash, label = "")
end
hline!([besterr], label = "best", c=:red, lw=3)
plt