Skip to content

Commit

Permalink
Fix aaa truncation code to correctly recalculate the weights at the n…
Browse files Browse the repository at this point in the history
…ew size and add unit test for that. Update README.md with note about ForwardDiff not working correctly at a support point.

Signed-off-by: Don MacMillen <don@macmillen.net>
  • Loading branch information
macd committed Jul 22, 2023
1 parent d345916 commit 841aa58
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 28 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ julia> df3 = x -> -cos(x) + 2exp(x)
julia> df3(1.23)
6.508221345454844
```


NB: ForwardDiff does not play well with BaryRational because when we interpolate at
a support point, we just return the initial function value there. ForwardDiff recognizes
this as a constant and returns derivative of a constant, which is zero. There is
special handling in the algorithm of [3] for calculating the derivatives at support points
and that is implemented here.

The AAA algorithm is adaptive in the subset of support points that it
chooses to use.

Expand Down
65 changes: 38 additions & 27 deletions src/aaa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,26 @@ function aaa(Z::AbstractVector{T}, F::S; tol=1e-13, mmax=100, verbose=false,
aaa(Z, F.(Z), tol=tol, mmax=mmax, verbose=verbose, clean=clean)
end

function compute_weights(m, J, C::S, A::S) where {T, S <: AbstractMatrix{T}}
if length(J) >= m # The usual tall-skinny case
# Notice that A[J, :] selects only the non-support points and it is
# those points that we will use for a least squares fit.
G = svd(A[J, :]) # Reduced SVD (the default)
s = G.S
mm = findall(==(minimum(s)), s) # Treat case of multiple min sing val
nm = length(mm)
w = G.V[:, mm] * (ones(T, nm) ./ sqrt(nm)) # Aim for non-sparse wt vector
elseif length(J) >= 1
V = nullspace(A[J, :]) # Fewer rows than columns
nm = size(V, 2)
w = V * ones(T, nm) ./ sqrt(nm) # Aim for non-sparse wt vector
else
w = ones(T, m) ./ sqrt(m) # No rows at all
end
return w
end


"""aaa rational approximation of data F on set Z
r = aaa(Z, F; tol, mmax, verbose, clean)
Expand Down Expand Up @@ -71,13 +91,13 @@ function aaa(Z::AbstractVector{U}, F::AbstractVector{S}; tol=1e-13, mmax=100,

M = length(Z) # number of sample points
mmax = min(M, mmax) # max number of support points

abstol = tol * norm(F, Inf)
verbose && println("\nabstol: ", abstol)

F, Z = promote(F, Z)
T = promote_type(S, U)

J = [1:M;]
z = T[] # support points
f = T[] # function values at support points
Expand All @@ -88,8 +108,10 @@ function aaa(Z::AbstractVector{U}, F::AbstractVector{S}; tol=1e-13, mmax=100,
errvec = T[]
R = fill(mean(F), size(F))
m = 1
jtrunc = Int[]
@inbounds for outer m in 1:mmax
j = argmax(abs.(F .- R)) # select next support point
push!(jtrunc, j) # save index incase we need to truncate later
push!(z, Z[j])
push!(f, F[j])
deleteat!(J, findfirst(isequal(j), J)) # update index vector
Expand All @@ -100,21 +122,7 @@ function aaa(Z::AbstractVector{U}, F::AbstractVector{S}; tol=1e-13, mmax=100,
# Loewner matrix
A = hcat(A, (F .- f[end]) .* C[:, end])

# Compute weights:
if length(J) >= m # The usual tall-skinny case
# Notice that A[J, :] selects only the non-support points
G = svd(A[J, :]) # Reduced SVD (the default)
s = G.S
mm = findall(==(minimum(s)), s) # Treat case of multiple min sing val
nm = length(mm)
w = G.V[:, mm] * (ones(T, nm) ./ sqrt(nm)) # Aim for non-sparse wt vector
elseif length(J) >= 1
V = nullspace(A[J, :]) # Fewer rows than columns
nm = size(V, 2)
w = V * ones(T, nm) ./ sqrt(nm) # Aim for non-sparse wt vector
else
w = ones(T, m) ./ sqrt(m) # No rows at all
end
w = compute_weights(m, J, C, A)

# Don't use the zero weights when calculating the approximation at the
# support points
Expand All @@ -136,11 +144,14 @@ function aaa(Z::AbstractVector{U}, F::AbstractVector{S}; tol=1e-13, mmax=100,
# If we've gone to max iters, then it is possible that the best
# approximation is at a smaller vector size. If so, truncate the
# approximation which will give us a better approximation that is
# faster to compute.
# faster to compute. Note that we must truncate C and A, reset J, and then
# recompute the weights for this smaller size.
if m == mmax
verbose && println("Hit max iters. Truncating approximation.")
idx = argmin(i -> real(errvec[i]), eachindex(errvec))
for v in (z, f, w, errvec)
verbose && println("Hit max iters. Truncating approximation at $idx.")
J = deleteat!([1:M;], sort(jtrunc))
w = @views compute_weights(idx, J, C[:,1:idx], A[:, 1:idx])
for v in (z, f, errvec)
deleteat!(v, idx+1:mmax)
end
end
Expand All @@ -150,7 +161,7 @@ function aaa(Z::AbstractVector{U}, F::AbstractVector{S}; tol=1e-13, mmax=100,
for v in (z, f, w, errvec)
deleteat!(v, izero)
end

# We must sort if we plan on using bary rather than reval, _but_ this
# will not work when z is complex
if do_sort
Expand All @@ -168,7 +179,7 @@ function aaa(Z::AbstractVector{U}, F::AbstractVector{S}; tol=1e-13, mmax=100,
ii = findall(abs.(res) .< 1e-13) # find negligible residues
length(ii) != 0 && cleanup!(r, pol, res, zer, Z, F)
end

return r
end

Expand All @@ -185,10 +196,10 @@ function prz(r::AAAapprox)
pol, _ = eigen(E, B)
pol = pol[isfinite.(pol)]
dz = T(1//100000) * exp.(2im*pi*[1:4;]/4)

# residues
res = r(pol .+ transpose(dz)) * dz ./ 4

E = [0 transpose(w .* f); ones(T, m) diagm(z)]
zer, _ = eigen(E, B)
zer = zer[isfinite.(zer)]
Expand All @@ -214,7 +225,7 @@ function reval(zz, z, f, w)
CC = 1.0 ./ (zv .- transpose(z)) # Cauchy matrix
r = (CC * (w .* f)) ./ (CC * w) # AAA approx as vector
r[isinf.(zv)] .= sum(f .* w) ./ sum(w)

ii = findall(isnan.(r)) # find values NaN = Inf/Inf if any
@inbounds for j in ii
# Wow, linear search, but only if a NaN happens
Expand All @@ -229,14 +240,14 @@ end

# Only calculate the updated z, f, and w
# FIXME: Change the hardcoded tolerance in this function
function cleanup!(r, pol, res, zer, Z, F)
function cleanup!(r, pol, res, zer, Z, F; verbose=false)
z, f, w = r.x, r.f, r.w
m = length(z)
M = length(Z)
ii = findall(abs.(res) .< 1e-13) # find negligible residues
ni = length(ii)
ni == 0 && return
println("$ni Froissart doublets. Number of residues = ", length(res))
verbose && println("$ni Froissart doublets. Number of residues = ", length(res))

# For each spurious pole find and remove closest support point:
@inbounds for j = 1:ni
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ include("test_deriv.jl")
@test test_aaa_airy_prime()
@test test_2nd_derivative()
@test test_runge_derivs()
@test test_truncation()
end
end
11 changes: 11 additions & 0 deletions test/test_deriv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,14 @@ function test_runge_derivs(tol=1e-10)

return err1 < tol && err2 < tol && err3 < 1e-5
end

function test_truncation()
xbig = BigFloat.([-1//1:1//100:1//1;])
fbig = sin.(xbig);
# The min error appears at m=25, so this will test the truncation
sf = aaa(xbig, fbig, mmax=30, clean=false, tol=BigFloat(1/10^40));
# This also tests at support points
xtest = BigFloat.([-1//1:1//1000:1//1;])
error = norm(sin.(xtest) - sf.(xtest), Inf)
return error < BigFloat(1//10^30)
end

0 comments on commit 841aa58

Please sign in to comment.