-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
linalg.py
2362 lines (1988 loc) · 88.6 KB
/
linalg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import functools
from functools import partial
import math
from typing import cast, Any, Callable, Literal, Optional, TypeVar, Union, overload
import warnings
import numpy as np
import jax
from jax import lax
from jax._src import ad_util
from jax._src import api
from jax._src import dispatch
from jax._src import dtypes
from jax._src.core import (
Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape)
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import control_flow
from jax._src.lax import eigh as lax_eigh
from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
_input_dtype)
from jax._src.lib import gpu_linalg
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_sparse
from jax._src.lib import lapack
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike
xops = xla_client.ops
TFun = TypeVar('TFun', bound=Callable[..., Any])
# traceables
# TODO(phawkins): remove backward compatibility shim after 2022/08/11.
def _warn_on_positional_kwargs(f: TFun) -> TFun:
"""Decorator used for backward compatibility of keyword-only arguments.
Some functions were changed to mark their keyword arguments as keyword-only.
This decorator allows existing code to keep working temporarily, while issuing
a warning if a now keyword-only parameter is passed positionally."""
sig = inspect.signature(f)
pos_names = [name for name, p in sig.parameters.items()
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD]
kwarg_names = [name for name, p in sig.parameters.items()
if p.kind == inspect.Parameter.KEYWORD_ONLY]
# This decorator assumes that all arguments to `f` are either
# positional-or-keyword or keyword-only.
assert len(pos_names) + len(kwarg_names) == len(sig.parameters)
@functools.wraps(f)
def wrapped(*args, **kwargs):
if len(args) < len(pos_names):
a = pos_names[len(args)]
raise TypeError(f"{f.__name__} missing required positional argument: {a}")
pos_args = args[:len(pos_names)]
extra_kwargs = args[len(pos_names):]
if len(extra_kwargs) > len(kwarg_names):
raise TypeError(f"{f.__name__} takes at most {len(sig.parameters)} "
f" arguments but {len(args)} were given.")
for name, value in zip(kwarg_names, extra_kwargs):
if name in kwargs:
raise TypeError(f"{f.__name__} got multiple values for argument: "
f"{name}")
warnings.warn(f"Argument {name} to {f.__name__} is now a keyword-only "
"argument. Support for passing it positionally will be "
"removed in an upcoming JAX release.",
DeprecationWarning)
kwargs[name] = value
return f(*pos_args, **kwargs)
return cast(TFun, wrapped)
@_warn_on_positional_kwargs
def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array:
"""Cholesky decomposition.
Computes the Cholesky decomposition
.. math::
A = L . L^H
of square matrices, :math:`A`, such that :math:`L`
is lower triangular. The matrices of :math:`A` must be positive-definite and
either Hermitian, if complex, or symmetric, if real.
Args:
x: A batch of square Hermitian (symmetric if real) positive-definite
matrices with shape ``[..., n, n]``.
symmetrize_input: If ``True``, the matrix is symmetrized before Cholesky
decomposition by computing :math:`\\frac{1}{2}(x + x^H)`. If ``False``,
only the lower triangle of ``x`` is used; the upper triangle is ignored
and not accessed.
Returns:
The Cholesky decomposition as a matrix with the same dtype as ``x`` and
shape ``[..., n, n]``. If Cholesky decomposition fails, returns a matrix
full of NaNs. The behavior on failure may change in the future.
"""
if symmetrize_input:
x = symmetrize(x)
return jnp.tril(cholesky_p.bind(x))
@_warn_on_positional_kwargs
def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
compute_right_eigenvectors: bool = True) -> list[Array]:
"""Eigendecomposition of a general matrix.
Nonsymmetric eigendecomposition is at present only implemented on CPU.
"""
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
@_warn_on_positional_kwargs
def eigh(x: Array, *, lower: bool = True, symmetrize_input: bool = True,
sort_eigenvalues: bool = True) -> tuple[Array, Array]:
r"""Eigendecomposition of a Hermitian matrix.
Computes the eigenvectors and eigenvalues of a complex Hermitian or real
symmetric square matrix.
Args:
x: A batch of square complex Hermitian or real symmetric matrices with shape
``[..., n, n]``.
lower: If ``symmetrize_input`` is ``False``, describes which triangle of the
input matrix to use. If ``symmetrize_input`` is ``False``, only the
triangle given by ``lower`` is accessed; the other triangle is ignored and
not accessed.
symmetrize_input: If ``True``, the matrix is symmetrized before the
eigendecomposition by computing :math:`\frac{1}{2}(x + x^H)`.
sort_eigenvalues: If ``True``, the eigenvalues will be sorted in ascending
order. If ``False`` the eigenvalues are returned in an
implementation-defined order.
Returns:
A tuple ``(v, w)``.
``v`` is an array with the same dtype as ``x`` such that ``v[..., :, i]`` is
the normalized eigenvector corresponding to eigenvalue ``w[..., i]``.
``w`` is an array with the same dtype as ``x`` (or its real counterpart if
complex) with shape ``[..., n]`` containing the eigenvalues of ``x`` in
ascending order(each repeated according to its multiplicity).
"""
if symmetrize_input:
x = symmetrize(x)
v, w = eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues)
return v, w
def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array:
"""Converts the pivots (row swaps) returned by LU to a permutation.
We build a permutation rather than applying `pivots` directly to the rows
of a matrix because lax loops aren't differentiable.
Args:
pivots: an int32 array of shape (..., k) of row swaps to perform
permutation_size: the size of the output permutation. Has to be >= k.
Returns:
An int32 array of shape (..., permutation_size).
"""
permutation = lu_pivots_to_permutation_p.bind(
pivots, permutation_size=int(permutation_size))
return permutation
def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
"""LU decomposition with partial pivoting.
Computes the matrix decomposition:
.. math::
P.A = L.U
where :math:`P` is a permutation of the rows of :math:`A`, :math:`L` is a
lower-triangular matrix with unit-diagonal elements, and :math:`U` is an
upper-triangular matrix.
Args:
x: A batch of matrices with shape ``[..., m, n]``.
Returns:
A tuple ``(lu, pivots, permutation)``.
``lu`` is a batch of matrices with the same shape and dtype as ``x``
containing the :math:`L` matrix in its lower triangle and the :math:`U`
matrix in its upper triangle. The (unit) diagonal elements of :math:`L` are
not represented explicitly.
``pivots`` is an int32 array with shape ``[..., min(m, n)]`` representing a
sequence of row swaps that should be performed on :math:`A`.
``permutation`` is an alternative representation of the sequence of row
swaps as a permutation, represented as an int32 array with shape
``[..., m]``.
"""
lu, pivots, permutation = lu_p.bind(x)
return lu, pivots, permutation
@_warn_on_positional_kwargs
def qr(x: ArrayLike, *, full_matrices: bool = True) -> tuple[Array, Array]:
"""QR decomposition.
Computes the QR decomposition
.. math::
A = Q . R
of matrices :math:`A`, such that :math:`Q` is a unitary (orthogonal) matrix,
and :math:`R` is an upper-triangular matrix.
Args:
x: A batch of matrices with shape ``[..., m, n]``.
full_matrices: Determines if full or reduced matrices are returned; see
below.
Returns:
A pair of arrays ``(q, r)``.
Array ``q`` is a unitary (orthogonal) matrix,
with shape ``[..., m, m]`` if ``full_matrices=True``, or
``[..., m, min(m, n)]`` if ``full_matrices=False``.
Array ``r`` is an upper-triangular matrix with shape ``[..., m, n]`` if
``full_matrices=True``, or ``[..., min(m, n), n]`` if
``full_matrices=False``.
"""
q, r = qr_p.bind(x, full_matrices=full_matrices)
return q, r
@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[True]) -> tuple[Array, Array, Array]: ...
@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: Literal[False]) -> Array: ...
@overload
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, tuple[Array, Array, Array]]: ...
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
@_warn_on_positional_kwargs
def svd(x: ArrayLike, *, full_matrices: bool = True, compute_uv: bool = True) -> Union[Array, tuple[Array, Array, Array]]:
"""Singular value decomposition.
Returns the singular values if compute_uv is False, otherwise returns a triple
containing the left singular vectors, the singular values and the adjoint of
the right singular vectors.
"""
result = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv)
if compute_uv:
s, u, v = result
return u, s, v
else:
s, = result
return s
@_warn_on_positional_kwargs
def triangular_solve(a: ArrayLike, b: ArrayLike, *,
left_side: bool = False, lower: bool = False,
transpose_a: bool = False, conjugate_a: bool = False,
unit_diagonal: bool = False) -> Array:
r"""Triangular solve.
Solves either the matrix equation
.. math::
\mathit{op}(A) . X = B
if ``left_side`` is ``True`` or
.. math::
X . \mathit{op}(A) = B
if ``left_side`` is ``False``.
``A`` must be a lower or upper triangular square matrix, and where
:math:`\mathit{op}(A)` may either transpose :math:`A` if ``transpose_a``
is ``True`` and/or take its complex conjugate if ``conjugate_a`` is ``True``.
Args:
a: A batch of matrices with shape ``[..., m, m]``.
b: A batch of matrices with shape ``[..., m, n]`` if ``left_side`` is
``True`` or shape ``[..., n, m]`` otherwise.
left_side: describes which of the two matrix equations to solve; see above.
lower: describes which triangle of ``a`` should be used. The other triangle
is ignored.
transpose_a: if ``True``, the value of ``a`` is transposed.
conjugate_a: if ``True``, the complex conjugate of ``a`` is used in the
solve. Has no effect if ``a`` is real.
unit_diagonal: if ``True``, the diagonal of ``a`` is assumed to be unit
(all 1s) and not accessed.
Returns:
A batch of matrices the same shape and dtype as ``b``.
"""
conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a), jnp.complexfloating)
singleton = jnp.ndim(b) == jnp.ndim(a) - 1
if singleton:
b = jnp.expand_dims(b, -1 if left_side else -2)
out = triangular_solve_p.bind(
a, b, left_side=left_side, lower=lower, transpose_a=transpose_a,
conjugate_a=conjugate_a, unit_diagonal=unit_diagonal)
if singleton:
out = out[..., 0] if left_side else out[..., 0, :]
return out
# utilities
@partial(vectorize, signature='(n,m),(m)->(n)')
def _matvec_multiply(a: Array, b: Array) -> Array:
return lax.dot(a, b, precision=lax.Precision.HIGHEST)
def _check_solve_shapes(a: Array, b: Array):
if not (a.ndim >= 2 and b.ndim in [a.ndim, a.ndim - 1] and
a.shape[-1] == a.shape[-2] == b.shape[a.ndim - 2]):
raise ValueError(
"The arguments to solve must have shapes a=[..., m, m] and "
f"b=[..., m, k] or b=[..., m]; got a={a.shape} and b={b.shape}")
def _solve(a: Array, b: Array) -> Array:
_check_solve_shapes(a, b)
# Broadcast leading dimensions of b to the shape of a, as is required by
# custom_linear_solve.
out_shape = tuple(d_a if d_b == 1 else d_b
for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape))
b = jnp.broadcast_to(b, out_shape)
# With custom_linear_solve, we can reuse the same factorization when
# computing sensitivities. This is considerably faster.
lu_, _, permutation = lu(lax.stop_gradient(a))
custom_solve = partial(
lax.custom_linear_solve,
lambda x: _matvec_multiply(a, x),
solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0),
transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1))
if a.ndim == b.ndim + 1:
# b.shape == [..., m]
return custom_solve(b)
else:
# b.shape == [..., m, k]
return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def _T(x: Array) -> Array: return jnp.swapaxes(x, -1, -2)
def _H(x: Array) -> Array: return ufuncs.conj(_T(x))
def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
# primitives
_cpu_lapack_types = {np.dtype(np.float32), np.dtype(np.float64),
np.dtype(np.complex64), np.dtype(np.complex128)}
# Cholesky decomposition
def _cholesky_jvp_rule(primals, tangents):
x, = primals
sigma_dot, = tangents
L = jnp.tril(cholesky_p.bind(x))
# Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf
def phi(X):
l = jnp.tril(X)
return l / lax.expand_dims(
lax_internal._const(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype),
range(l.ndim - 2))
tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True,
conjugate_a=True, lower=True)
L_dot = lax.batch_matmul(L, phi(triangular_solve(
L, tmp, left_side=True, transpose_a=False, lower=True)),
precision=lax.Precision.HIGHEST)
return L, L_dot
def _cholesky_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return cholesky(x), 0
cholesky_p = standard_unop(_float | _complex, 'cholesky')
ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule
batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule
def _cholesky_lowering(ctx, x):
return hlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
mlir.register_lowering(cholesky_p, _cholesky_lowering)
def _cholesky_cpu_lowering(ctx, operand):
operand_aval, = ctx.avals_in
out_aval, = ctx.avals_out
batch_dims = operand_aval.shape[:-2]
if jaxlib_version < (0, 4, 13):
if not is_constant_shape(operand_aval.shape):
raise NotImplementedError(
"Shape polymorphism for native serialization for cholesky on CPU is "
f"not implemented; b/261671778; {operand_aval.shape}")
result, info = lapack.potrf_hlo(operand_aval.dtype, operand, lower=True) # type: ignore
else:
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
result, info = lapack.potrf_hlo(operand_aval.dtype, operand, lower=True,
a_shape_vals=op_shape_vals)
ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
"EQ", "SIGNED")
select_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
return [_broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok,
select_aval,
broadcast_dimensions=range(len(batch_dims))),
select_aval,
result, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)]
mlir.register_lowering(
cholesky_p, _cholesky_cpu_lowering, platform='cpu')
# Asymmetric eigendecomposition
def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors):
return dispatch.apply_primitive(
eig_p,
operand,
compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors,
)
def eig_lower(*args, **kw):
raise NotImplementedError(
"Nonsymmetric eigendecomposition is only implemented on the CPU backend. "
"If your matrix is symmetric or Hermitian, you should use eigh instead.")
def eig_abstract_eval(operand, *, compute_left_eigenvectors,
compute_right_eigenvectors):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError("Argument to nonsymmetric eigendecomposition must have "
"shape [..., n, n], got shape {}".format(operand.shape))
batch_dims = operand.shape[:-2]
n = operand.shape[-1]
dtype = np.complex64 if dtypes.finfo(operand.dtype).bits == 32 else np.complex128
dtype = dtypes.canonicalize_dtype(dtype)
vl = vr = operand.update(shape=batch_dims + (n, n), dtype=dtype)
w = operand.update(shape=batch_dims + (n,), dtype=dtype)
else:
raise NotImplementedError
output = [w]
if compute_left_eigenvectors:
output.append(vl)
if compute_right_eigenvectors:
output.append(vr)
return tuple(output)
def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
compute_right_eigenvectors):
operand_aval, = ctx.avals_in
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
if jaxlib_version < (0, 4, 13):
if any(not is_constant_shape(a.shape) for a in ctx.avals_in):
raise NotImplementedError(
"Shape polymorphism for eig is not implemented. "
"Try upgrading jaxlib")
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand, # type: ignore
jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)
else:
if jaxlib_version < (0, 4, 14):
op_shape_vals = mlir.eval_dynamic_shape_as_vals(ctx, operand_aval.shape)
else:
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand,
input_shape_vals=op_shape_vals,
jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)
ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
"EQ", "SIGNED")
select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
w = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_w_aval,
broadcast_dimensions=range(len(batch_dims))),
select_w_aval,
w, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)
output = [w]
if compute_left_eigenvectors:
aval = ctx.avals_out[len(output)]
select_vl_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
vl = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_vl_aval,
broadcast_dimensions=range(len(batch_dims))),
select_vl_aval,
vl, aval, _nan_like_hlo(ctx, aval), aval)
output.append(vl)
if compute_right_eigenvectors:
aval = ctx.avals_out[len(output)]
select_vr_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
vr = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_vr_aval,
broadcast_dimensions=range(len(batch_dims))),
select_vr_aval,
vr, aval, _nan_like_hlo(ctx, aval), aval)
output.append(vr)
return output
def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors,
compute_right_eigenvectors):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return (eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors),
(0,) * (1 + compute_left_eigenvectors + compute_right_eigenvectors))
def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
compute_right_eigenvectors):
if compute_left_eigenvectors or compute_right_eigenvectors:
raise NotImplementedError(
'The derivatives of eigenvectors are not implemented, only '
'eigenvalues. See '
'https://github.com/google/jax/issues/2748 for discussion.')
# Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in
# https://arxiv.org/abs/1701.00392
a, = primals
da, = tangents
l, v = eig(a, compute_left_eigenvectors=False)
return [l], [reductions.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)]
eig_p = Primitive('eig')
eig_p.multiple_results = True
eig_p.def_impl(eig_impl)
eig_p.def_abstract_eval(eig_abstract_eval)
mlir.register_lowering(eig_p, eig_lower)
mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu')
batching.primitive_batchers[eig_p] = eig_batching_rule
ad.primitive_jvps[eig_p] = eig_jvp_rule
# Symmetric/Hermitian eigendecomposition
def eigh_jacobi(x: ArrayLike, *, lower: bool = True,
sort_eigenvalues: bool = True) -> tuple[Array, Array]:
"""Helper Jacobi eigendecomposition implemented by XLA.
Used as a subroutine of QDWH-eig on TPU."""
w, v = eigh_jacobi_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues)
return w, v
def _eigh_jacobi_impl(operand, *, lower, sort_eigenvalues):
w, v = dispatch.apply_primitive(eigh_jacobi_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return w, v
def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError(
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
"got shape {}".format(operand.shape))
batch_dims = operand.shape[:-2]
n = operand.shape[-1]
w = operand.update(shape=batch_dims + (n,),
dtype=lax_internal._complex_basetype(operand.dtype))
v = operand.update(shape=batch_dims + (n, n))
else:
w, v = operand, operand
return w, v
def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues):
operand_aval, = ctx.avals_in
if operand_aval.shape[-1] == 0:
reshape_aval = operand_aval.update(shape=operand_aval.shape[:-1])
return [
hlo.RealOp(mlir.reshape(ctx, operand, reshape_aval)).result,
operand,
]
eigvals_type = mlir.aval_to_ir_type(ctx.avals_out[0])
eigvecs_type = mlir.aval_to_ir_type(ctx.avals_out[1])
result_types = [eigvecs_type, eigvals_type]
backend_config = f"{int(lower)},{int(sort_eigenvalues)},100,1e-6"
if any(not is_constant_shape(aval_out.shape)
for aval_out in ctx.avals_out):
result_shapes = [
mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
# The custom call returns the results swapped
for aval_out in list(reversed(ctx.avals_out))
]
else:
result_shapes = None
op = mlir.custom_call(
"Eigh",
result_types,
[operand],
backend_config=backend_config,
api_version=1,
result_shapes=result_shapes,
)
return op.results[1], op.results[0]
eigh_jacobi_p = Primitive('eigh_jacobi')
eigh_jacobi_p.multiple_results = True
eigh_jacobi_p.def_impl(_eigh_jacobi_impl)
eigh_jacobi_p.def_abstract_eval(_eigh_jacobi_abstract_eval)
mlir.register_lowering(eigh_jacobi_p, _eigh_jacobi_lowering_rule)
def _eigh_impl(operand, *, lower, sort_eigenvalues):
v, w = dispatch.apply_primitive(eigh_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return v, w
def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
if isinstance(operand, ShapedArray):
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
raise ValueError(
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
"got shape {}".format(operand.shape))
batch_dims = operand.shape[:-2]
n = operand.shape[-1]
v = operand.update(shape=batch_dims + (n, n))
w = operand.update(shape=batch_dims + (n,),
dtype=lax_internal._complex_basetype(operand.dtype))
else:
v, w = operand, operand
return v, w
def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
sort_eigenvalues):
del sort_eigenvalues # The CPU/GPU implementations always sort.
operand_aval, = ctx.avals_in
v_aval, w_aval = ctx.avals_out
batch_dims = operand_aval.shape[:-2]
# The eigh implementation on CPU and GPU uses lapack helper routines to
# find the size of the workspace based on the non-batch dimensions.
# Therefore, we cannot yet support dynamic non-batch dimensions.
if not is_constant_shape(operand_aval.shape[-2:]):
raise NotImplementedError(
"Shape polymorphism for for native lowering for eigh is implemented "
f"only for the batch dimensions: {operand_aval.shape}")
if jaxlib_version < (0, 4, 14):
batch_size_num = math.prod(batch_dims) if batch_dims else 1
batch_size = mlir.eval_dynamic_shape(ctx, (batch_size_num,))[0]
if isinstance(batch_size, int):
batch_size = mlir.ir_constant(np.int32(batch_size))
v_shape: ir.Value = mlir.eval_dynamic_shape_as_tensor(ctx, v_aval.shape)
w_shape: ir.Value = mlir.eval_dynamic_shape_as_tensor(ctx, w_aval.shape)
info_shape: ir.Value = mlir.eval_dynamic_shape_as_tensor(ctx, batch_dims)
v, w, info = syevd_impl(operand_aval.dtype, operand, batch_size,
v_shape, w_shape, info_shape,
lower=lower)
else:
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
v, w, info = syevd_impl(operand_aval.dtype, operand,
a_shape_vals=op_shape_vals, lower=lower)
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
select_v_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
v = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_v_aval,
broadcast_dimensions=range(len(batch_dims))),
select_v_aval,
v, v_aval, _nan_like_hlo(ctx, v_aval), v_aval)
select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
w = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_w_aval,
broadcast_dimensions=range(len(batch_dims))),
select_w_aval,
w, w_aval, _nan_like_hlo(ctx, w_aval), w_aval)
return [v, w]
def _eigh_tpu_impl(x, *, lower, sort_eigenvalues):
*_, m, n = x.shape
assert m == n, (m, n)
termination_size = 256
if not is_constant_dim(m):
# TODO: maybe we can relax the check below for shape polymorphism?
raise NotImplementedError(
"Shape polymorphism for for native lowering for eigh is implemented "
f"only for the batch dimensions: {x.shape}")
if m <= termination_size:
eig_vals, eig_vecs = eigh_jacobi(x, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return eig_vecs, eig_vals
def eigh_qdwh(x):
if len(x.shape) > 2:
return control_flow.map(eigh_qdwh, x)
# We should only look at elements from the lower/upper triangle. Reflects
# that triangle into the other triangle to form a Hermitian matrix.
if lower:
mask = jnp.tri(n, k=0, dtype=bool)
else:
mask = ufuncs.logical_not(jnp.tri(n, k=-1, dtype=bool))
if dtypes.issubdtype(x.dtype, jnp.complexfloating):
re = lax.select(mask, lax.real(x), _T(lax.real(x)))
if lower:
im_mask = jnp.tri(n, k=-1, dtype=bool)
else:
im_mask = ufuncs.logical_not(jnp.tri(n, k=0, dtype=bool))
im = lax.select(im_mask, lax.imag(x), jnp.zeros_like(lax.imag(x)))
im = lax.select(mask, im, -_T(im))
x = lax.complex(re, im)
else:
x = lax.select(mask, x, _T(x))
return lax_eigh.eigh(x, sort_eigenvalues=sort_eigenvalues,
termination_size=termination_size)
eig_vals, eig_vecs = eigh_qdwh(x)
return eig_vecs, eig_vals
def _eigh_jvp_rule(primals, tangents, *, lower, sort_eigenvalues):
# Derivative for eigh in the simplest case of distinct eigenvalues.
# This is classic nondegenerate perurbation theory, but also see
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
# The general solution treating the case of degenerate eigenvalues is
# considerably more complicated. Ambitious readers may refer to the general
# methods below or refer to degenerate perturbation theory in physics.
# https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
# https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
a, = primals
a_dot, = tangents
v, w_real = eigh_p.bind(symmetrize(a), lower=lower,
sort_eigenvalues=sort_eigenvalues)
# for complex numbers we need eigenvalues to be full dtype of v, a:
w = w_real.astype(a.dtype)
eye_n = jnp.eye(a.shape[-1], dtype=a.dtype)
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
Fmat = ufuncs.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n
# eigh impl doesn't support batch dims, but future-proof the grad.
dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
precision=lax.Precision.HIGHEST)
vdag_adot_v = dot(dot(_H(v), a_dot), v)
dv = dot(v, ufuncs.multiply(Fmat, vdag_adot_v))
dw = ufuncs.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1))
return (v, w_real), (dv, dw)
def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return eigh_p.bind(x, lower=lower, sort_eigenvalues=sort_eigenvalues), (0, 0)
eigh_p = Primitive('eigh')
eigh_p.multiple_results = True
eigh_p.def_impl(_eigh_impl)
eigh_p.def_abstract_eval(_eigh_abstract_eval)
ad.primitive_jvps[eigh_p] = _eigh_jvp_rule
batching.primitive_batchers[eigh_p] = _eigh_batching_rule
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo),
platform='cpu')
if gpu_solver is not None:
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd),
platform='cuda')
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd),
platform='rocm')
mlir.register_lowering(
eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True),
platform='tpu')
_triangular_solve_dtype_rule = partial(
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
'triangular_solve')
def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
if a.ndim < 2:
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
raise TypeError(msg.format(a.ndim))
if b.ndim < 2:
msg = "triangular_solve requires b.ndim to be at least 2, got {}."
raise TypeError(msg.format(b.ndim))
if a.shape[-1] != a.shape[-2]:
msg = ("triangular_solve requires the last two dimensions of a to be equal "
"in size, got a.shape of {}.")
raise TypeError(msg.format(a.shape))
if a.shape[:-2] != b.shape[:-2]:
msg = ("triangular_solve requires both arguments to have the same number "
"of dimensions and equal batch dimensions, got {} and {}.")
raise TypeError(msg.format(a.shape, b.shape))
common_dim = -2 if left_side else -1
if a.shape[-1] != b.shape[common_dim]:
msg = "Incompatible shapes for arguments to triangular_solve: {} and {}."
raise TypeError(msg.format(a.shape, b.shape))
return b.shape
def _triangular_solve_jvp_rule_a(
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
unit_diagonal):
m, n = b.shape[-2:]
k = 1 if unit_diagonal else 0
g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k)
g_a = lax.neg(g_a)
g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a
g_a = ufuncs.conj(g_a) if conjugate_a else g_a
dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul,
precision=lax.Precision.HIGHEST)
def a_inverse(rhs):
return triangular_solve(a, rhs, left_side=left_side, lower=lower,
transpose_a=transpose_a, conjugate_a=conjugate_a,
unit_diagonal=unit_diagonal)
# triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
# for matrix/vector inputs). Order these operations in whichever order is
# cheaper.
if left_side:
assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (m, n)
if m > n:
return a_inverse(dot(g_a, ans)) # A^{-1} (∂A X)
else:
return dot(a_inverse(g_a), ans) # (A^{-1} ∂A) X
else:
assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (m, n)
if m < n:
return a_inverse(dot(ans, g_a)) # (X ∂A) A^{-1}
else:
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
def _triangular_solve_transpose_rule(
cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a,
unit_diagonal):
# Triangular solve is nonlinear in its first argument and linear in its second
# argument, analogous to `div` but swapped.
assert not ad.is_undefined_primal(a) and ad.is_undefined_primal(b)
if type(cotangent) is ad_util.Zero:
cotangent_b = ad_util.Zero(b.aval)
else:
cotangent_b = triangular_solve(a, cotangent, left_side=left_side,
lower=lower, transpose_a=not transpose_a,
conjugate_a=conjugate_a,
unit_diagonal=unit_diagonal)
return [None, cotangent_b]
def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
lower, transpose_a, conjugate_a,
unit_diagonal):
x, y = batched_args
bx, by = batch_dims
if bx is batching.not_mapped:
if left_side:
y = batching.moveaxis(y, by, -1)
y_flat = y.reshape(y.shape[:-2] + (y.shape[-2] * y.shape[-1],))
bdim_out = y.ndim - 1
else:
y = batching.moveaxis(y, by, -2)
y_flat = y.reshape(y.shape[:-3] + (y.shape[-3] * y.shape[-2], y.shape[-1]))
bdim_out = y.ndim - 2
out_flat = triangular_solve(
x, y_flat, left_side=left_side, lower=lower,
transpose_a=transpose_a, conjugate_a=conjugate_a,
unit_diagonal=unit_diagonal)
return out_flat.reshape(y.shape), bdim_out
else:
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims)
if i is not None)
x = batching.bdim_at_front(x, bx, size)
y = batching.bdim_at_front(y, by, size)
return triangular_solve(x, y, left_side=left_side, lower=lower,
transpose_a=transpose_a, conjugate_a=conjugate_a,
unit_diagonal=unit_diagonal), 0
triangular_solve_p = standard_primitive(
_triangular_solve_shape_rule, _triangular_solve_dtype_rule,
'triangular_solve')
ad.defjvp2(triangular_solve_p,
_triangular_solve_jvp_rule_a,
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
ad.primitive_transposes[triangular_solve_p] = _triangular_solve_transpose_rule
batching.primitive_batchers[triangular_solve_p] = _triangular_solve_batching_rule
def _triangular_solve_lowering(
ctx, a, b, *, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
out_aval, = ctx.avals_out
if conjugate_a and not transpose_a:
a = chlo.ConjOp(a)
conjugate_a = False
if not transpose_a:
transpose = "NO_TRANSPOSE"
else:
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
return hlo.TriangularSolveOp(
a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose)).results
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
def _triangular_solve_cpu_lower(
ctx, a, b, *, left_side, lower, transpose_a,
conjugate_a, unit_diagonal):
a_aval, b_aval = ctx.avals_in
if conjugate_a and not transpose_a:
a = chlo.ConjOp(a).result
conjugate_a = False
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
if jaxlib_version < (0, 4, 14):
return [lapack.trsm_hlo(
a_aval.dtype, alpha,
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)] # type: ignore
else:
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
return [lapack.trsm_hlo(
a_aval.dtype, alpha,
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
b_shape_vals=b_shape_vals)]
else:
# Fall back to the HLO implementation for unsupported types or batching.
# TODO: Consider swapping XLA for LAPACK in batched case
if transpose_a:
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
else:
transpose = "NO_TRANSPOSE"
return hlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower),
ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose)).results
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
platform='cpu')
# Support operation for LU decomposition: Transformation of the pivots returned
# by LU decomposition into permutations.