-
-
Notifications
You must be signed in to change notification settings - Fork 9.9k
/
linalg.py
446 lines (358 loc) · 17.4 KB
/
linalg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
from __future__ import annotations
from ._dtypes import _floating_dtypes, _numeric_dtypes
from ._manipulation_functions import reshape
from ._array_object import Array
from ..core.numeric import normalize_axis_tuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ._typing import Literal, Optional, Sequence, Tuple, Union
from typing import NamedTuple
import numpy.linalg
import numpy as np
class EighResult(NamedTuple):
eigenvalues: Array
eigenvectors: Array
class QRResult(NamedTuple):
Q: Array
R: Array
class SlogdetResult(NamedTuple):
sign: Array
logabsdet: Array
class SVDResult(NamedTuple):
U: Array
S: Array
Vh: Array
# Note: the inclusion of the upper keyword is different from
# np.linalg.cholesky, which does not have it.
def cholesky(x: Array, /, *, upper: bool = False) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.cholesky <numpy.linalg.cholesky>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.cholesky.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in cholesky')
L = np.linalg.cholesky(x._array)
if upper:
return Array._new(L).mT
return Array._new(L)
# Note: cross is the numpy top-level namespace, not np.linalg
def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
"""
Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
See its docstring for more information.
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in cross')
# Note: this is different from np.cross(), which broadcasts
if x1.shape != x2.shape:
raise ValueError('x1 and x2 must have the same shape')
if x1.ndim == 0:
raise ValueError('cross() requires arrays of dimension at least 1')
# Note: this is different from np.cross(), which allows dimension 2
if x1.shape[axis] != 3:
raise ValueError('cross() dimension must equal 3')
return Array._new(np.cross(x1._array, x2._array, axis=axis))
def det(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.det.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in det')
return Array._new(np.linalg.det(x._array))
# Note: diagonal is the numpy top-level namespace, not np.linalg
def diagonal(x: Array, /, *, offset: int = 0) -> Array:
"""
Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
See its docstring for more information.
"""
# Note: diagonal always operates on the last two axes, whereas np.diagonal
# operates on the first two axes by default
return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1))
def eigh(x: Array, /) -> EighResult:
"""
Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.eigh.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in eigh')
# Note: the return type here is a namedtuple, which is different from
# np.eigh, which only returns a tuple.
return EighResult(*map(Array._new, np.linalg.eigh(x._array)))
def eigvalsh(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.eigvalsh.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in eigvalsh')
return Array._new(np.linalg.eigvalsh(x._array))
def inv(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.inv.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in inv')
return Array._new(np.linalg.inv(x._array))
# Note: matmul is the numpy top-level namespace but not in np.linalg
def matmul(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
See its docstring for more information.
"""
# Note: the restriction to numeric dtypes only is different from
# np.matmul.
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in matmul')
return Array._new(np.matmul(x1._array, x2._array))
# Note: the name here is different from norm(). The array API norm is split
# into matrix_norm and vector_norm().
# The type for ord should be Optional[Union[int, float, Literal[np.inf,
# -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point
# literals.
def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.norm.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in matrix_norm')
return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord))
def matrix_power(x: Array, n: int, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.matrix_power.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed for the first argument of matrix_power')
# np.matrix_power already checks if n is an integer
return Array._new(np.linalg.matrix_power(x._array, n))
# Note: the keyword argument name rtol is different from np.linalg.matrix_rank
def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
"""
Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`.
See its docstring for more information.
"""
# Note: this is different from np.linalg.matrix_rank, which supports 1
# dimensional arrays.
if x.ndim < 2:
raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
S = np.linalg.svd(x._array, compute_uv=False)
if rtol is None:
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps
else:
if isinstance(rtol, Array):
rtol = rtol._array
# Note: this is different from np.linalg.matrix_rank, which does not multiply
# the tolerance by the largest singular value.
tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis]
return Array._new(np.count_nonzero(S > tol, axis=-1))
# Note: this function is new in the array API spec. Unlike transpose, it only
# transposes the last two axes.
def matrix_transpose(x: Array, /) -> Array:
if x.ndim < 2:
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
return Array._new(np.swapaxes(x._array, -1, -2))
# Note: outer is the numpy top-level namespace, not np.linalg
def outer(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`.
See its docstring for more information.
"""
# Note: the restriction to numeric dtypes only is different from
# np.outer.
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in outer')
# Note: the restriction to only 1-dim arrays is different from np.outer
if x1.ndim != 1 or x2.ndim != 1:
raise ValueError('The input arrays to outer must be 1-dimensional')
return Array._new(np.outer(x1._array, x2._array))
# Note: the keyword argument name rtol is different from np.linalg.pinv
def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.pinv <numpy.linalg.pinv>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.pinv.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in pinv')
# Note: this is different from np.linalg.pinv, which does not multiply the
# default tolerance by max(M, N).
if rtol is None:
rtol = max(x.shape[-2:]) * np.finfo(x.dtype).eps
return Array._new(np.linalg.pinv(x._array, rcond=rtol))
def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult:
"""
Array API compatible wrapper for :py:func:`np.linalg.qr <numpy.linalg.qr>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.qr.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in qr')
# Note: the return type here is a namedtuple, which is different from
# np.linalg.qr, which only returns a tuple.
return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode)))
def slogdet(x: Array, /) -> SlogdetResult:
"""
Array API compatible wrapper for :py:func:`np.linalg.slogdet <numpy.linalg.slogdet>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.slogdet.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in slogdet')
# Note: the return type here is a namedtuple, which is different from
# np.linalg.slogdet, which only returns a tuple.
return SlogdetResult(*map(Array._new, np.linalg.slogdet(x._array)))
# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
# of matrices. The np.linalg.solve behavior of allowing stacks of both
# matrices and vectors is ambiguous c.f.
# https://github.com/numpy/numpy/issues/15349 and
# https://github.com/data-apis/array-api/issues/285.
# To workaround this, the below is the code from np.linalg.solve except
# only calling solve1 in the exactly 1D case.
def _solve(a, b):
from ..linalg.linalg import (_makearray, _assert_stacked_2d,
_assert_stacked_square, _commonType,
isComplexType, get_linalg_error_extobj,
_raise_linalgerror_singular)
from ..linalg import _umath_linalg
a, _ = _makearray(a)
_assert_stacked_2d(a)
_assert_stacked_square(a)
b, wrap = _makearray(b)
t, result_t = _commonType(a, b)
# This part is different from np.linalg.solve
if b.ndim == 1:
gufunc = _umath_linalg.solve1
else:
gufunc = _umath_linalg.solve
# This does nothing currently but is left in because it will be relevant
# when complex dtype support is added to the spec in 2022.
signature = 'DD->D' if isComplexType(t) else 'dd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
r = gufunc(a, b, signature=signature, extobj=extobj)
return wrap(r.astype(result_t, copy=False))
def solve(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.solve <numpy.linalg.solve>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.solve.
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in solve')
return Array._new(_solve(x1._array, x2._array))
def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
"""
Array API compatible wrapper for :py:func:`np.linalg.svd <numpy.linalg.svd>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.svd.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in svd')
# Note: the return type here is a namedtuple, which is different from
# np.svd, which only returns a tuple.
return SVDResult(*map(Array._new, np.linalg.svd(x._array, full_matrices=full_matrices)))
# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
# np.linalg.svd(compute_uv=False).
def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in svdvals')
return Array._new(np.linalg.svd(x._array, compute_uv=False))
# Note: tensordot is the numpy top-level namespace but not in np.linalg
# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array:
# Note: the restriction to numeric dtypes only is different from
# np.tensordot.
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in tensordot')
return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
# Note: trace is the numpy top-level namespace, not np.linalg
def trace(x: Array, /, *, offset: int = 0) -> Array:
"""
Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
See its docstring for more information.
"""
if x.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in trace')
# Note: trace always operates on the last two axes, whereas np.trace
# operates on the first two axes by default
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))
# Note: vecdot is not in NumPy
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError('Only numeric dtypes are allowed in vecdot')
ndim = max(x1.ndim, x2.ndim)
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
if x1_shape[axis] != x2_shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
x1_, x2_ = np.broadcast_arrays(x1._array, x2._array)
x1_ = np.moveaxis(x1_, axis, -1)
x2_ = np.moveaxis(x2_, axis, -1)
res = x1_[..., None, :] @ x2_[..., None]
return Array._new(res[..., 0, 0])
# Note: the name here is different from norm(). The array API norm is split
# into matrix_norm and vector_norm().
# The type for ord should be Optional[Union[int, float, Literal[np.inf,
# -np.inf]]] but Literal does not support floating-point literals.
def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.norm.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in norm')
# np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
# when axis=None and the input is 2-D, so to force a vector norm, we make
# it so the input is 1-D (for axis=None), or reshape so that norm is done
# on a single dimension.
a = x._array
if axis is None:
# Note: np.linalg.norm() doesn't handle 0-D arrays
a = a.ravel()
_axis = 0
elif isinstance(axis, tuple):
# Note: The axis argument supports any number of axes, whereas
# np.linalg.norm() only supports a single axis for vector norm.
normalized_axis = normalize_axis_tuple(axis, x.ndim)
rest = tuple(i for i in range(a.ndim) if i not in normalized_axis)
newshape = axis + rest
a = np.transpose(a, newshape).reshape(
(np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest]))
_axis = 0
else:
_axis = axis
res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord))
if keepdims:
# We can't reuse np.linalg.norm(keepdims) because of the reshape hacks
# above to avoid matrix norm logic.
shape = list(x.shape)
_axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
for i in _axis:
shape[i] = 1
res = reshape(res, tuple(shape))
return res
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']