Skip to content

Commit

Permalink
MAINT: Use errstate rather than extobj for linalg ufunc calls
Browse files Browse the repository at this point in the history
This may slow things down very slightly, but we should be able to
recover this easily and the `extobj=` API is a bit strange.
  • Loading branch information
seberg committed Jun 13, 2023
1 parent 3c47251 commit 2c58ab2
Showing 1 changed file with 44 additions and 47 deletions.
91 changes: 44 additions & 47 deletions numpy/linalg/linalg.py
Expand Up @@ -24,7 +24,7 @@
array, asarray, zeros, empty, empty_like, intc, single, double,
csingle, cdouble, inexact, complexfloating, newaxis, all, Inf, dot,
add, multiply, sqrt, sum, isfinite,
finfo, errstate, geterrobj, moveaxis, amin, amax, prod, abs,
finfo, errstate, moveaxis, amin, amax, prod, abs,
atleast_2d, intp, asanyarray, object_, matmul,
swapaxes, divide, count_nonzero, isnan, sign, argsort, sort,
reciprocal
Expand Down Expand Up @@ -94,20 +94,6 @@ class LinAlgError(ValueError):
"""


def _determine_error_states():
errobj = geterrobj()
bufsize = errobj[0]

with errstate(invalid='call', over='ignore',
divide='ignore', under='ignore'):
invalid_call_errmask = geterrobj()[1]

return [bufsize, invalid_call_errmask, None]

# Dealing with errors in _umath_linalg
_linalg_error_extobj = _determine_error_states()
del _determine_error_states

def _raise_linalgerror_singular(err, flag):
raise LinAlgError("Singular matrix")

Expand All @@ -127,10 +113,6 @@ def _raise_linalgerror_qr(err, flag):
raise LinAlgError("Incorrect argument found while performing "
"QR factorization")

def get_linalg_error_extobj(callback):
extobj = list(_linalg_error_extobj) # make a copy
extobj[2] = callback
return extobj

def _makearray(a):
new = asarray(a)
Expand Down Expand Up @@ -405,8 +387,9 @@ def solve(a, b):
gufunc = _umath_linalg.solve

signature = 'DD->D' if isComplexType(t) else 'dd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
r = gufunc(a, b, signature=signature, extobj=extobj)
with errstate(call=_raise_linalgerror_singular, invalid='call',
over='ignore', divide='ignore', under='ignore'):
r = gufunc(a, b, signature=signature)

return wrap(r.astype(result_t, copy=False))

Expand Down Expand Up @@ -557,8 +540,9 @@ def inv(a):
t, result_t = _commonType(a)

signature = 'D->D' if isComplexType(t) else 'd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
with errstate(call=_raise_linalgerror_singular, invalid='call',
over='ignore', divide='ignore', under='ignore'):
ainv = _umath_linalg.inv(a, signature=signature)
return wrap(ainv.astype(result_t, copy=False))


Expand Down Expand Up @@ -769,14 +753,15 @@ def cholesky(a):
[ 0.+2.j, 1.+0.j]])
"""
extobj = get_linalg_error_extobj(_raise_linalgerror_nonposdef)
gufunc = _umath_linalg.cholesky_lo
a, wrap = _makearray(a)
_assert_stacked_2d(a)
_assert_stacked_square(a)
t, result_t = _commonType(a)
signature = 'D->D' if isComplexType(t) else 'd->d'
r = gufunc(a, signature=signature, extobj=extobj)
with errstate(call=_raise_linalgerror_nonposdef, invalid='call',
over='ignore', divide='ignore', under='ignore'):
r = gufunc(a, signature=signature)
return wrap(r.astype(result_t, copy=False))


Expand Down Expand Up @@ -948,8 +933,9 @@ def qr(a, mode='reduced'):
gufunc = _umath_linalg.qr_r_raw_n

signature = 'D->D' if isComplexType(t) else 'd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_qr)
tau = gufunc(a, signature=signature, extobj=extobj)
with errstate(call=_raise_linalgerror_qr, invalid='call',
over='ignore', divide='ignore', under='ignore'):
tau = gufunc(a, signature=signature)

# handle modes that don't return q
if mode == 'r':
Expand Down Expand Up @@ -979,8 +965,9 @@ def qr(a, mode='reduced'):
gufunc = _umath_linalg.qr_reduced

signature = 'DD->D' if isComplexType(t) else 'dd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_qr)
q = gufunc(a, tau, signature=signature, extobj=extobj)
with errstate(call=_raise_linalgerror_qr, invalid='call',
over='ignore', divide='ignore', under='ignore'):
q = gufunc(a, tau, signature=signature)
r = triu(a[..., :mc, :])

q = q.astype(result_t, copy=False)
Expand Down Expand Up @@ -1068,10 +1055,11 @@ def eigvals(a):
_assert_finite(a)
t, result_t = _commonType(a)

extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
signature = 'D->D' if isComplexType(t) else 'd->D'
w = _umath_linalg.eigvals(a, signature=signature, extobj=extobj)
with errstate(call=_raise_linalgerror_eigenvalues_nonconvergence,
invalid='call', over='ignore', divide='ignore',
under='ignore'):
w = _umath_linalg.eigvals(a, signature=signature)

if not isComplexType(t):
if all(w.imag == 0):
Expand Down Expand Up @@ -1166,8 +1154,6 @@ def eigvalsh(a, UPLO='L'):
if UPLO not in ('L', 'U'):
raise ValueError("UPLO argument must be 'L' or 'U'")

extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
if UPLO == 'L':
gufunc = _umath_linalg.eigvalsh_lo
else:
Expand All @@ -1178,7 +1164,10 @@ def eigvalsh(a, UPLO='L'):
_assert_stacked_square(a)
t, result_t = _commonType(a)
signature = 'D->d' if isComplexType(t) else 'd->d'
w = gufunc(a, signature=signature, extobj=extobj)
with errstate(call=_raise_linalgerror_eigenvalues_nonconvergence,
invalid='call', over='ignore', divide='ignore',
under='ignore'):
w = gufunc(a, signature=signature)
return w.astype(_realType(result_t), copy=False)

def _convertarray(a):
Expand Down Expand Up @@ -1329,10 +1318,11 @@ def eig(a):
_assert_finite(a)
t, result_t = _commonType(a)

extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
signature = 'D->DD' if isComplexType(t) else 'd->DD'
w, vt = _umath_linalg.eig(a, signature=signature, extobj=extobj)
with errstate(call=_raise_linalgerror_eigenvalues_nonconvergence,
invalid='call', over='ignore', divide='ignore',
under='ignore'):
w, vt = _umath_linalg.eig(a, signature=signature)

if not isComplexType(t) and all(w.imag == 0.0):
w = w.real
Expand Down Expand Up @@ -1476,15 +1466,16 @@ def eigh(a, UPLO='L'):
_assert_stacked_square(a)
t, result_t = _commonType(a)

extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
if UPLO == 'L':
gufunc = _umath_linalg.eigh_lo
else:
gufunc = _umath_linalg.eigh_up

signature = 'D->dD' if isComplexType(t) else 'd->dd'
w, vt = gufunc(a, signature=signature, extobj=extobj)
with errstate(call=_raise_linalgerror_eigenvalues_nonconvergence,
invalid='call', over='ignore', divide='ignore',
under='ignore'):
w, vt = gufunc(a, signature=signature)
w = w.astype(_realType(result_t), copy=False)
vt = vt.astype(result_t, copy=False)
return EighResult(w, wrap(vt))
Expand Down Expand Up @@ -1662,8 +1653,6 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
_assert_stacked_2d(a)
t, result_t = _commonType(a)

extobj = get_linalg_error_extobj(_raise_linalgerror_svd_nonconvergence)

m, n = a.shape[-2:]
if compute_uv:
if full_matrices:
Expand All @@ -1678,7 +1667,10 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
gufunc = _umath_linalg.svd_n_s

signature = 'D->DdD' if isComplexType(t) else 'd->ddd'
u, s, vh = gufunc(a, signature=signature, extobj=extobj)
with errstate(call=_raise_linalgerror_svd_nonconvergence,
invalid='call', over='ignore', divide='ignore',
under='ignore'):
u, s, vh = gufunc(a, signature=signature)
u = u.astype(result_t, copy=False)
s = s.astype(_realType(result_t), copy=False)
vh = vh.astype(result_t, copy=False)
Expand All @@ -1690,7 +1682,10 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
gufunc = _umath_linalg.svd_n

signature = 'D->d' if isComplexType(t) else 'd->d'
s = gufunc(a, signature=signature, extobj=extobj)
with errstate(call=_raise_linalgerror_svd_nonconvergence,
invalid='call', over='ignore', divide='ignore',
under='ignore'):
s = gufunc(a, signature=signature)
s = s.astype(_realType(result_t), copy=False)
return s

Expand Down Expand Up @@ -2319,11 +2314,13 @@ def lstsq(a, b, rcond="warn"):
gufunc = _umath_linalg.lstsq_n

signature = 'DDd->Ddid' if isComplexType(t) else 'ddd->ddid'
extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq)
if n_rhs == 0:
# lapack can't handle n_rhs = 0 - so allocate the array one larger in that axis
b = zeros(b.shape[:-2] + (m, n_rhs + 1), dtype=b.dtype)
x, resids, rank, s = gufunc(a, b, rcond, signature=signature, extobj=extobj)

with errstate(call=_raise_linalgerror_lstsq, invalid='call',
over='ignore', divide='ignore', under='ignore'):
x, resids, rank, s = gufunc(a, b, rcond, signature=signature)
if m == 0:
x[...] = 0
if n_rhs == 0:
Expand Down

0 comments on commit 2c58ab2

Please sign in to comment.