diff --git a/src/aaa.jl b/src/aaa.jl index e7867e9..28d98d6 100644 --- a/src/aaa.jl +++ b/src/aaa.jl @@ -1,5 +1,5 @@ # AAA algorithm from the paper "The AAA Alorithm for Rational Approximation" -# by Y. Nakatsukasa, O. Sete, and L.N. Trefethen, SIAM Journal on Scientific +# by Y. Nakatsukasa, O. Sete, and L.N. Trefethen, SIAM Journal on Scientific # Computing, 2018 using Printf @@ -11,7 +11,7 @@ struct AAAapprox{T <: AbstractArray} <: BRInterp end # In this version zz can be a scalar or a vector. BUT: do not make the mistake -# of broadcasting (ie a.(zz)) when zz is a vector. Although it gives correct +# of broadcasting (ie a.(zz)) when zz is a vector. Although it gives correct # results, it is much slower than just a(zz) (a::AAAapprox)(zz) = reval(zz, a.x, a.f, a.w) @@ -33,13 +33,13 @@ function compute_weights(m, J, A::S) where {T, S <: AbstractMatrix{T}} s = G.S mm = findall(==(minimum(s)), s) # Treat case of multiple min sing val nm = length(mm) - w = G.V[:, mm] * (ones(T, nm) ./ sqrt(nm)) # Aim for non-sparse wt vector + w = G.V[:, mm] * (ones(T, nm) ./ sqrt(T(nm))) # Aim for non-sparse wt vector elseif length(J) >= 1 V = nullspace(A[J, :]) # Fewer rows than columns nm = size(V, 2) - w = V * ones(T, nm) ./ sqrt(nm) # Aim for non-sparse wt vector + w = V * ones(T, nm) ./ sqrt(T(nm)) # Aim for non-sparse wt vector else - w = ones(T, m) ./ sqrt(m) # No rows at all + w = ones(T, m) ./ sqrt(T(m)) # No rows at all end return w end @@ -64,7 +64,7 @@ end Note 1: Changes from matlab version: switched order of Z and F in function signature added verbose and clean boolean flags - pol, res, zer = vectors of poles, residues, zeros are now only + pol, res, zer = vectors of poles, residues, zeros are now only calculated on demand by calling prz(z::AAAapprox) Note 2: This does (more or less) work with BigFloats. Caveats: since prz @@ -133,21 +133,21 @@ function aaa(Z::AbstractVector{U}, F::AbstractVector{S}; tol=1e-13, mmax=100, # Don't use the zero weights when calculating the approximation at the # support points i0 = findall(!=(T(0)), w) - N = C[:, i0] * (w[i0] .* f[i0]) # numerator + N = C[:, i0] * (w[i0] .* f[i0]) # numerator D = C[:, i0] * w[i0] - # Use the rational approximation at the remaining non support + # Use the rational approximation at the remaining non support # points so we can measure the error. R .= F - R[J] .= N[J] ./ D[J] - + R[J] .= N[J] ./ D[J] + err = norm(F - R, Inf) verbose && println("Iteration: ", m, " err: ", err) push!(errvec, err) # max error at sample points err <= abstol && break # stop if converged end - # If we've gone to max iters, then it is possible that the best + # If we've gone to max iters, then it is possible that the best # approximation is at a smaller vector size. If so, truncate the # approximation which will give us a better approximation that is # faster to compute. Note that we must truncate A, reset J, and then @@ -214,7 +214,7 @@ function prz(z, f, w) N = (T(1) ./ (pol .- transpose(z))) * (f .* w) D = -((T(1) ./ (pol .- transpose(z))) .^ 2) * w res = N ./ D - + E = [T(0) transpose(w .* f); ones(T, m) diagm(z)]; sz = schur(E, B) zer = sz.values[isfinite.(sz.values)] @@ -227,7 +227,7 @@ end # to be broadcasted over zz. (and in fact, you should not do so) function reval(zz, z, f, w) # evaluate r at zz - zv = size(zz) == () ? [zz] : vec(zz) + zv = size(zz) == () ? [zz] : vec(zz) CC = 1.0 ./ (zv .- transpose(z)) # Cauchy matrix r = (CC * (w .* f)) ./ (CC * w) # AAA approx as vector r[isinf.(zv)] .= sum(f .* w) ./ sum(w) @@ -268,7 +268,7 @@ function cleanup!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T}; zdistances[j] = norm(pol[j] .- Z, -Inf) end ii = findall(abs.(res) ./ zdistances .< cleanup_tol * geometric_mean_of_absF) - + ni = length(ii) ni == 0 && return sn = ni == 1 ? "" : "s" @@ -280,7 +280,7 @@ function cleanup!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T}; _, jj = findmin(azp) deleteat!(z, jj) # remove nearest support points deleteat!(f, jj) - end + end # Remove support points z from sample set: @inbounds for zs in z @@ -359,7 +359,7 @@ function cleanup2!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T}; pol, res, zer = prz(z, f, w) FT = typeof(abs(T(0))) cleanup_tol = FT(cleanup_tol) - + niter = 0 while true niter = niter + 1 @@ -405,7 +405,7 @@ function cleanup2!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T}; unique!(ii) ni = length(ii) - if ni == 0 + if ni == 0 # Nothing to do. break else @@ -423,7 +423,7 @@ function cleanup2!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T}; deleteat!(z, jj) deleteat!(f, jj) end - + # Remove support points z from sample set: @inbounds for zs in z idx = findfirst(==(zs), Z) @@ -432,17 +432,17 @@ function cleanup2!(r, Zp::AbstractVector{T}, Fp::AbstractVector{T}; end m = length(z) M = length(Z) - + # Build Loewner matrix: SF = spdiagm(M, M, 0 => F) Sf = diagm(f) C = 1 ./ (Z .- transpose(z)) # Cauchy matrix. A = SF * C - C * Sf # Loewner matrix. - + # Solve least-squares problem to obtain weights: G = svd(A) @views w = G.V[:, m] - + # Compute poles, residues and zeros for next round. pol, res, zer = prz(z, f, w) end # End of while loop diff --git a/test/runtests.jl b/test/runtests.jl index 59b27a9..93fd9d7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ include("test_deriv.jl") @test test_aaa_maxiters() # 2 doublets @test test_aaa_truncation() @test test_aaa_complex() + @test test_aaa_float32() end @testset "FH_rational_interpolation" begin @test test_fh_runge() diff --git a/test/test_aaa.jl b/test/test_aaa.jl index 8e8c283..e62b986 100644 --- a/test/test_aaa.jl +++ b/test/test_aaa.jl @@ -260,3 +260,15 @@ function test_aaa_complex() zz = complex.(xx, xx) return norm(sin.(zz) - g(zz), Inf) < 1e-12 end + +function test_aaa_float32() + n = 100 + z = range(-Float32(1), Float32(1), length=n) + f = 1 ./ (z .^ 2 .+ 1) + try + aaa(z, f, tol=sqrt(eps(one(Float32)))) + return true + catch + return false + end +end