Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace DomainErrors with NaNs on out-of-domain evaluations #64

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ y = itp(1.5) # At a single point
ys = itp.(xs) # At multiple points
```

Attempts to evaluate outside the interpolation range will throw a [`DomainError`](https://docs.julialang.org/en/v1/base/base/#Core.DomainError) (i.e., the interpolator will not perform extrapolation).

### Plot (with [Plots](https://github.com/JuliaPlots/Plots.jl))

```jl
Expand Down
65 changes: 44 additions & 21 deletions src/PCHIPInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,7 @@ end

i = searchsortedlast(xs, x)

if i < firstindex(xs)
throw(DomainError(x, "Below interpolation range"))
end

if i == lastindex(xs)
if x > @inbounds xs[i]
throw(DomainError(x, "Above interpolation range"))
end
if i == lastindex(xs) && x == @inbounds xs[i]
i -= 1 # Treat right endpoint as part of rightmost interval
end

Expand All @@ -98,14 +91,14 @@ end
imin = firstindex(xs)

if x < @inbounds xs[imin]
throw(DomainError(x, "Below interpolation range"))
return imin - 1
end

imax = lastindex(xs)
xmax = @inbounds xs[imax]

if x > xmax
throw(DomainError(x, "Above interpolation range"))
return imax
elseif x == xmax
return imax - 1 # Treat right endpoint as part of rightmost interval
end
Expand All @@ -131,19 +124,49 @@ end

@inline _x(::Interpolator, x) = x
@inline _x(itp::Interpolator, x, _) = _x(itp, x)
Base.@propagate_inbounds _x(itp::Interpolator, ::Val{:begin}, i) = itp.xs[i]
Base.@propagate_inbounds _x(itp::Interpolator, ::Val{:end}, i) = itp.xs[i+1]
@inline function _x(itp::Interpolator, ::Val{:begin}, i)
if i < firstindex(itp.xs) || i >= lastindex(itp.xs)
return float(eltype(itp.xs))(NaN)
end
return @inbounds itp.xs[i]
end
@inline function _x(itp::Interpolator, ::Val{:end}, i)
if i < firstindex(itp.xs) || i >= lastindex(itp.xs)
return float(eltype(itp.xs))(NaN)
end
return @inbounds itp.xs[i+1]
end

@inline _evaluate(itp::Interpolator, ::Val{:begin}, i) = itp.ys[i]
@inline _evaluate(itp::Interpolator, ::Val{:end}, i) = itp.ys[i+1]
@inline function _evaluate(itp::Interpolator, ::Val{:begin}, i)
if i < firstindex(itp.ys) || i >= lastindex(itp.ys)
return float(eltype(itp.ys))(NaN)
end
return @inbounds itp.ys[i]
end
@inline function _evaluate(itp::Interpolator, ::Val{:end}, i)
if i < firstindex(itp.ys) || i >= lastindex(itp.ys)
return float(eltype(itp.ys))(NaN)
end
return @inbounds itp.ys[i+1]
end

@inline _derivative(itp::Interpolator, ::Val{:begin}, i) = itp.ds[i]
@inline _derivative(itp::Interpolator, ::Val{:end}, i) = itp.ds[i+1]
@inline function _derivative(itp::Interpolator, ::Val{:begin}, i)
if i < firstindex(itp.ds) || i >= lastindex(itp.ds)
return float(eltype(itp.ds))(NaN)
end
return @inbounds itp.ds[i]
end
@inline function _derivative(itp::Interpolator, ::Val{:end}, i)
if i < firstindex(itp.ds) || i >= lastindex(itp.ds)
return float(eltype(itp.ds))(NaN)
end
return @inbounds itp.ds[i+1]
end

@inline _ϕ(t) = 3t^2 - 2t^3
@inline _ψ(t) = t^3 - t^2

Base.@propagate_inbounds function _evaluate(itp::Interpolator, x, i)
function _evaluate(itp::Interpolator, x, i)
x1 = _x(itp, Val(:begin), i)
x2 = _x(itp, Val(:end), i)
h = x2 - x1
Expand All @@ -160,18 +183,18 @@ Base.@propagate_inbounds function _evaluate(itp::Interpolator, x, i)
+ d2*h * _ψ((x-x1)/h))
end

