Skip to content

Commit

Permalink
Merge pull request #9 from lxvm/main
Browse files Browse the repository at this point in the history
fix for Float32 precision
  • Loading branch information
macd committed Aug 22, 2023
2 parents 3bdd842 + 30042da commit 4ff462c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
42 changes: 21 additions & 21 deletions src/aaa.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# AAA algorithm from the paper "The AAA Alorithm for Rational Approximation"
# by Y. Nakatsukasa, O. Sete, and L.N. Trefethen, SIAM Journal on Scientific
# by Y. Nakatsukasa, O. Sete, and L.N. Trefethen, SIAM Journal on Scientific
# Computing, 2018
using Printf

Expand All @@ -11,7 +11,7 @@ struct AAAapprox{T <: AbstractArray} <: BRInterp
end

# In this version zz can be a scalar or a vector. BUT: do not make the mistake
# of broadcasting (ie a.(zz)) when zz is a vector. Although it gives correct
# of broadcasting (ie a.(zz)) when zz is a vector. Although it gives correct
# results, it is much slower than just a(zz)
(a::AAAapprox)(zz) = reval(zz, a.x, a.f, a.w)

Expand All @@ -33,13 +33,13 @@ function compute_weights(m, J, A::S) where {T, S <: AbstractMatrix{T}}
s = G.S
mm = findall(==(minimum(s)), s) # Treat case of multiple min sing val
nm = length(mm)
w = G.V[:, mm] * (ones(T, nm) ./ sqrt(nm)) # Aim for non-sparse wt vector
w = G.V[:, mm] * (ones(T, nm) ./ sqrt(T(nm))) # Aim for non-sparse wt vector
elseif length(J) >= 1
V = nullspace(A[J, :]) # Fewer rows than columns
nm = size(V, 2)
w = V * ones(T, nm) ./ sqrt(nm) # Aim for non-sparse wt vector
w = V * ones(T, nm) ./ sqrt(T(nm)) # Aim for non-sparse wt vector
else
w = ones(T, m) ./ sqrt(m) # No rows at all
w = ones(T, m) ./ sqrt(T(m)) # No rows at all
end
return w
end
Expand All @@ -64,7 +64,7 @@ end
Note 1: Changes from matlab version:
switched order of Z and F in function signature
added verbose and clean boolean flags
pol, res, zer = vectors of poles, residues, zeros are now only
pol, res, zer = vectors of poles, residues, zeros are now only
calculated on demand by calling prz(z::AAAapprox)
Note 2: This does (more or less) work with BigFloats. Caveats: since prz
Expand Down Expand Up @@ -133,21 +133,21 @@ function aaa(Z::AbstractVector{U}, F::AbstractVector{S}; tol=1e-13, mmax=100,
# Don't use the zero weights when calculating the approximation at the
# support points
i0 = findall(!=(T(0)), w)
N = C[:, i0] * (w[i0] .* f[i0]) # numerator
N = C[:, i0] * (w[i0] .* f[i0]) # numerator
D = C[:, i0] * w[i0]

# Use the rational approximation at the remaining non support
# Use the rational approximation at the remaining non support
# points so we can measure the error.
R .= F
R[J] .= N[J] ./ D[J]
R[J] .= N[J] ./ D[J]

err = norm(F - R, Inf)
verbose && println("Iteration: ", m, " err: ", err)
push!(errvec, err) # max error at sample points
err <= abstol && break # stop if converged
end

# If we've gone to max iters, then it is possible that the best
# If we've gone to max iters, then it is possible that the best
# approximation is at a smaller vector size. If so, truncate the
# approximation which will give us a better approximation that is
# faster to compute. Note that we must truncate A, reset J, and then
Expand Down Expand Up @@ -214,7 +214,7 @@ function prz(z, f, w)
N = (T(1) ./ (pol .- transpose(z))) * (f .* w)
D = -((T(1) ./ (pol .- transpose(z))) .^ 2) * w
res = N ./ D

E = [T(0) transpose(w .* f); ones(T, m) diagm(z)];
sz = schur(E, B)
zer = sz.values[isfinite.(sz.values)]
Expand All @@ -227,7 +227,7 @@ end
# to be broadcasted over zz. (and in fact, you should not do so)
function reval(zz, z, f, w)
# evaluate r at zz
zv = size(zz) == () ? [zz] : vec(zz)
zv = size(zz) == () ? [zz] : vec(zz)
CC = 1.0 ./ (zv .- transpose(z)) # Cauchy matrix
r = (CC * (w .* f)) ./ (CC * w) # AAA approx as vector
r[isinf.(zv)] .= sum(f .* w) ./ sum(w)
Expand Down Expand Up @@ -268,7 +268,7 @@ function cleanup!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T};
zdistances[j] = norm(pol[j] .- Z, -Inf)
end
ii = findall(abs.(res) ./ zdistances .< cleanup_tol * geometric_mean_of_absF)

ni = length(ii)
ni == 0 && return
sn = ni == 1 ? "" : "s"
Expand All @@ -280,7 +280,7 @@ function cleanup!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T};
_, jj = findmin(azp)
deleteat!(z, jj) # remove nearest support points
deleteat!(f, jj)
end
end

# Remove support points z from sample set:
@inbounds for zs in z
Expand Down Expand Up @@ -359,7 +359,7 @@ function cleanup2!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T};
pol, res, zer = prz(z, f, w)
FT = typeof(abs(T(0)))
cleanup_tol = FT(cleanup_tol)

niter = 0
while true
niter = niter + 1
Expand Down Expand Up @@ -405,7 +405,7 @@ function cleanup2!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T};
unique!(ii)

ni = length(ii)
if ni == 0
if ni == 0
# Nothing to do.
break
else
Expand All @@ -423,7 +423,7 @@ function cleanup2!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T};
deleteat!(z, jj)
deleteat!(f, jj)
end

# Remove support points z from sample set:
@inbounds for zs in z
idx = findfirst(==(zs), Z)
Expand All @@ -432,17 +432,17 @@ function cleanup2!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T};
end
m = length(z)
M = length(Z)

# Build Loewner matrix:
SF = spdiagm(M, M, 0 => F)
Sf = diagm(f)
C = 1 ./ (Z .- transpose(z)) # Cauchy matrix.
A = SF * C - C * Sf # Loewner matrix.

# Solve least-squares problem to obtain weights:
G = svd(A)
@views w = G.V[:, m]

# Compute poles, residues and zeros for next round.
pol, res, zer = prz(z, f, w)
end # End of while loop
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ include("test_deriv.jl")
@test test_aaa_maxiters() # 2 doublets
@test test_aaa_truncation()
@test test_aaa_complex()
@test test_aaa_float32()
end
@testset "FH_rational_interpolation" begin
@test test_fh_runge()
Expand Down
12 changes: 12 additions & 0 deletions test/test_aaa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,15 @@ function test_aaa_complex()
zz = complex.(xx, xx)
return norm(sin.(zz) - g(zz), Inf) < 1e-12
end

function test_aaa_float32()
n = 100
z = range(-Float32(1), Float32(1), length=n)
f = 1 ./ (z .^ 2 .+ 1)
try
aaa(z, f, tol=sqrt(eps(one(Float32))))
return true
catch
return false
end
end

0 comments on commit 4ff462c

Please sign in to comment.