Skip to content

Commit

Permalink
Add 'scoring' parameter for SplineCV (#380)
Browse files Browse the repository at this point in the history
Used to control the scoring metric used during cross-validation. 
Argument is passed to cross_val_score internally and defaults to
None.
  • Loading branch information
JamesSample committed Sep 7, 2022
1 parent 66748f9 commit 5d9885f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
8 changes: 8 additions & 0 deletions verde/spline.py
Expand Up @@ -92,6 +92,11 @@ class SplineCV(BaseGridder):
If True, will use :func:`dask.delayed` to dispatch computations and
allow mod:`dask` to execute the grid search in parallel (see note
above).
scoring : None, str, or callable
The scoring function (or name of a function) used for cross-validation.
Must be known to scikit-learn. See the description of *scoring* in
:func:`sklearn.model_selection.cross_val_score` for details. If None,
will fall back to the :meth:`verde.Spline.score` method.
Attributes
----------
Expand Down Expand Up @@ -134,6 +139,7 @@ def __init__(
cv=None,
client=None,
delayed=False,
scoring=None,
):
super().__init__()
self.dampings = dampings
Expand All @@ -143,6 +149,7 @@ def __init__(
self.cv = cv
self.client = client
self.delayed = delayed
self.scoring = scoring
if engine != "auto":
warnings.warn(
"The 'engine' parameter of 'verde.SplineCV' is "
Expand Down Expand Up @@ -229,6 +236,7 @@ def fit(self, coordinates, data, weights=None):
weights=weights,
cv=self.cv,
delayed=self.delayed,
scoring=self.scoring,
)
scores.append(dispatch(np.mean, delayed=self.delayed)(score))
best = dispatch(np.argmax, delayed=self.delayed)(scores)
Expand Down
20 changes: 20 additions & 0 deletions verde/tests/test_spline.py
Expand Up @@ -15,6 +15,7 @@
from dask.distributed import Client
from sklearn.model_selection import ShuffleSplit

from ..model_selection import cross_val_score
from ..spline import Spline, SplineCV
from ..synthetic import CheckerBoard
from .utils import requires_numba
Expand Down Expand Up @@ -99,6 +100,25 @@ def test_spline():
)


def test_spline_cv_scoring():
"Check scoring parameter works with SplineCV"
region = (100, 500, -800, -700)
synth = CheckerBoard(region=region)
data = synth.scatter(size=1500, random_state=1)
coords = (data.easting, data.northing)
# Compare SplineCV to results from Spline with cross_val_score
for score in ["r2", "neg_root_mean_squared_error"]:
spline = Spline(damping=None, mindist=1e-5)
score_spline = np.mean(
cross_val_score(spline, coords, data.scalars, scoring=score)
)
# Limit SplineCV to a single parameter set equal to Spline's defaults
spline_cv = SplineCV(mindists=[1e-5], dampings=[None], scoring=score)
spline_cv.fit(coords, data.scalars)
score_spline_cv = spline_cv.scores_[0]
npt.assert_allclose(score_spline, score_spline_cv, rtol=1e-5)


def test_spline_weights():
"Use weights to ignore an outlier"
data = CheckerBoard().scatter(size=2000, random_state=1)
Expand Down

0 comments on commit 5d9885f

Please sign in to comment.