Skip to content

Commit

Permalink
Merge branch 'master' into new_charge_encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
mganahl committed Jul 8, 2020
2 parents 301b42a + bb440f1 commit 31d80c3
Show file tree
Hide file tree
Showing 4 changed files with 466 additions and 121 deletions.
215 changes: 184 additions & 31 deletions tensornetwork/backends/jax/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# pylint: disable=abstract-method

_CACHED_MATVECS = {}
_CACHED_FUNCTIONS = {}


class JaxBackend(abstract_backend.AbstractBackend):
Expand Down Expand Up @@ -243,15 +244,15 @@ def eigs(self,
which: Text = 'LR',
maxiter: int = 20) -> Tuple[Tensor, List]:
"""
Implicitly restarted Arnoldi method for finding the lowest
eigenvector-eigenvalue pairs of a linear operator `A`.
Implicitly restarted Arnoldi method for finding the lowest
eigenvector-eigenvalue pairs of a linear operator `A`.
`A` is a function implementing the matrix-vector
product.
product.
WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered
at the first invocation of `eigs`, and on any subsequent calls
if the python `id` of `A` changes, even if the formal definition of `A`
stays the same.
at the first invocation of `eigs`, and on any subsequent calls
if the python `id` of `A` changes, even if the formal definition of `A`
stays the same.
Example: the following will jit once at the beginning, and then never again:
```python
Expand All @@ -265,7 +266,7 @@ def A(H,x):
res = eigs(A, [H],x) #jitting is triggerd only at `n=0`
```
The following code triggers jitting at every iteration, which
The following code triggers jitting at every iteration, which
results in considerably reduced performance
```python
Expand All @@ -278,7 +279,7 @@ def A(H,x):
x = jax.np.array(np.random.rand(10,10))
res = eigs(A, [H],x) #jitting is triggerd at every step `n`
```
Args:
A: A (sparse) implementation of a linear operator.
Call signature of `A` is `res = A(vector, *args)`, where `vector`
Expand All @@ -293,13 +294,13 @@ def A(H,x):
num_krylov_vecs: The number of iterations (number of krylov vectors).
numeig: The number of eigenvector-eigenvalue pairs to be computed.
tol: The desired precision of the eigenvalues. For the jax backend
this has currently no effect, and precision of eigenvalues is not
this has currently no effect, and precision of eigenvalues is not
guaranteed. This feature may be added at a later point. To increase
precision the caller can either increase `maxiter` or `num_krylov_vecs`.
which: Flag for targetting different types of eigenvalues. Currently
supported are `which = 'LR'` (larges real part) and `which = 'LM'`
which: Flag for targetting different types of eigenvalues. Currently
supported are `which = 'LR'` (larges real part) and `which = 'LM'`
(larges magnitude).
maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes
maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes
equivalent to a simple Arnoldi method.
Returns:
(eigvals, eigvecs)
Expand All @@ -326,11 +327,12 @@ def A(H,x):
type(initial_state)))
if A not in _CACHED_MATVECS:
_CACHED_MATVECS[A] = libjax.tree_util.Partial(libjax.jit(A))
if not hasattr(self, '_iram'):
# pylint: disable=attribute-defined-outside-init
self._iram = jitted_functions._implicitly_restarted_arnoldi(libjax)
return self._iram(_CACHED_MATVECS[A], args, initial_state, num_krylov_vecs,
numeig, which, tol, maxiter)
if "imp_arnoldi" not in _CACHED_FUNCTIONS:
imp_arnoldi = jitted_functions._implicitly_restarted_arnoldi(libjax)
_CACHED_FUNCTIONS["imp_arnoldi"] = imp_arnoldi
return _CACHED_FUNCTIONS["imp_arnoldi"](_CACHED_MATVECS[A], args,
initial_state, num_krylov_vecs,
numeig, which, tol, maxiter)

def eigsh_lanczos(
self,
Expand All @@ -347,12 +349,12 @@ def eigsh_lanczos(
reorthogonalize: Optional[bool] = False) -> Tuple[Tensor, List]:
"""
Lanczos method for finding the lowest eigenvector-eigenvalue pairs
of a hermitian linear operator `A`. `A` is a function implementing
the matrix-vector product.
of a hermitian linear operator `A`. `A` is a function implementing
the matrix-vector product.
WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered
at the first invocation of `eigsh_lanczos`, and on any subsequent calls
if the python `id` of `A` changes, even if the formal definition of `A`
stays the same.
at the first invocation of `eigsh_lanczos`, and on any subsequent calls
if the python `id` of `A` changes, even if the formal definition of `A`
stays the same.
Example: the following will jit once at the beginning, and then never again:
```python
Expand All @@ -366,7 +368,7 @@ def A(H,x):
res = eigsh_lanczos(A, [H],x) #jitting is triggerd only at `n=0`
```
The following code triggers jitting at every iteration, which
The following code triggers jitting at every iteration, which
results in considerably reduced performance
```python
Expand All @@ -379,7 +381,7 @@ def A(H,x):
x = jax.np.array(np.random.rand(10,10))
res = eigsh_lanczos(A, [H],x) #jitting is triggerd at every step `n`
```
Args:
A: A (sparse) implementation of a linear operator.
Call signature of `A` is `res = A(vector, *args)`, where `vector`
Expand All @@ -395,7 +397,7 @@ def A(H,x):
numeig: The number of eigenvector-eigenvalue pairs to be computed.
If `numeig > 1`, `reorthogonalize` has to be `True`.
tol: The desired precision of the eigenvalues. For the jax backend
this has currently no effect, and precision of eigenvalues is not
this has currently no effect, and precision of eigenvalues is not
guaranteed. This feature may be added at a later point.
To increase precision the caller can increase `num_krylov_vecs`.
delta: Stopping criterion for Lanczos iteration.
Expand All @@ -404,7 +406,7 @@ def A(H,x):
is stopped. It means that an (approximate) invariant subspace has
been found.
ndiag: The tridiagonal Operator is diagonalized every `ndiag` iterations
to check convergence. This has currently no effect for the jax backend,
to check convergence. This has currently no effect for the jax backend,
but may be added at a later point.
reorthogonalize: If `True`, Krylov vectors are kept orthogonal by
explicit orthogonalization (more costly than `reorthogonalize=False`)
Expand Down Expand Up @@ -433,12 +435,163 @@ def A(H,x):
type(initial_state)))
if A not in _CACHED_MATVECS:
_CACHED_MATVECS[A] = libjax.tree_util.Partial(A)
#if not hasattr(self, '_jaxlan'):
# pylint: disable=attribute-defined-outside-init
jaxlan = jitted_functions._generate_jitted_eigsh_lanczos(libjax)

