From 2c58ab2ba592155f0fdaeb38592d55656bc1637a Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 12 Jun 2023 15:44:48 +0200 Subject: [PATCH] MAINT: Use `errstate` rather than `extobj` for linalg ufunc calls This may slow things down very slightly, but we should be able to recover this easily and the `extobj=` API is a bit strange. --- numpy/linalg/linalg.py | 91 ++++++++++++++++++++---------------------- 1 file changed, 44 insertions(+), 47 deletions(-) diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index b838b9397024..c39de052c2fd 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -24,7 +24,7 @@ array, asarray, zeros, empty, empty_like, intc, single, double, csingle, cdouble, inexact, complexfloating, newaxis, all, Inf, dot, add, multiply, sqrt, sum, isfinite, - finfo, errstate, geterrobj, moveaxis, amin, amax, prod, abs, + finfo, errstate, moveaxis, amin, amax, prod, abs, atleast_2d, intp, asanyarray, object_, matmul, swapaxes, divide, count_nonzero, isnan, sign, argsort, sort, reciprocal @@ -94,20 +94,6 @@ class LinAlgError(ValueError): """ -def _determine_error_states(): - errobj = geterrobj() - bufsize = errobj[0] - - with errstate(invalid='call', over='ignore', - divide='ignore', under='ignore'): - invalid_call_errmask = geterrobj()[1] - - return [bufsize, invalid_call_errmask, None] - -# Dealing with errors in _umath_linalg -_linalg_error_extobj = _determine_error_states() -del _determine_error_states - def _raise_linalgerror_singular(err, flag): raise LinAlgError("Singular matrix") @@ -127,10 +113,6 @@ def _raise_linalgerror_qr(err, flag): raise LinAlgError("Incorrect argument found while performing " "QR factorization") -def get_linalg_error_extobj(callback): - extobj = list(_linalg_error_extobj) # make a copy - extobj[2] = callback - return extobj def _makearray(a): new = asarray(a) @@ -405,8 +387,9 @@ def solve(a, b): gufunc = _umath_linalg.solve signature = 'DD->D' if isComplexType(t) else 'dd->d' - extobj = get_linalg_error_extobj(_raise_linalgerror_singular) - r = gufunc(a, b, signature=signature, extobj=extobj) + with errstate(call=_raise_linalgerror_singular, invalid='call', + over='ignore', divide='ignore', under='ignore'): + r = gufunc(a, b, signature=signature) return wrap(r.astype(result_t, copy=False)) @@ -557,8 +540,9 @@ def inv(a): t, result_t = _commonType(a) signature = 'D->D' if isComplexType(t) else 'd->d' - extobj = get_linalg_error_extobj(_raise_linalgerror_singular) - ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj) + with errstate(call=_raise_linalgerror_singular, invalid='call', + over='ignore', divide='ignore', under='ignore'): + ainv = _umath_linalg.inv(a, signature=signature) return wrap(ainv.astype(result_t, copy=False)) @@ -769,14 +753,15 @@ def cholesky(a): [ 0.+2.j, 1.+0.j]]) """ - extobj = get_linalg_error_extobj(_raise_linalgerror_nonposdef) gufunc = _umath_linalg.cholesky_lo a, wrap = _makearray(a) _assert_stacked_2d(a) _assert_stacked_square(a) t, result_t = _commonType(a) signature = 'D->D' if isComplexType(t) else 'd->d' - r = gufunc(a, signature=signature, extobj=extobj) + with errstate(call=_raise_linalgerror_nonposdef, invalid='call', + over='ignore', divide='ignore', under='ignore'): + r = gufunc(a, signature=signature) return wrap(r.astype(result_t, copy=False)) @@ -948,8 +933,9 @@ def qr(a, mode='reduced'): gufunc = _umath_linalg.qr_r_raw_n signature = 'D->D' if isComplexType(t) else 'd->d' - extobj = get_linalg_error_extobj(_raise_linalgerror_qr) - tau = gufunc(a, signature=signature, extobj=extobj) + with errstate(call=_raise_linalgerror_qr, invalid='call', + over='ignore', divide='ignore', under='ignore'): + tau = gufunc(a, signature=signature) # handle modes that don't return q if mode == 'r': @@ -979,8 +965,9 @@ def qr(a, mode='reduced'): gufunc = _umath_linalg.qr_reduced signature = 'DD->D' if isComplexType(t) else 'dd->d' - extobj = get_linalg_error_extobj(_raise_linalgerror_qr) - q = gufunc(a, tau, signature=signature, extobj=extobj) + with errstate(call=_raise_linalgerror_qr, invalid='call', + over='ignore', divide='ignore', under='ignore'): + q = gufunc(a, tau, signature=signature) r = triu(a[..., :mc, :]) q = q.astype(result_t, copy=False) @@ -1068,10 +1055,11 @@ def eigvals(a): _assert_finite(a) t, result_t = _commonType(a) - extobj = get_linalg_error_extobj( - _raise_linalgerror_eigenvalues_nonconvergence) signature = 'D->D' if isComplexType(t) else 'd->D' - w = _umath_linalg.eigvals(a, signature=signature, extobj=extobj) + with errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, + invalid='call', over='ignore', divide='ignore', + under='ignore'): + w = _umath_linalg.eigvals(a, signature=signature) if not isComplexType(t): if all(w.imag == 0): @@ -1166,8 +1154,6 @@ def eigvalsh(a, UPLO='L'): if UPLO not in ('L', 'U'): raise ValueError("UPLO argument must be 'L' or 'U'") - extobj = get_linalg_error_extobj( - _raise_linalgerror_eigenvalues_nonconvergence) if UPLO == 'L': gufunc = _umath_linalg.eigvalsh_lo else: @@ -1178,7 +1164,10 @@ def eigvalsh(a, UPLO='L'): _assert_stacked_square(a) t, result_t = _commonType(a) signature = 'D->d' if isComplexType(t) else 'd->d' - w = gufunc(a, signature=signature, extobj=extobj) + with errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, + invalid='call', over='ignore', divide='ignore', + under='ignore'): + w = gufunc(a, signature=signature) return w.astype(_realType(result_t), copy=False) def _convertarray(a): @@ -1329,10 +1318,11 @@ def eig(a): _assert_finite(a) t, result_t = _commonType(a) - extobj = get_linalg_error_extobj( - _raise_linalgerror_eigenvalues_nonconvergence) signature = 'D->DD' if isComplexType(t) else 'd->DD' - w, vt = _umath_linalg.eig(a, signature=signature, extobj=extobj) + with errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, + invalid='call', over='ignore', divide='ignore', + under='ignore'): + w, vt = _umath_linalg.eig(a, signature=signature) if not isComplexType(t) and all(w.imag == 0.0): w = w.real @@ -1476,15 +1466,16 @@ def eigh(a, UPLO='L'): _assert_stacked_square(a) t, result_t = _commonType(a) - extobj = get_linalg_error_extobj( - _raise_linalgerror_eigenvalues_nonconvergence) if UPLO == 'L': gufunc = _umath_linalg.eigh_lo else: gufunc = _umath_linalg.eigh_up signature = 'D->dD' if isComplexType(t) else 'd->dd' - w, vt = gufunc(a, signature=signature, extobj=extobj) + with errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, + invalid='call', over='ignore', divide='ignore', + under='ignore'): + w, vt = gufunc(a, signature=signature) w = w.astype(_realType(result_t), copy=False) vt = vt.astype(result_t, copy=False) return EighResult(w, wrap(vt)) @@ -1662,8 +1653,6 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False): _assert_stacked_2d(a) t, result_t = _commonType(a) - extobj = get_linalg_error_extobj(_raise_linalgerror_svd_nonconvergence) - m, n = a.shape[-2:] if compute_uv: if full_matrices: @@ -1678,7 +1667,10 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False): gufunc = _umath_linalg.svd_n_s signature = 'D->DdD' if isComplexType(t) else 'd->ddd' - u, s, vh = gufunc(a, signature=signature, extobj=extobj) + with errstate(call=_raise_linalgerror_svd_nonconvergence, + invalid='call', over='ignore', divide='ignore', + under='ignore'): + u, s, vh = gufunc(a, signature=signature) u = u.astype(result_t, copy=False) s = s.astype(_realType(result_t), copy=False) vh = vh.astype(result_t, copy=False) @@ -1690,7 +1682,10 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False): gufunc = _umath_linalg.svd_n signature = 'D->d' if isComplexType(t) else 'd->d' - s = gufunc(a, signature=signature, extobj=extobj) + with errstate(call=_raise_linalgerror_svd_nonconvergence, + invalid='call', over='ignore', divide='ignore', + under='ignore'): + s = gufunc(a, signature=signature) s = s.astype(_realType(result_t), copy=False) return s @@ -2319,11 +2314,13 @@ def lstsq(a, b, rcond="warn"): gufunc = _umath_linalg.lstsq_n signature = 'DDd->Ddid' if isComplexType(t) else 'ddd->ddid' - extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq) if n_rhs == 0: # lapack can't handle n_rhs = 0 - so allocate the array one larger in that axis b = zeros(b.shape[:-2] + (m, n_rhs + 1), dtype=b.dtype) - x, resids, rank, s = gufunc(a, b, rcond, signature=signature, extobj=extobj) + + with errstate(call=_raise_linalgerror_lstsq, invalid='call', + over='ignore', divide='ignore', under='ignore'): + x, resids, rank, s = gufunc(a, b, rcond, signature=signature) if m == 0: x[...] = 0 if n_rhs == 0: