Skip to content

Commit

Permalink
Merge pull request #64 from gerlero/nan
Browse files Browse the repository at this point in the history
Replace DomainErrors with NaNs on out-of-domain evaluations
  • Loading branch information
gerlero authored Dec 14, 2023
2 parents 71072fa + cfe83f8 commit 9af84ee
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 38 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ y = itp(1.5) # At a single point
ys = itp.(xs) # At multiple points
```

Attempts to evaluate outside the interpolation range will throw a [`DomainError`](https://docs.julialang.org/en/v1/base/base/#Core.DomainError) (i.e., the interpolator will not perform extrapolation).

### Plot (with [Plots](https://github.com/JuliaPlots/Plots.jl))

```jl
Expand Down
65 changes: 44 additions & 21 deletions src/PCHIPInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,7 @@ end

i = searchsortedlast(xs, x)

if i < firstindex(xs)
throw(DomainError(x, "Below interpolation range"))
end

if i == lastindex(xs)
if x > @inbounds xs[i]
throw(DomainError(x, "Above interpolation range"))
end
if i == lastindex(xs) && x == @inbounds xs[i]
i -= 1 # Treat right endpoint as part of rightmost interval
end

Expand All @@ -98,14 +91,14 @@ end
imin = firstindex(xs)

if x < @inbounds xs[imin]
throw(DomainError(x, "Below interpolation range"))
return imin - 1
end

imax = lastindex(xs)
xmax = @inbounds xs[imax]

if x > xmax
throw(DomainError(x, "Above interpolation range"))
return imax
elseif x == xmax
return imax - 1 # Treat right endpoint as part of rightmost interval
end
Expand All @@ -131,19 +124,49 @@ end

@inline _x(::Interpolator, x) = x
@inline _x(itp::Interpolator, x, _) = _x(itp, x)
Base.@propagate_inbounds _x(itp::Interpolator, ::Val{:begin}, i) = itp.xs[i]
Base.@propagate_inbounds _x(itp::Interpolator, ::Val{:end}, i) = itp.xs[i+1]
@inline function _x(itp::Interpolator, ::Val{:begin}, i)
if i < firstindex(itp.xs) || i >= lastindex(itp.xs)
return float(eltype(itp.xs))(NaN)
end
return @inbounds itp.xs[i]
end
@inline function _x(itp::Interpolator, ::Val{:end}, i)
if i < firstindex(itp.xs) || i >= lastindex(itp.xs)
return float(eltype(itp.xs))(NaN)
end
return @inbounds itp.xs[i+1]
end

@inline _evaluate(itp::Interpolator, ::Val{:begin}, i) = itp.ys[i]
@inline _evaluate(itp::Interpolator, ::Val{:end}, i) = itp.ys[i+1]
@inline function _evaluate(itp::Interpolator, ::Val{:begin}, i)
if i < firstindex(itp.ys) || i >= lastindex(itp.ys)
return float(eltype(itp.ys))(NaN)
end
return @inbounds itp.ys[i]
end
@inline function _evaluate(itp::Interpolator, ::Val{:end}, i)
if i < firstindex(itp.ys) || i >= lastindex(itp.ys)
return float(eltype(itp.ys))(NaN)
end
return @inbounds itp.ys[i+1]
end

@inline _derivative(itp::Interpolator, ::Val{:begin}, i) = itp.ds[i]
@inline _derivative(itp::Interpolator, ::Val{:end}, i) = itp.ds[i+1]
@inline function _derivative(itp::Interpolator, ::Val{:begin}, i)
if i < firstindex(itp.ds) || i >= lastindex(itp.ds)
return float(eltype(itp.ds))(NaN)
end
return @inbounds itp.ds[i]
end
@inline function _derivative(itp::Interpolator, ::Val{:end}, i)
if i < firstindex(itp.ds) || i >= lastindex(itp.ds)
return float(eltype(itp.ds))(NaN)
end
return @inbounds itp.ds[i+1]
end

@inline (t) = 3t^2 - 2t^3
@inline (t) = t^3 - t^2

Base.@propagate_inbounds function _evaluate(itp::Interpolator, x, i)
function _evaluate(itp::Interpolator, x, i)
x1 = _x(itp, Val(:begin), i)
x2 = _x(itp, Val(:end), i)
h = x2 - x1
Expand All @@ -160,18 +183,18 @@ Base.@propagate_inbounds function _evaluate(itp::Interpolator, x, i)
+ d2*h * ((x-x1)/h))
end