@inline _evaluate(itp::Interpolator, x) = @inbounds _evaluate(itp, x, _findinterval(itp, x))
@inline _evaluate(itp::Interpolator, x) = _evaluate(itp, x, _findinterval(itp, x))

@inline (itp::Interpolator)(x::Number) = _evaluate(itp, x)


Base.@propagate_inbounds function _integrate(itp::Interpolator, a, b, i)
@inline function _integrate(itp::Interpolator, a, b, i)
a_ = _x(itp, a, i)
b_ = _x(itp, b, i)
return (b_ - a_)/6*(_evaluate(itp, a, i) + 4*_evaluate(itp, (a_ + b_)/2, i) + _evaluate(itp, b, i)) # Simpson's rule
end

Base.@propagate_inbounds function _integrate(itp::Interpolator, a, b, i, j)
@inline function _integrate(itp::Interpolator, a, b, i, j)
if i == j
return _integrate(itp, a, b, i)
end
Expand All @@ -190,7 +213,7 @@ end
return -_integrate(itp, b, a)
end

return @inbounds _integrate(itp, a, b, _findinterval(itp, a), _findinterval(itp, b))
return _integrate(itp, a, b, _findinterval(itp, a), _findinterval(itp, b))
end

@inline integrate(itp::Interpolator, a::Number, b::Number) = _integrate(itp, a, b)
Expand Down
29 changes: 14 additions & 15 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,21 @@ end
@assert PCHIPInterpolation._is_strictly_increasing(xs)
for search in (x -> (@inferred PCHIPInterpolation._findinterval_base(xs, x)),
x -> (@inferred PCHIPInterpolation._findinterval_custom(xs, x)))
@test_throws DomainError search(0.0)
@test_throws DomainError search(1.0 - eps(1.0))
@test search(0.0) == 0
@test search(1.0 - eps(1.0)) == 0
@test search(1.0) == 1
@test search(1.0 + eps(1.0)) == 1
@test search(2.0 - eps(2.0)) == 1
@test search(2.0) == 2
@test search(2.0 + eps(2.0)) == 2
@test search(3.0 - eps(3.0)) == 2
@test search(3.0) == 2
@test_throws DomainError search(3.0 + eps(3.0))
@test_throws DomainError search(4.0)
@test search(NaN) in 1:2
@test search(3.0 + eps(3.0)) == 3
@test search(4.0) == 3
end
end
end

@testset "xs too short" begin
for xs in (1.0:1.0, collect(1.0:1.0), 1.0:0.0, collect(1.0:0.0))
@assert length(xs) < 2
Expand Down Expand Up @@ -257,11 +256,11 @@ end
itp = @inferred Interpolator([1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0])
@test itp(1) == 4
@test itp(4) == 1
@test_throws DomainError itp(1 - 1e-6)
@test_throws DomainError itp(4 + 1e-6)
@test_throws DomainError integrate(itp, 1 - 1e-6, 4 + 1e-6)
@test_throws DomainError integrate(itp, 1 - 1e-6, 3)
@test_throws DomainError integrate(itp, 2, 4 + 1e-6)
@test isnan(itp(1 - 1e-6))
@test isnan(itp(4 + 1e-6))
@test isnan(integrate(itp, 1 - 1e-6, 4 + 1e-6))
@test isnan(integrate(itp, 1 - 1e-6, 3))
@test isnan(integrate(itp, 2, 4 + 1e-6))
end

@testset "NaN propagation" begin
Expand Down Expand Up @@ -313,10 +312,10 @@ end
@testset "out of domain" begin
oitp = Interpolator(oxs, oys)

@test_throws DomainError oitp(0.0 - eps(0.0))
@test_throws DomainError oitp(11.0 + eps(11.0))
@test_throws DomainError integrate(oitp, 0.0 - eps(0.0), 1.0)
@test_throws DomainError integrate(oitp, 0.0, 11.0 + eps(11.0))
@test isnan(oitp(0.0 - eps(0.0)))
@test isnan(oitp(11.0 + eps(11.0)))
@test isnan(integrate(oitp, 0.0 - eps(0.0), 1.0))
@test isnan(integrate(oitp, 0.0, 11.0 + eps(11.0)))
end

@testset "incompatible arguments" begin
Expand Down