if "eigsh_lanczos" not in _CACHED_FUNCTIONS:
eigsh_lanczos = jitted_functions._generate_jitted_eigsh_lanczos(libjax)
_CACHED_FUNCTIONS["eigsh_lanczos"] = eigsh_lanczos
eigsh_lanczos = _CACHED_FUNCTIONS["eigsh_lanczos"]
return eigsh_lanczos(_CACHED_MATVECS[A], args, initial_state,
num_krylov_vecs, numeig, delta, reorthogonalize)

def gmres(self,
A_mv: Callable,
b: Tensor,
A_args: Optional[List] = None,
A_kwargs: Optional[dict] = None,
x0: Optional[Tensor] = None,
tol: float = 1E-05,
atol: Optional[float] = None,
num_krylov_vectors: Optional[int] = None,
maxiter: Optional[int] = 1,
M: Optional[Callable] = None
) -> Tuple[Tensor, int]:
""" GMRES solves the linear system A @ x = b for x given a vector `b` and
a general (not necessarily symmetric/Hermitian) linear operator `A`.
As a Krylov method, GMRES does not require a concrete matrix representation
of the n by n `A`, but only a function
`vector1 = A_mv(vector0, *A_args, **A_kwargs)`
prescribing a one-to-one linear map from vector0 to vector1 (that is,
A must be square, and thus vector0 and vector1 the same size). If `A` is a
dense matrix, or if it is a symmetric/Hermitian operator, a different
linear solver will usually be preferable.
GMRES works by first constructing the Krylov basis
K = (x0, A_mv@x0, A_mv@A_mv@x0, ..., (A_mv^num_krylov_vectors)@x_0) and then
solving a certain dense linear system K @ q0 = q1 from whose solution x can
be approximated. For `num_krylov_vectors = n` the solution is provably exact
in infinite precision, but the expense is cubic in `num_krylov_vectors` so
one is typically interested in the `num_krylov_vectors << n` case.
The solution can in this case be repeatedly
improved, to a point, by restarting the Arnoldi iterations each time
`num_krylov_vectors` is reached. Unfortunately the optimal parameter choices
balancing expense and accuracy are difficult to predict in advance, so
applying this function requires a degree of experimentation.
In a tensor network code one is typically interested in A_mv implementing
some tensor contraction. This implementation thus allows `b` and `x0` to be
of whatever arbitrary, though identical, shape `b = A_mv(x0, ...)` expects.
Reshaping to and from a matrix problem is handled internally.
The Jax backend version of GMRES uses a homemade implementation that, for
now, is suboptimal for num_krylov_vecs ~ b.size.
For the same reason as described in eigsh_lancsoz, the function A_mv
should be Jittable (or already Jitted) and, if at all possible, defined
only once at the global scope. A new compilation will be triggered each
time an A_mv with a new function signature is passed in, even if the
'new' function is identical to the old one (function identity is
undecidable).
Args:
A_mv : A function `v0 = A_mv(v, *A_args, **A_kwargs)` where `v0` and
`v` have the same shape.
b : The `b` in `A @ x = b`; it should be of the shape `A_mv`
operates on.
A_args : Positional arguments to `A_mv`, supplied to this interface
as a list.
Default: None.
A_kwargs : In the other backends, keyword arguments to `A_mv`, supplied
as a dictionary. However, the Jax backend does not support
A_mv accepting
keyword arguments since this causes problems with Jit.
Therefore, an error is thrown if A_kwargs is specified.
Default: None.
x0 : An optional guess solution. Zeros are used by default.
If `x0` is supplied, its shape and dtype must match those of
`b`, or an
error will be thrown.
Default: zeros.
tol, atol: Solution tolerance to achieve,
norm(residual) <= max(tol*norm(b), atol).
Default: tol=1E-05
atol=tol
num_krylov_vectors
: Size of the Krylov space to build at each restart.
Expense is cubic in this parameter. If supplied, it must be
an integer in 0 < num_krylov_vectors <= b.size.
Default: b.size.
maxiter : The Krylov space will be repeatedly rebuilt up to this many
times. Large values of this argument
should be used only with caution, since especially for nearly
symmetric matrices and small `num_krylov_vectors` convergence
might well freeze at a value significantly larger than `tol`.
Default: 1
M : Inverse of the preconditioner of A; see the docstring for
`scipy.sparse.linalg.gmres`. This is unsupported in the Jax
backend, and NotImplementedError will be raised if it is
supplied.
Default: None.
Raises:
ValueError: -if `x0` is supplied but its shape differs from that of `b`.
-if num_krylov_vectors is 0 or exceeds b.size.
-if tol or atol was negative.
NotImplementedError: - If M is supplied.
- If A_kwargs is supplied.
Returns:
x : The converged solution. It has the same shape as `b`.
info : 0 if convergence was achieved, the number of restarts otherwise.
"""