@inline _evaluate(itp::Interpolator, x) = @inbounds _evaluate(itp, x, _findinterval(itp, x))
@inline _evaluate(itp::Interpolator, x) = _evaluate(itp, x, _findinterval(itp, x))

@inline (itp::Interpolator)(x::Number) = _evaluate(itp, x)


Base.@propagate_inbounds function _integrate(itp::Interpolator, a, b, i)
@inline function _integrate(itp::Interpolator, a, b, i)
a_ = _x(itp, a, i)
b_ = _x(itp, b, i)
return (b_ - a_)/6*(_evaluate(itp, a, i) + 4*_evaluate(itp, (a_ + b_)/2, i) + _evaluate(itp, b, i)) # Simpson's rule
end

Base.@propagate_inbounds function _integrate(itp::Interpolator, a, b, i, j)
@inline function _integrate(itp::Interpolator, a, b, i, j)
if i == j
return _integrate(itp, a, b, i)
end
Expand All @@ -190,7 +213,7 @@ end
return -_integrate(itp, b, a)
end

return @inbounds _integrate(itp, a, b, _findinterval(itp, a), _findinterval(itp, b))
return _integrate(itp, a, b, _findinterval(itp, a), _findinterval(itp, b))
end

@inline integrate(itp::Interpolator, a::Number, b::Number) = _integrate(itp, a, b)
Expand Down
29 changes: 14 additions & 15 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,21 @@ end
@assert PCHIPInterpolation._is_strictly_increasing(xs)
for search in (x -> (@inferred PCHIPInterpolation._findinterval_base(xs, x)),
x -> (@inferred PCHIPInterpolation._findinterval_custom(xs, x)))
@test_throws DomainError search(0.0)
@test_throws DomainError search(1.0 - eps(1.0))
@test search(0.0) == 0
@test search(1.0 - eps(1.0)) == 0
@test search(1.0) == 1
@test search(1.0 + eps(1.0)) == 1
@test search(2.0 - eps(2.0)) == 1
@test search(2.0) == 2
@test search(2.0 + eps(2.0)) == 2
@test search(3.0 - eps(3.0)) == 2
@test search(3.0) == 2
@test_throws DomainError search(3.0 + eps(3.0))
@test_throws DomainError search(4.0)
@test search(NaN) in 1:2
@test search(3.0 + eps(3.0)) == 3
@test search(4.0) == 3
end
end
end

@testset "xs too short" begin
for xs in (1.0:1.0, collect(1.0:1.0), 1.0:0.0, collect(1.0:0.0))
@assert length(xs) < 2
Expand Down Expand Up @@ -257,11 +256,11 @@ end
itp = @inferred Interpolator([1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0])
@test itp(1) == 4
@test itp(4) == 1
@test_throws DomainError itp(1 - 1e-6)
@test_throws DomainError itp(4 + 1e-6)
@test_throws DomainError integrate(itp, 1 - 1e-6, 4 + 1e-6)
@test_throws DomainError integrate(itp, 1 - 1e-6, 3)
@test_throws DomainError integrate(itp, 2, 4 + 1e-6)
@test isnan(itp(1 - 1e-6))
@test isnan(itp(4 + 1e-6))
@test isnan(integrate(itp, 1 - 1e-6, 4 + 1e-6))
@test isnan(integrate(itp, 1 - 1e-6, 3))
@test isnan(integrate(itp, 2, 4 + 1e-6))
end

@testset "NaN propagation" begin
Expand Down Expand Up @@ -313,10 +312,10 @@ end
@testset "out of domain" begin
oitp = Interpolator(oxs, oys)

@test_throws DomainError oitp(0.0 - eps(0.0))
@test_throws DomainError oitp(11.0 + eps(11.0))
@test_throws DomainError integrate(oitp, 0.0 - eps(0.0), 1.0)
@test_throws DomainError integrate(oitp, 0.0, 11.0 + eps(11.0))
@test isnan(oitp(0.0 - eps(0.0)))
@test isnan(oitp(11.0 + eps(11.0)))
@test isnan(integrate(oitp, 0.0 - eps(0.0), 1.0))
@test isnan(integrate(oitp, 0.0, 11.0 + eps(11.0)))
end

@testset "incompatible arguments" begin
Expand Down

0 comments on commit 9af84ee

Please sign in to comment.