Skip to content

Commit

Permalink
ENH: Add ext parameter to UnivariateSpline
Browse files Browse the repository at this point in the history
Pass parameter to fitpack.splev to control how to handle out of range
evaluations. Parameter ext can be passed via __init__ or __call__. Added
test to scipy/interpolate/tests/test_fitpack2.py. Tests if UnivariateSpline
call equals x ** 3 for out of range x and extrapolation modes 0 and 1.

See ticket 3557.

closes scipygh-3557
  • Loading branch information
jacobcvt12 committed Jul 7, 2014
1 parent 6c254b5 commit 04e049d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
30 changes: 26 additions & 4 deletions scipy/interpolate/fitpack2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ class UnivariateSpline(object):
If None (default), s=len(w) which should be a good value if 1/w[i] is
an estimate of the standard deviation of y[i]. If 0, spline will
interpolate through all data points.
ext : int, optional
Controls the extrapolation mode for elements
not in the interval defined by the knot sequence
* if ext=0, return the extrapolated value.
* if ext=1, return 0
* if ext=2, raise a ValueError
The default value is 0.
See Also
--------
Expand Down Expand Up @@ -124,7 +133,7 @@ class UnivariateSpline(object):
"""

def __init__(self, x, y, w=None, bbox=[None]*2, k=3, s=None):
def __init__(self, x, y, w=None, bbox=[None]*2, k=3, s=None, ext=0):
"""
Input:
x,y - 1-d sequences of data points (x must be
Expand All @@ -142,8 +151,19 @@ def __init__(self, x, y, w=None, bbox=[None]*2, k=3, s=None):
Default s=len(w) which should be a good value
if 1/w[i] is an estimate of the standard
deviation of y[i].
ext - Controls the extrapolation mode for elements
not in the interval defined by the knot sequence
* if ext=0, return the extrapolated value.
* if ext=1, return 0
* if ext=2, raise a ValueError
The default value is 0.
"""
# _data == x,y,w,xb,xe,k,s,n,t,c,fp,fpint,nrdata,ier
self.ext = ext
if ext not in (0, 1, 2):
raise ValueError("unknown extrapolation mode")
data = dfitpack.fpcurf0(x,y,k,w=w,
xb=bbox[0],xe=bbox[1],s=s)
if data[-1] == 1:
Expand Down Expand Up @@ -228,7 +248,7 @@ def set_smoothing_factor(self, s):
self._data = data
self._reset_class()

def __call__(self, x, nu=0, ext=0):
def __call__(self, x, nu=0, ext=None):
"""
Evaluate spline (or its nu-th derivative) at positions x.
Expand All @@ -248,8 +268,8 @@ def __call__(self, x, nu=0, ext=0):
* if ext=1, return 0
* if ext=2, raise a ValueError
The default value is 0.
The default value is 0, passed from the initialization of
UnivariateSpline.
"""
x = np.asarray(x)
# empty input yields empty output
Expand All @@ -258,6 +278,8 @@ def __call__(self, x, nu=0, ext=0):
# if nu is None:
# return dfitpack.splev(*(self._eval_args+(x,)))
# return dfitpack.splder(nu=nu,*(self._eval_args+(x,)))
if ext is None:
ext = self.ext
return fitpack.splev(x, self._eval_args, der=nu, ext=ext)

def get_knots(self):
Expand Down
22 changes: 9 additions & 13 deletions scipy/interpolate/tests/test_fitpack2.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,15 @@ def test_resize_regression(self):
assert_allclose(spl([0.1, 0.5, 0.9, 0.99]), desired, atol=5e-4)

def test_out_of_range_regression(self):
"""Regression test for #3557."""
x_in_range = [-3., -2.33333333, -1.66666667, -1., -0.33333333,
0.33333333, 1., 1.66666667, 2.33333333, 3.]
x_out_of_range = [-3.66666667, -3., -2.33333333, -1.66666667, -1.,
-0.33333333, 0.33333333, 1., 1.66666667, 2.33333333,
3., 3.66666667]
y = [-0.03443365, 0.02969204, 0.16251805, 0.33279394, 1.07995049,
0.88203547, 0.31274008, -0.10215527, -0.0274716, -0.06661292]
desired = array([0., -0.21332128, 0.16590419, 0.41822373, 0.55728419,
0.59673238, 0.55021514, 0.43137928, 0.25387164,
0.03133904, -0.22257169, 0.])
spl = UnivariateSpline(x=x_in_range, y=y)
assert_allclose(spl(x_out_of_range, ext=1), desired, atol=5e-4)
# Test different extrapolation modes. See ticket 3557
x = np.arange(5, dtype=np.float)
y = x ** 3
spl = UnivariateSpline(x=x, y=y)
xp = linspace(-8, 13, 100)
xp_zeros = xp.copy()
xp_zeros[np.logical_or(xp_zeros < 0., xp_zeros > 4.)] = 0
assert_allclose(spl(xp, ext=0), xp**3, atol=1e-16)
assert_allclose(spl(xp, ext=1), xp_zeros**3, atol=1e-16)

def test_derivative_and_antiderivative(self):
# Thin wrappers to splder/splantider, so light smoke test only.
Expand Down

0 comments on commit 04e049d

Please sign in to comment.