# Iteratively Reweighted Least Squares

In [None]:
using LinearAlgebra, Printf, Plots

In [None]:
function wlsq(A, y, w) 
    W = Diagonal(sqrt.(w))
    return qr(W * A) \ (W * y)
end 

function irlsq(A, y; tol=1e-5, maxnit = 100, γ = 1.0, γmin = 1e-6, verbose=true)
    M, N = size(A)
    @assert M == length(y)
    wold = w = ones(M) / M
    res = 1e300
    x = zeros(N)
    verbose  && @printf("  n   | ||f-p||_inf |  extrema(w) \n")
    verbose  && @printf("------|-------------|---------------------\n")
    for nit = 1:maxnit 
        x = wlsq(A, y, w)
        
        resnew = norm(y - A * x, Inf)
        verbose  && @printf(" %4d |   %.2e  |  %.2e  %.2e \n", nit, resnew, extrema(w)...)

        # update
        wold = w
        res = resnew
        wnew = w .* (abs.(y - A * x).^γ .+ 1e-15)
        wnew /= sum(wnew)
        w = wnew 
    end
    return x, w, res 
end

In [None]:
using FFTW
# we first implement the fast chebyshev transform 

revchebnodes(N) = [ cos(j*π/N) for j = 0:N ]

function fct(F)
    N = length(F)-1
    G = [F; F[N:-1:2]]
    Ĝ = real.(ifft(G))
    return [Ĝ[1]; 2 * Ĝ[2:N]; Ĝ[N+1]]
end 

function cheb_basis(x::T, N) where {T}
    B = zeros(T, N+1)
    B[1] = 1.0 
    B[2] = x 
    for k = 2:N 
        B[k+1] = 2 * x * B[k] - B[k-1]
    end
    return B
end

eval_chebpoly(F̃, x) = dot(F̃, cheb_basis(x, length(F̃)-1))

In [None]:
Nx = 1_000
Nb = 10
f = x -> abs(sin(x))^3  

X = revchebnodes(Nx-1)
A = zeros(Nx, Nb)
for n = 1:Nx 
    A[n, :] = cheb_basis(X[n], Nb-1)
end
y = f.(X)

x, w, res = irlsq(A, y; maxnit = 20);

In [None]:
P0 = plot(X, f.(X), lw = 2, label = "f")
P1 = plot(X, w, lw=2, label = "w")
P2 = plot(X, y - A * x, lw=2,  label = "err")
emax = norm(y - A * x, Inf)
plot!(P2, [-1,1], [emax, emax], lw=2, c=2, label="±max-err")
plot!(P2, [-1,1], [-emax, -emax], lw=2, c=2, label="")
plot(P0, P1, P2, layout = grid(3,1), size = (500, 600))

In [None]:
Nx = 10_000
Nb = 30
β = 100
# f = x -> sin( (1+x) * π/4 )
f = x ->  1/(1 + exp(β * x))

X = revchebnodes(Nx-1)
A = zeros(Nx, Nb)
for n = 1:Nx 
    A[n, :] = cheb_basis(X[n], Nb-1)
end
y = f.(X)

x, w, res = irlsq(A, y; maxnit=20);

In [None]:
P0 = plot(X, f.(X), lw = 2, label = "f")
P1 = plot(X, w, lw=2, label = "w")
P2 = plot(X, y - A * x, lw=2,  label = "err")
emax = norm(y - A * x, Inf)
plot!(P2, [-1,1], [emax, emax], lw=2, c=2, label="±max-err")
plot!(P2, [-1,1], [-emax, -emax], lw=2, c=2, label="")
plot(P0, P1, P2, layout = grid(3,1), size=(500, 600))

In [None]:
using Remez, LaTeXStrings

chebnodes(N) = [ cos(j*π/N) for j = N:-1:0 ]
function bary(f::Function, N, x)
    X = chebnodes(N)
    F = f.(X)
    return bary(F, x; X=X)
end
function bary(F::Vector, x; X = chebnodes(length(F)-1))
    N = length(F)-1
    p = 0.5 * ( F[1] ./ (x .- X[1]) + (-1)^N * F[N+1] ./(x .- X[N+1]) )
    q = 0.5 * (1.0 ./ (x .- X[1]) + (-1)^N ./ (x .- X[N+1]))
    for n = 1:N-1
        p += (-1)^n * F[n+1] ./ (x .- X[n+1])
        q += (-1)^n ./ (x .- X[n+1])
    end 
    return p ./ q    
end
errgrid(Np) = range(-1+0.0123, stop=1-0.00321, length=Np)

function firlsq(f, Nx, Nb)
    X = revchebnodes(Nx-1)
    A = zeros(Nx, Nb)
    for n = 1:Nx 
        A[n, :] = cheb_basis(X[n], Nb-1)
    end
    y = f.(X)
    x, w, res = irlsq(A, y; maxnit=100, verbose=false)
end

In [None]:
# Chebyshev vs Remez vs IIRLSQ for the Fermi-Dirac Example 
# ---------------------------------------------------------
β = 100
f = x -> 1/(1+exp(β*x))

NN = 10:10:100
Nr = 5:5:60

xerr = errgrid(10_000)
err = [ norm(f.(xerr) - bary(f, N, xerr), Inf)  for N in NN ]
errremez = [ ratfn_minimax(f, (-1, 1), N, 0)[3] for N in Nr ]
errirlsq = [ firlsq(f, N * 100, N)[3]           for N in Nr ]
P = plot(xaxis = (L"N (degree)",), 
         yaxis = (:log, L"\| f - I_{NM} f\|_{L^\infty}"))
plot!(NN, err, lw=2, m=:o, label = "cheb")
plot!(Nr, errremez, lw=2, m=:o, label = "Remez")
plot!(Nr, errirlsq, lw=2, m=:o, label = "IRLSQ")

In [None]:
# Chebyshev vs Remez vs IIRLSQ for the Fermi-Dirac Example 
# ---------------------------------------------------------
β = 100
f = x -> 1/(1+exp(β*x))
Nr = 20:20:200
@printf("   N  |  remez    irlsq \n")
@printf("------|--------------------\n")
for N in Nr 
    err_remez = try 
        ratfn_minimax(f, (-1, 1), N, 0)[3]
    catch
        NaN 
    end
    err_irlsq = try
        firlsq(f, N * 100, N)[3] 
    catch 
        NaN
    end
    @printf(" %4d | %.2e  %.2e \n", N, err_remez, err_irlsq)
end