return jaxlan(_CACHED_MATVECS[A], args, initial_state,
num_krylov_vecs, numeig, delta, reorthogonalize)
if x0 is not None and x0.shape != b.shape:
errstring = (f"If x0 is supplied, its shape, {x0.shape}, must match b's"
f", {b.shape}.")
raise ValueError(errstring)
if x0 is not None and x0.dtype != b.dtype:
errstring = (f"If x0 is supplied, its dtype, {x0.dtype}, must match b's"
f", {b.dtype}.")
raise ValueError(errstring)
if num_krylov_vectors is None:
num_krylov_vectors = b.size
if num_krylov_vectors <= 0 or num_krylov_vectors > b.size:
errstring = (f"num_krylov_vectors must be in "
f"0 < {num_krylov_vectors} <= {b.size}.")
raise ValueError(errstring)
if tol < 0:
raise ValueError(f"tol = {tol} must be positive.")
if atol is None:
atol = tol
if atol < 0:
raise ValueError(f"atol = {atol} must be positive.")

if M is not None:
raise NotImplementedError("M is not supported by the Jax backend.")
if A_kwargs is not None:
raise NotImplementedError("A_kwargs is not supported by the Jax backend.")

if A_args is None:
A_args = []

if x0 is None:
x0 = self.zeros(b.shape, b.dtype)

if A_mv not in _CACHED_MATVECS:
_CACHED_MATVECS[A_mv] = libjax.tree_util.Partial(A_mv)
if "gmres_f" not in _CACHED_FUNCTIONS:
_CACHED_FUNCTIONS["gmres_f"] = jitted_functions.gmres_wrapper(libjax)
gmres_f = _CACHED_FUNCTIONS["gmres_f"]
x, _, n_iter, converged = gmres_f(_CACHED_MATVECS[A_mv], A_args, b,
x0, tol, atol, num_krylov_vectors,
maxiter)
if converged:
info = 0
else:
info = n_iter
return x, info

def conj(self, tensor: Tensor) -> Tensor:
return jnp.conj(tensor)
Expand Down
Loading

0 comments on commit 31d80c3

Please sign in to comment.