diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 47b359561ec..420b6c55d56 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,10 @@ New Features - Added support for multidimensional initial guess and bounds in :py:meth:`DataArray.curvefit` (:issue:`7768`, :pull:`7821`). By `András Gunyhó `_. +- Add an ``errors`` option to :py:meth:`Dataset.curve_fit` that allows + returning NaN for the parameters and covariances of failed fits, rather than + failing the whole series of fits (:issue:`6317`, :pull:`7891`). + By `Dominik Stańczak `_ and `András Gunyhó `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index bb6a9b131ab..9635d678c36 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6162,6 +6162,7 @@ def curvefit( p0: dict[str, float | DataArray] | None = None, bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None, param_names: Sequence[str] | None = None, + errors: ErrorOptions = "raise", kwargs: dict[str, Any] | None = None, ) -> Dataset: """ @@ -6206,6 +6207,10 @@ def curvefit( this will be automatically determined by arguments of `func`. `param_names` should be manually supplied when fitting a function that takes a variable number of parameters. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will + raise an exception. If 'ignore', the coefficients and covariances for the + coordinates where the fitting failed will be NaN. **kwargs : optional Additional keyword arguments to passed to scipy curve_fit. @@ -6312,6 +6317,7 @@ def curvefit( p0=p0, bounds=bounds, param_names=param_names, + errors=errors, kwargs=kwargs, ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0f271e9d7e4..f4ba9d4f9fe 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8631,6 +8631,7 @@ def curvefit( p0: dict[str, float | DataArray] | None = None, bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None, param_names: Sequence[str] | None = None, + errors: ErrorOptions = "raise", kwargs: dict[str, Any] | None = None, ) -> T_Dataset: """ @@ -8675,6 +8676,10 @@ def curvefit( this will be automatically determined by arguments of `func`. `param_names` should be manually supplied when fitting a function that takes a variable number of parameters. + errors : {"raise", "ignore"}, default: "raise" + If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will + raise an exception. If 'ignore', the coefficients and covariances for the + coordinates where the fitting failed will be NaN. **kwargs : optional Additional keyword arguments to passed to scipy curve_fit. @@ -8757,6 +8762,9 @@ def curvefit( f"dimensions {preserved_dims}." ) + if errors not in ["raise", "ignore"]: + raise ValueError('errors must be either "raise" or "ignore"') + # Broadcast all coords with each other coords_ = broadcast(*coords_) coords_ = [ @@ -8793,7 +8801,15 @@ def _wrapper(Y, *args, **kwargs): pcov = np.full([n_params, n_params], np.nan) return popt, pcov x = np.squeeze(x) - popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs) + + try: + popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs) + except RuntimeError: + if errors == "raise": + raise + popt = np.full([n_params], np.nan) + pcov = np.full([n_params, n_params], np.nan) + return popt, pcov result = type(self)() diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 804d1517f85..cee5afa56a4 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4584,6 +4584,48 @@ def sine(t, a, f, p): bounds={"a": (0, DataArray([1], coords={"foo": [1]}))}, ) + @requires_scipy + @pytest.mark.parametrize("use_dask", [True, False]) + def test_curvefit_ignore_errors(self, use_dask: bool) -> None: + if use_dask and not has_dask: + pytest.skip("requires dask") + + # nonsense function to make the optimization fail + def line(x, a, b): + if a > 10: + return 0 + return a * x + b + + da = DataArray( + [[1, 3, 5], [0, 20, 40]], + coords={"i": [1, 2], "x": [0.0, 1.0, 2.0]}, + ) + + if use_dask: + da = da.chunk({"i": 1}) + + expected = DataArray( + [[2, 1], [np.nan, np.nan]], coords={"i": [1, 2], "param": ["a", "b"]} + ) + + with pytest.raises(RuntimeError, match="calls to function has reached maxfev"): + da.curvefit( + coords="x", + func=line, + # limit maximum number of calls so the optimization fails + kwargs=dict(maxfev=5), + ).compute() # have to compute to raise the error + + fit = da.curvefit( + coords="x", + func=line, + errors="ignore", + # limit maximum number of calls so the optimization fails + kwargs=dict(maxfev=5), + ).compute() + + assert_allclose(fit.curvefit_coefficients, expected) + class TestReduce: @pytest.fixture(autouse=True)