Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added Lanczos solver to exponentiate matrix #971

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion tensornetwork/backends/abstract_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,47 @@ def eigsh_lanczos(self,
"""
raise NotImplementedError(
"Backend '{}' has not implemented eighs_lanczos.".format(self.name))

def expm_lanczos(self,
A: Callable,
initial_state: Tensor,
args: Optional[List[Tensor]] = None,
dtype: Optional[Type[np.number]] = None,# pylint: disable=no-member
num_krylov_vecs: int = 20,
tol: float = 1E-8,
delta: float = 1E-8,
ndiag: int = 20,
reorthogonalize: bool = False,dt: Optional[np.number] =1) -> Tuple[Tensor, List]:
"""
Lanczos method for computing the exponential of dt*`A` and applying it
to initial_state.
Args:
A: A (sparse) implementation of a linear operator.
Call signature of `A` is `res = A(vector, *args)`, where `vector`
can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
initial_state: An initial vector for the Lanczos algorithm.
arsg: A list of arguments to `A`. `A` will be called as
`res = A(initial_state, *args)`.
dtype: The dtype of the input `A`. If both no `initial_state` is provided,
a random initial state with shape `shape` and dtype `dtype` is created.
num_krylov_vecs: The number of iterations (number of krylov vectors).
tol: The desired precision of the normalized vector
u@np.diag(np.exp(dt*eigvals))@np.conj(u.transpose())@np.array([1,0,0,0..]),
where eigvals are then eigenvalues of the tridiagonal matrix and u the eigen vectors
delta: Stopping criterion for Lanczos iteration.
If a Krylov vector :math: `x_n` has an L2 norm
:math:`\\lVert x_n\\rVert < delta`, the iteration
is stopped. It means that an (approximate) invariant subspace has
been found.
ndiag: The tridiagonal Operator is diagonalized every `ndiag` iterations
to check convergence.
reorthogonalize: If `True`, Krylov vectors are kept orthogonal by
explicit orthogonalization (more costly than `reorthogonalize=False`)
dt: Prefactor of 'A' in the exponential, can be real or complex
Returns:
final_state=exp(dt*A)initial_state
"""
raise NotImplementedError(
"Backend '{}' has not implemented eighs_lanczos.".format(self.name))
def gmres(self,
A_mv: Callable,
b: Tensor,
Expand Down
109 changes: 109 additions & 0 deletions tensornetwork/backends/numpy/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,116 @@ def eigsh_lanczos(self,
state += vec * u[n1, n2]
eigenvectors.append(state / np.linalg.norm(state))
return eigvals[0:numeig], eigenvectors

def expm_lanczos(self,
A: Callable,
initial_state: Optional[Tensor],
args: Optional[List[Tensor]] = None,
dtype: Optional[Type[np.number]] = None,
num_krylov_vecs: int = 20,
tol: float = 1E-8,
delta: float = 1E-8,
ndiag: int = 20,
reorthogonalize: bool = False,dt:Optional[np.number]=1) -> Tuple[Tensor, List]:
"""
Lanczos method for computing the exponential of dt*`A` and applying it
to initial_state.
Args:
A: A (sparse) implementation of a linear operator.
Call signature of `A` is `res = A(vector, *args)`, where `vector`
can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
initial_state: An initial vector for the Lanczos algorithm.
arsg: A list of arguments to `A`. `A` will be called as
`res = A(initial_state, *args)`.
dtype: The dtype of the input `A`. If both no `initial_state` is provided,
a random initial state with shape `shape` and dtype `dtype` is created.
num_krylov_vecs: The number of iterations (number of krylov vectors).
tol: The desired precision of the normalized vector
u@np.diag(np.exp(dt*eigvals))@np.conj(u.transpose())@np.array([1,0,0,0..]),
where eigvals are then eigenvalues of the tridiagonal matrix and u the eigen vectors
delta: Stopping criterion for Lanczos iteration.
If a Krylov vector :math: `x_n` has an L2 norm
:math:`\\lVert x_n\\rVert < delta`, the iteration
is stopped. It means that an (approximate) invariant subspace has
been found.
ndiag: The tridiagonal Operator is diagonalized every `ndiag` iterations
to check convergence.
reorthogonalize: If `True`, Krylov vectors are kept orthogonal by
explicit orthogonalization (more costly than `reorthogonalize=False`)
dt: Prefactor of 'A' in the exponential, can be real or complex
Returns:
final_state=exp(dt*A)initial_state
"""
if args is None:
args = []

if not isinstance(initial_state, np.ndarray):
raise TypeError("Expected a `np.ndarray`. Got {}".format(
type(initial_state)))

vector_n = initial_state
Z = self.norm(vector_n)
vector_n /= Z
norms_vector_n = []
diag_elements = []
krylov_vecs = []
first = True
c_old =self.zeros((1), dtype=initial_state.dtype)
c_old+=1e-08 # ensuring some initial value
for it in range(num_krylov_vecs):

#normalize the current vector:
norm_vector_n = self.norm(vector_n)
if abs(norm_vector_n) < delta:
break
norms_vector_n.append(norm_vector_n)
vector_n = vector_n / norms_vector_n[-1]
#store the Lanczos vector for later
if reorthogonalize:
for v in krylov_vecs:
vector_n -= np.dot(np.ravel(np.conj(v)), np.ravel(vector_n)) * v
krylov_vecs.append(vector_n)
A_vector_n = A(vector_n, *args)
diag_elements.append(
np.dot(np.ravel(np.conj(vector_n)), np.ravel(A_vector_n)))

if ((it > 0) and (it % ndiag) == 0):
#diagonalize the effective Hamiltonian
A_tridiag = np.diag(diag_elements) + np.diag(
norms_vector_n[1:], 1) + np.diag(np.conj(norms_vector_n[1:]), -1)

eigvals, u = np.linalg.eigh(A_tridiag)
matrix_exp=u@np.diag(np.exp(dt*eigvals))@np.conj(u.transpose())
c_new=matrix_exp[:,0].reshape(-1,1)
c_new/=self.norm(c_new)
c_old_m=np.zeros(c_new.shape,dtype=initial_state.dtype)
c_old_m[0:c_old.size]=c_old
c_old_m/=self.norm(c_old_m)
if not first:
if np.linalg.norm(c_old_m - c_new) < tol:
break
first = False
c_old= c_new
if it > 0:
A_vector_n -= (krylov_vecs[-1] * diag_elements[-1])
A_vector_n -= (krylov_vecs[-2] * norms_vector_n[-1])
else:
A_vector_n -= (krylov_vecs[-1] * diag_elements[-1])
vector_n = A_vector_n

A_tridiag = np.diag(diag_elements) + np.diag(
norms_vector_n[1:], 1) + np.diag(np.conj(norms_vector_n[1:]), -1)

eigvals, u = np.linalg.eigh(A_tridiag)
matrix_exp=u@np.diag(np.exp(dt*eigvals))@np.conj(u.transpose())


state = self.zeros(initial_state.shape, initial_state.dtype)

for n1, vec in enumerate(krylov_vecs):
state += vec * matrix_exp[n1, 0]

return np.linalg.norm(initial_state)*state
def addition(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
return tensor1 + tensor2

Expand Down
74 changes: 74 additions & 0 deletions tensornetwork/backends/numpy/numpy_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,80 @@ def test_eigsh_lanczos_raises():
TypeError, match="Expected a `np.ndarray`. Got <class 'list'>"):
backend.eigsh_lanczos(lambda x: x, initial_state=[1, 2, 3])

def test_expm_small_number_krylov_vectors():
backend = numpy_backend.NumPyBackend()
dt=0.1
init = np.array([1, 1], dtype=np.float64)
H = np.array([[1, 2], [3, 4]], dtype=np.float64)

def mv(x, mat):
return np.dot(mat, x)

state= backend.expm_lanczos(mv, init,[H], num_krylov_vecs=1)
a0=init.transpose()@H@init
res=np.exp(a0)*init
np.testing.assert_allclose(state, res)


@pytest.mark.parametrize("dtype", [np.float64, np.complex128])
def test_expm_lanczos_1(dtype):
backend = numpy_backend.NumPyBackend()
D = 16
np.random.seed(10)
init = backend.randn((D,), dtype=dtype, seed=10)
tmp = backend.randn((D, D), dtype=dtype, seed=10)
H = tmp + backend.transpose(backend.conj(tmp), (1, 0))
def mv(x, mat):
return np.dot(mat, x)
state=backend.expm_lanczos(mv, init, [H])

w,v = np.linalg.eigh(H)
res=v@np.diag(np.exp(w))@np.conj(v).transpose()@init


np.testing.assert_allclose(state, res)






@pytest.mark.parametrize("dtype", [ np.complex128])
@pytest.mark.parametrize("dt", [0.1j, -0.1j])
def test_eigsh_lanczos_reorthogonalize_different_dt(dtype, dt):
backend = numpy_backend.NumPyBackend()
D = 24
np.random.seed(10)
init = backend.randn((D,), dtype=dtype, seed=10)
tmp = backend.randn((D, D), dtype=dtype, seed=10)
H = tmp + backend.transpose(backend.conj(tmp), (1, 0))

def mv(x, mat):
return np.dot(mat, x)
state= backend.expm_lanczos(
mv,init,
[H],
dtype=dtype,
num_krylov_vecs=D,
reorthogonalize=True,
ndiag=1,
tol=10**(-12),
delta=10**(-12), dt=dt)
w,v = np.linalg.eigh(H)
res=v@np.diag(np.exp(dt*w))@np.conj(v).transpose()@init
np.testing.assert_allclose(state, res)
test_eigsh_lanczos_reorthogonalize_different_dt(np.complex128, -0.1)
def test_expm_lanczos_raises():
backend = numpy_backend.NumPyBackend()

with pytest.raises(
TypeError, match="Expected a `np.ndarray`. Got <class 'list'>"):
backend.expm_lanczos(lambda x: x, initial_state=[1, 2, 3])
# to do, implement test on dtypes of dt, initial state and matrix
#with pytest.raises(
# TypeError, match="Expected a `np.ndarray`. Got <class 'list'>"):
# backend.expm_lanczos(lambda x: x, initial_state=[1, 2, 3])


def test_gmres_raises():
backend = numpy_backend.NumPyBackend()
Expand Down
121 changes: 121 additions & 0 deletions tensornetwork/backends/pytorch/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,128 @@ def eigsh_lanczos(self,
state += vec * u[n1, n2]
eigenvectors.append(state / torchlib.norm(state))
return eigvals[0:numeig], eigenvectors
def expm_lanczos(self,
A: Callable,
initial_state: Optional[Tensor],
args: Optional[List[Tensor]] = None,
dtype: Optional[Type[np.number]] = None,
num_krylov_vecs: int = 20,
tol: float = 1E-8,
delta: float = 1E-8,
ndiag: int = 20,
reorthogonalize: bool = False,dt:Optional[np.number]=1) -> Tuple[Tensor, List]:
"""
Lanczos method for computing the exponential of dt*`A` and applying it
to initial_state.
Args:
A: A (sparse) implementation of a linear operator.
Call signature of `A` is `res = A(vector, *args)`, where `vector`
can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
initial_state: An initial vector for the Lanczos algorithm.
arsg: A list of arguments to `A`. `A` will be called as
`res = A(initial_state, *args)`.
dtype: The dtype of the input `A`. If both no `initial_state` is provided,
a random initial state with shape `shape` and dtype `dtype` is created.
num_krylov_vecs: The number of iterations (number of krylov vectors).
tol: The desired precision of the normalized vector
u@np.diag(np.exp(dt*eigvals))@np.conj(u.transpose())@np.array([1,0,0,0..]),
where eigvals are then eigenvalues of the tridiagonal matrix and u the eigen vectors
delta: Stopping criterion for Lanczos iteration.
If a Krylov vector :math: `x_n` has an L2 norm
:math:`\\lVert x_n\\rVert < delta`, the iteration
is stopped. It means that an (approximate) invariant subspace has
been found.
ndiag: The tridiagonal Operator is diagonalized every `ndiag` iterations
to check convergence.
reorthogonalize: If `True`, Krylov vectors are kept orthogonal by
explicit orthogonalization (more costly than `reorthogonalize=False`)
dt: Prefactor of 'A' in the exponential, can be real or complex
Returns:
final_state=exp(dt*A)initial_state
"""
if args is None:
args = []


if not isinstance(initial_state, torchlib.Tensor):
raise TypeError("Expected a `torch.Tensor`. Got {}".format(
type(initial_state)))

initial_state = self.convert_to_tensor(initial_state)
vector_n = initial_state
Z = self.norm(vector_n)
vector_n /= Z
norms_vector_n = []
diag_elements = []
krylov_vecs = []
first = True
c_old =self.zeros((1), dtype=initial_state.dtype)
c_old+=1e-08 # ensuring some initial value
for it in range(num_krylov_vecs):

#normalize the current vector:
norm_vector_n = torchlib.norm(vector_n)
if abs(norm_vector_n) < delta:
break
norms_vector_n.append(norm_vector_n)
vector_n = vector_n / norms_vector_n[-1]
#store the Lanczos vector for later
if reorthogonalize:
for v in krylov_vecs:

vector_n -= (torchlib.conj(v).contiguous().view(-1).dot(
vector_n.contiguous().view(-1))) * torchlib.reshape(
v, vector_n.shape)
krylov_vecs.append(vector_n)
A_vector_n = A(vector_n, *args)

diag_elements.append(torchlib.conj(vector_n).contiguous().view(-1).dot(
A_vector_n.contiguous().view(-1)))

if ((it > 0) and (it % ndiag) == 0):
#diagonalize the effective Hamiltonian

A_tridiag = torchlib.diag(
torchlib.tensor(diag_elements)) + torchlib.diag(
torchlib.tensor(norms_vector_n[1:]), 1) + torchlib.diag(
torchlib.conj(torchlib.tensor(norms_vector_n[1:])), -1)

eigvals, u = torchlib.linalg.eigh(A_tridiag)
matrix_exp=u@torchlib.diag(torchlib.exp(dt*eigvals))@torchlib.conj(torchlib.transpose(u,1,0))
c_new=matrix_exp[:,0].reshape(-1,1)
c_new/=self.norm(c_new)
c_old_m=self.zeros(c_new.shape,dtype=initial_state.dtype)

c_old_m[0:c_old.shape[0]]=c_old.reshape(-1,1)
c_old_m/=self.norm(c_old_m)
c_old=c_old.reshape(-1,1)
if not first:
if torchlib.linalg.norm(c_old_m - c_new) < tol:
break
first = False
c_old= c_new
if it > 0:
A_vector_n -= (krylov_vecs[-1] * diag_elements[-1])
A_vector_n -= (krylov_vecs[-2] * norms_vector_n[-1])
else:
A_vector_n -= (krylov_vecs[-1] * diag_elements[-1])
vector_n = A_vector_n

A_tridiag = torchlib.diag(
torchlib.tensor(diag_elements)) + torchlib.diag(
torchlib.tensor(norms_vector_n[1:]), 1) + torchlib.diag(
torchlib.conj(torchlib.tensor(norms_vector_n[1:])), -1)

eigvals, u = torchlib.linalg.eigh(A_tridiag)
matrix_exp=u@torchlib.diag(torchlib.exp(dt*eigvals))@torchlib.conj(torchlib.transpose(u,1,0))


state = self.zeros(initial_state.shape, initial_state.dtype)

for n1, vec in enumerate(krylov_vecs):
state += vec * matrix_exp[n1, 0]

return torchlib.linalg.norm(initial_state)*state
def addition(self, tensor1: Tensor, tensor2: Tensor) -> Tensor:
return tensor1 + tensor2

Expand Down
Loading