Skip to content

Commit

Permalink
add eigsh_lanczos to blocksparse backend (#688)
Browse files Browse the repository at this point in the history
* typing

* typing

* typing

* typing. lintingc

* linting

* typing

* wip

* transpose_data -> contiguous

* wip adding eigsh_lanczos

* __matmul__ support for tensors (just like numpy)

* extend __matmul__ to tensors

* restrict to vectors and matrices

* tests adjusted

* yapf

* wip

* wip

* wip

* testing added

* comment

* comment

* typing

* yapf

* yapf

* linting

* typinig changes

* typing

* remove newline

* comment

* comment

* comment

* add type

* linting

* typing

* typing

* typing

* linting

* remove parens

* space

* remove complex test

* remove complex casting

* remove parense

* add some parense for readability
  • Loading branch information
mganahl committed Jun 30, 2020
1 parent 7967a91 commit 892468e
Show file tree
Hide file tree
Showing 6 changed files with 370 additions and 31 deletions.
3 changes: 1 addition & 2 deletions tensornetwork/backends/numpy/numpy_backend.py
Expand Up @@ -397,8 +397,7 @@ def eigsh_lanczos(self,
norms_vector_n[1:], 1) + np.diag(np.conj(norms_vector_n[1:]), -1)
eigvals, u = np.linalg.eigh(A_tridiag)
eigenvectors = []
if np.iscomplexobj(A_tridiag):
eigvals = np.array(eigvals).astype(A_tridiag.dtype)
eigvals = np.array(eigvals).astype(A_tridiag.dtype)

for n2 in range(min(numeig, len(eigvals))):
state = self.zeros(initial_state.shape, initial_state.dtype)
Expand Down
146 changes: 140 additions & 6 deletions tensornetwork/backends/symmetric/symmetric_backend.py
Expand Up @@ -21,10 +21,10 @@
import numpy
Tensor = Any

#TODO (mganahl): implement sparse solvers
# TODO (mganahl): implement eigs


#pylint: disable=abstract-method
# pylint: disable=abstract-method
class SymmetricBackend(abstract_backend.AbstractBackend):
"""See base_backend.BaseBackend for documentation."""

Expand Down Expand Up @@ -126,19 +126,19 @@ def eye(self,
return self.bs.eye(N, M, dtype=dtype)

def ones(self,
shape: List[Index],
shape: Sequence[Index],
dtype: Optional[numpy.dtype] = None) -> Tensor:
dtype = dtype if dtype is not None else numpy.float64
return self.bs.ones(shape, dtype=dtype)

def zeros(self,
shape: List[Index],
shape: Sequence[Index],
dtype: Optional[numpy.dtype] = None) -> Tensor:
dtype = dtype if dtype is not None else numpy.float64
return self.bs.zeros(shape, dtype=dtype)

def randn(self,
shape: List[Index],
shape: Sequence[Index],
dtype: Optional[numpy.dtype] = None,
seed: Optional[int] = None) -> Tensor:

Expand All @@ -147,7 +147,7 @@ def randn(self,
return self.bs.randn(shape, dtype)

def random_uniform(self,
shape: List[Index],
shape: Sequence[Index],
boundaries: Optional[Tuple[float, float]] = (0.0, 1.0),
dtype: Optional[numpy.dtype] = None,
seed: Optional[int] = None) -> Tensor:
Expand All @@ -163,6 +163,140 @@ def conj(self, tensor: Tensor) -> Tensor:
def eigh(self, matrix: Tensor) -> Tuple[Tensor, Tensor]:
return self.bs.eigh(matrix)

def eigsh_lanczos(self,
A: Callable,
args: Optional[List[Tensor]] = None,
initial_state: Optional[Tensor] = None,
shape: Optional[Tuple] = None,
dtype: Optional[Type[numpy.number]] = None,
num_krylov_vecs: int = 20,
numeig: int = 1,
tol: float = 1E-8,
delta: float = 1E-8,
ndiag: int = 20,
reorthogonalize: bool = False) -> Tuple[Tensor, List]:
"""
Lanczos method for finding the lowest eigenvector-eigenvalue pairs
of a linear operator `A`.
Args:
A: A (sparse) implementation of a linear operator.
Call signature of `A` is `res = A(vector, *args)`, where `vector`
can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
arsg: A list of arguments to `A`. `A` will be called as
`res = A(initial_state, *args)`.
initial_state: An initial vector for the Lanczos algorithm. If `None`,
a random initial `Tensor` is created using the `backend.randn` method
shape: The shape of the input-dimension of `A`.
dtype: The dtype of the input `A`. If both no `initial_state` is provided,
a random initial state with shape `shape` and dtype `dtype` is created.
num_krylov_vecs: The number of iterations (number of krylov vectors).
numeig: The nummber of eigenvector-eigenvalue pairs to be computed.
If `numeig > 1`, `reorthogonalize` has to be `True`.
tol: The desired precision of the eigenvalus. Uses
`numpy.linalg.norm(eigvalsnew[0:numeig] - eigvalsold[0:numeig]) < tol`
as stopping criterion between two diagonalization steps of the
tridiagonal operator.
delta: Stopping criterion for Lanczos iteration.
If a Krylov vector :math: `x_n` has an L2 norm
:math:`\\lVert x_n\\rVert < delta`, the iteration
is stopped. It means that an (approximate) invariant subspace has
been found.
ndiag: The tridiagonal Operator is diagonalized every `ndiag` iterations
to check convergence.
reorthogonalize: If `True`, Krylov vectors are kept orthogonal by
explicit orthogonalization (more costly than `reorthogonalize=False`)
Returns:
(eigvals, eigvecs)
eigvals: A list of `numeig` lowest eigenvalues
eigvecs: A list of `numeig` lowest eigenvectors
"""
if args is None:
args = []

if num_krylov_vecs < numeig:
raise ValueError('`num_krylov_vecs` >= `numeig` required!')

if numeig > 1 and not reorthogonalize:
raise ValueError(
"Got numeig = {} > 1 and `reorthogonalize = False`. "
"Use `reorthogonalize=True` for `numeig > 1`".format(numeig))
if initial_state is None:
if (shape is None) or (dtype is None):
raise ValueError("if no `initial_state` is passed, then `shape` and"
"`dtype` have to be provided")
initial_state = self.randn(shape, dtype)

if not isinstance(initial_state, BlockSparseTensor):
raise TypeError("Expected a `BlockSparseTensor`. Got {}".format(
type(initial_state)))

vector_n = initial_state
vector_n.contiguous() # bring into contiguous memory layout

Z = self.norm(vector_n)
vector_n /= Z
norms_vector_n = []
diag_elements = []
krylov_vecs = []
first = True
eigvalsold = []
for it in range(num_krylov_vecs):
# normalize the current vector:
norm_vector_n = self.norm(vector_n)
if abs(norm_vector_n) < delta:
# we found an invariant subspace, time to stop
break
norms_vector_n.append(norm_vector_n)
vector_n = vector_n / norms_vector_n[-1]
# store the Lanczos vector for later
if reorthogonalize:
# vector_n is always in contiguous memory layout at this point
for v in krylov_vecs:
v.contiguous() # make sure storage layouts are matching
# it's save to operate on the tensor data now (pybass some checks)
vector_n.data -= numpy.dot(numpy.conj(v.data), vector_n.data) * v.data
krylov_vecs.append(vector_n)
A_vector_n = A(vector_n, *args)
A_vector_n.contiguous() # contiguous memory layout

# operate on tensor-data for scalar products
# this can be potentially problematic if vector_n and A_vector_n
# have non-matching shapes due to an erroneous matvec.
# If this is the case though an error will be thrown at line 281
diag_elements.append(
numpy.dot(numpy.conj(vector_n.data), A_vector_n.data))

if (it > 0) and (it % ndiag == 0) and (len(diag_elements) >= numeig):
# diagonalize the effective Hamiltonian
A_tridiag = numpy.diag(diag_elements) + numpy.diag(
norms_vector_n[1:], 1) + numpy.diag(
numpy.conj(norms_vector_n[1:]), -1)
eigvals, u = numpy.linalg.eigh(A_tridiag)
if not first:
if numpy.linalg.norm(eigvals[0:numeig] - eigvalsold[0:numeig]) < tol:
break
first = False
eigvalsold = eigvals[0:numeig]
if it > 0:
A_vector_n -= (krylov_vecs[-1] * diag_elements[-1])
A_vector_n -= (krylov_vecs[-2] * norms_vector_n[-1])
else:
A_vector_n -= (krylov_vecs[-1] * diag_elements[-1])
vector_n = A_vector_n

A_tridiag = numpy.diag(diag_elements) + numpy.diag(
norms_vector_n[1:], 1) + numpy.diag(numpy.conj(norms_vector_n[1:]), -1)
eigvals, u = numpy.linalg.eigh(A_tridiag)
eigenvectors = []
eigvals = numpy.array(eigvals).astype(A_tridiag.dtype)

for n2 in range(min(numeig, len(eigvals))):
state = self.zeros(initial_state.sparse_shape, initial_state.dtype)
for n1, vec in enumerate(krylov_vecs):
state += vec * u[n1, n2]
eigenvectors.append(state / self.norm(state))
return eigvals[0:numeig], eigenvectors

def addition(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
return tensor1 + tensor2

Expand Down

0 comments on commit 892468e

Please sign in to comment.