Skip to content

Commit

Permalink
Add errors option to curvefit (pydata#7891)
Browse files Browse the repository at this point in the history
* Add allow_failures flag to Dataset.curve_fit

* Reword docstring

* Add allow_failures flag also to DataArray

* Add unit test for curvefit with allow_failures

* Update whats-new

Co-authored-by: Dominik Stańczak <stanczakdominik@gmail.com>

* Add PR to whats-new

* Update docstring

* Rename allow_failures to errors to be consistent with other methods

* Compute array so test passes also with dask

* Check error message

* Update whats-new

---------

Co-authored-by: Dominik Stańczak <stanczakdominik@gmail.com>
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
3 people authored and dstansby committed Jun 28, 2023
1 parent 2697150 commit 0a8bf98
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 1 deletion.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Expand Up @@ -25,6 +25,10 @@ New Features

- Added support for multidimensional initial guess and bounds in :py:meth:`DataArray.curvefit` (:issue:`7768`, :pull:`7821`).
By `András Gunyhó <https://github.com/mgunyho>`_.
- Add an ``errors`` option to :py:meth:`Dataset.curve_fit` that allows
returning NaN for the parameters and covariances of failed fits, rather than
failing the whole series of fits (:issue:`6317`, :pull:`7891`).
By `Dominik Stańczak <https://github.com/StanczakDominik>`_ and `András Gunyhó <https://github.com/mgunyho>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
6 changes: 6 additions & 0 deletions xarray/core/dataarray.py
Expand Up @@ -6162,6 +6162,7 @@ def curvefit(
p0: dict[str, float | DataArray] | None = None,
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
param_names: Sequence[str] | None = None,
errors: ErrorOptions = "raise",
kwargs: dict[str, Any] | None = None,
) -> Dataset:
"""
Expand Down Expand Up @@ -6206,6 +6207,10 @@ def curvefit(
this will be automatically determined by arguments of `func`. `param_names`
should be manually supplied when fitting a function that takes a variable
number of parameters.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will
raise an exception. If 'ignore', the coefficients and covariances for the
coordinates where the fitting failed will be NaN.
**kwargs : optional
Additional keyword arguments to passed to scipy curve_fit.
Expand Down Expand Up @@ -6312,6 +6317,7 @@ def curvefit(
p0=p0,
bounds=bounds,
param_names=param_names,
errors=errors,
kwargs=kwargs,
)

Expand Down
18 changes: 17 additions & 1 deletion xarray/core/dataset.py
Expand Up @@ -8631,6 +8631,7 @@ def curvefit(
p0: dict[str, float | DataArray] | None = None,
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
param_names: Sequence[str] | None = None,
errors: ErrorOptions = "raise",
kwargs: dict[str, Any] | None = None,
) -> T_Dataset:
"""
Expand Down Expand Up @@ -8675,6 +8676,10 @@ def curvefit(
this will be automatically determined by arguments of `func`. `param_names`
should be manually supplied when fitting a function that takes a variable
number of parameters.
errors : {"raise", "ignore"}, default: "raise"
If 'raise', any errors from the `scipy.optimize_curve_fit` optimization will
raise an exception. If 'ignore', the coefficients and covariances for the
coordinates where the fitting failed will be NaN.
**kwargs : optional
Additional keyword arguments to passed to scipy curve_fit.
Expand Down Expand Up @@ -8757,6 +8762,9 @@ def curvefit(
f"dimensions {preserved_dims}."
)

if errors not in ["raise", "ignore"]:
raise ValueError('errors must be either "raise" or "ignore"')

# Broadcast all coords with each other
coords_ = broadcast(*coords_)
coords_ = [
Expand Down Expand Up @@ -8793,7 +8801,15 @@ def _wrapper(Y, *args, **kwargs):
pcov = np.full([n_params, n_params], np.nan)
return popt, pcov
x = np.squeeze(x)
popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)

try:
popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)
except RuntimeError:
if errors == "raise":
raise
popt = np.full([n_params], np.nan)
pcov = np.full([n_params, n_params], np.nan)

return popt, pcov

result = type(self)()
Expand Down
42 changes: 42 additions & 0 deletions xarray/tests/test_dataarray.py
Expand Up @@ -4584,6 +4584,48 @@ def sine(t, a, f, p):
bounds={"a": (0, DataArray([1], coords={"foo": [1]}))},
)

@requires_scipy
@pytest.mark.parametrize("use_dask", [True, False])
def test_curvefit_ignore_errors(self, use_dask: bool) -> None:
if use_dask and not has_dask:
pytest.skip("requires dask")

# nonsense function to make the optimization fail
def line(x, a, b):
if a > 10:
return 0
return a * x + b

da = DataArray(
[[1, 3, 5], [0, 20, 40]],
coords={"i": [1, 2], "x": [0.0, 1.0, 2.0]},
)

if use_dask:
da = da.chunk({"i": 1})

expected = DataArray(
[[2, 1], [np.nan, np.nan]], coords={"i": [1, 2], "param": ["a", "b"]}
)

with pytest.raises(RuntimeError, match="calls to function has reached maxfev"):
da.curvefit(
coords="x",
func=line,
# limit maximum number of calls so the optimization fails
kwargs=dict(maxfev=5),
).compute() # have to compute to raise the error

fit = da.curvefit(
coords="x",
func=line,
errors="ignore",
# limit maximum number of calls so the optimization fails
kwargs=dict(maxfev=5),
).compute()

assert_allclose(fit.curvefit_coefficients, expected)


class TestReduce:
@pytest.fixture(autouse=True)
Expand Down

0 comments on commit 0a8bf98

Please sign in to comment.