/
linalg.py
176 lines (140 loc) · 5.93 KB
/
linalg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import operator
import numpy as np
import jax.numpy as jnp
from jax import lax, device_put
from jax.tree_util import tree_leaves, tree_map, tree_multimap, tree_structure
from jax.util import safe_map as map
def _vdot_real_part(x, y):
"""Vector dot-product guaranteed to have a real valued result."""
# all our uses of vdot() in CG are for computing an operator of the form
# `z^T M z` where `M` is positive definite and Hermitian, so the result is
# real valued:
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
result = vdot(x.real, y.real)
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
result += vdot(x.imag, y.imag)
return result
# aliases for working with pytrees
def _vdot_tree(x, y):
return sum(tree_leaves(tree_multimap(_vdot_real_part, x, y)))
def _mul(scalar, tree):
return tree_map(partial(operator.mul, scalar), tree)
_add = partial(tree_multimap, operator.add)
_sub = partial(tree_multimap, operator.sub)
def _identity(x):
return x
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
bs = _vdot_tree(b, b)
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
def cond_fun(value):
x, r, gamma, p, k = value
rs = gamma if M is _identity else _vdot_tree(r, r)
return (rs > atol2) & (k < maxiter)
def body_fun(value):
x, r, gamma, p, k = value
Ap = A(p)
alpha = gamma / _vdot_tree(p, Ap)
x_ = _add(x, _mul(alpha, p))
r_ = _sub(r, _mul(alpha, Ap))
z_ = M(r_)
gamma_ = _vdot_tree(r_, z_)
beta_ = gamma_ / gamma
p_ = _add(z_, _mul(beta_, p))
return x_, r_, gamma_, p_, k + 1
r0 = _sub(b, A(x0))
p0 = z0 = M(r0)
gamma0 = _vdot_tree(r0, z0)
initial_value = (x0, r0, gamma0, p0, 0)
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
return x_final
def _shapes(pytree):
return map(jnp.shape, tree_leaves(pytree))
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
numerical precision), but note that the interface is slightly different: you
need to supply the linear operator ``A`` as a function instead of a sparse
matrix or ``LinearOperator``.
Derivatives of ``cg`` are implemented via implicit differentiation with
another ``cg`` solve, rather than by differentiating *through* the solver.
They will be accurate only if both solves converge.
Parameters
----------
A : function
Function that calculates the matrix-vector product ``Ax`` when called
like ``A(x)``. ``A`` must represent a hermitian, positive definite
matrix, and must return array(s) with the same structure and shape as its
argument.
b : array or tree of arrays
Right hand side of the linear system representing a single vector. Can be
stored as an array or Python container of array(s) with any shape.
Returns
-------
x : array or tree of arrays
The converged solution. Has the same structure as ``b``.
info : None
Placeholder for convergence information. In the future, JAX will report
the number of iterations when convergence is not achieved, like SciPy.
Other Parameters
----------------
x0 : array
Starting guess for the solution. Must have the same structure as ``b``.
tol, atol : float, optional
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
maxiter : integer
Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
M : function
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance.
See also
--------
scipy.sparse.linalg.cg
jax.lax.custom_linear_solve
"""
if x0 is None:
x0 = tree_map(jnp.zeros_like, b)
b, x0 = device_put((b, x0))
if maxiter is None:
size = sum(bi.size for bi in tree_leaves(b))
maxiter = 10 * size # copied from scipy
if M is None:
M = _identity
if tree_structure(x0) != tree_structure(b):
raise ValueError(
'x0 and b must have matching tree structure: '
f'{tree_structure(x0)} vs {tree_structure(b)}')
if _shapes(x0) != _shapes(b):
raise ValueError(
'arrays in x0 and b must have matching shapes: '
f'{_shapes(x0)} vs {_shapes(b)}')
cg_solve = partial(
_cg_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)
# real-valued positive-definite linear operators are symmetric
real_valued = lambda x: not issubclass(x.dtype.type, np.complexfloating)
symmetric = all(map(real_valued, tree_leaves(b)))
x = lax.custom_linear_solve(
A, b, solve=cg_solve, transpose_solve=cg_solve, symmetric=symmetric)
info = None # TODO(shoyer): return the real iteration count here
return x, info