Skip to content

Commit

Permalink
Properly restrict the input dtypes for the array_api trace, svdvals, …
Browse files Browse the repository at this point in the history
…and vecdot
  • Loading branch information
asmeurer committed Mar 29, 2022
1 parent f306e94 commit f375d71
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions numpy/array_api/linalg.py
Expand Up @@ -344,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
# np.linalg.svd(compute_uv=False).
def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in svdvals')
return Array._new(np.linalg.svd(x._array, compute_uv=False))

# Note: tensordot is the numpy top-level namespace but not in np.linalg
Expand All @@ -364,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array:
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in trace')
# Note: trace always operates on the last two axes, whereas np.trace
# operates on the first two axes by default
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))

# Note: vecdot is not in NumPy
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in vecdot')
return tensordot(x1, x2, axes=((axis,), (axis,)))


Expand Down

0 comments on commit f375d71

Please sign in to comment.