diff --git a/verde/spline.py b/verde/spline.py index 14f03bf7e..df19e0d45 100644 --- a/verde/spline.py +++ b/verde/spline.py @@ -92,6 +92,11 @@ class SplineCV(BaseGridder): If True, will use :func:`dask.delayed` to dispatch computations and allow mod:`dask` to execute the grid search in parallel (see note above). + scoring : None, str, or callable + The scoring function (or name of a function) used for cross-validation. + Must be known to scikit-learn. See the description of *scoring* in + :func:`sklearn.model_selection.cross_val_score` for details. If None, + will fall back to the :meth:`verde.Spline.score` method. Attributes ---------- @@ -134,6 +139,7 @@ def __init__( cv=None, client=None, delayed=False, + scoring=None, ): super().__init__() self.dampings = dampings @@ -143,6 +149,7 @@ def __init__( self.cv = cv self.client = client self.delayed = delayed + self.scoring = scoring if engine != "auto": warnings.warn( "The 'engine' parameter of 'verde.SplineCV' is " @@ -229,6 +236,7 @@ def fit(self, coordinates, data, weights=None): weights=weights, cv=self.cv, delayed=self.delayed, + scoring=self.scoring, ) scores.append(dispatch(np.mean, delayed=self.delayed)(score)) best = dispatch(np.argmax, delayed=self.delayed)(scores) diff --git a/verde/tests/test_spline.py b/verde/tests/test_spline.py index a02396ed8..9e00b6c66 100644 --- a/verde/tests/test_spline.py +++ b/verde/tests/test_spline.py @@ -15,6 +15,7 @@ from dask.distributed import Client from sklearn.model_selection import ShuffleSplit +from ..model_selection import cross_val_score from ..spline import Spline, SplineCV from ..synthetic import CheckerBoard from .utils import requires_numba @@ -99,6 +100,25 @@ def test_spline(): ) +def test_spline_cv_scoring(): + "Check scoring parameter works with SplineCV" + region = (100, 500, -800, -700) + synth = CheckerBoard(region=region) + data = synth.scatter(size=1500, random_state=1) + coords = (data.easting, data.northing) + # Compare SplineCV to results from Spline with cross_val_score + for score in ["r2", "neg_root_mean_squared_error"]: + spline = Spline(damping=None, mindist=1e-5) + score_spline = np.mean( + cross_val_score(spline, coords, data.scalars, scoring=score) + ) + # Limit SplineCV to a single parameter set equal to Spline's defaults + spline_cv = SplineCV(mindists=[1e-5], dampings=[None], scoring=score) + spline_cv.fit(coords, data.scalars) + score_spline_cv = spline_cv.scores_[0] + npt.assert_allclose(score_spline, score_spline_cv, rtol=1e-5) + + def test_spline_weights(): "Use weights to ignore an outlier" data = CheckerBoard().scatter(size=2000, random_state=1)