Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.
211 changes: 151 additions & 60 deletions tensornetwork/matrixproductstates/base_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from functools import partial
from tensornetwork.backends.decorators import jit
import warnings
from tensornetwork.ncon_interface import ncon
from tensornetwork.backend_contextmanager import get_default_backend
from tensornetwork.backends.abstract_backend import AbstractBackend
from typing import Any, List, Optional, Text, Type, Union, Dict, Sequence
import tensornetwork.ncon_interface as ncon

Tensor = Any


Expand Down Expand Up @@ -82,10 +83,10 @@ def __init__(self,
else:
self.backend = backend_factory.get_backend(backend)


# the dtype is deduced from the tensor object.
self.tensors = [self.backend.convert_to_tensor(t) for t in tensors]
if not all(
[self.tensors[0].dtype == tensor.dtype for tensor in self.tensors]):
if not all(t.dtype == self.tensors[0].dtype for t in self.tensors):
raise TypeError('not all dtypes in BaseMPS.tensors are the same')

self.connector_matrix = connector_matrix
Expand All @@ -94,20 +95,30 @@ def __init__(self,
########################################################################
########## define functions for jitted operations ##########
########################################################################
@partial(jit, backend=self.backend, static_argnums=(1,))
def svd(tensor, max_singular_values=None):
return self.backend.svd(tensor=tensor, pivot_axis=2,
max_singular_values=max_singular_values)
@partial(jit, backend=self.backend, static_argnums=(1, 2, 3))
def svd(tensor,
pivot_axis=2,
max_singular_values=None,
max_truncation_error=None):
return self.backend.svd(
tensor=tensor,
pivot_axis=pivot_axis,
max_singular_values=max_singular_values,
max_truncation_error=max_truncation_error,
relative=True)

self.svd = svd

@partial(jit, backend=self.backend)
def qr(tensor):
return self.backend.qr(tensor, 2)

self.qr = qr

@partial(jit, backend=self.backend)
def rq(tensor):
return self.backend.rq(tensor, 1)

self.rq = rq

self.norm = self.backend.jit(self.backend.norm)
Expand All @@ -116,22 +127,27 @@ def rq(tensor):
########################################################################

def left_transfer_operator(self, A, l, Abar):
return ncon([A, l, Abar], [[1, 2, -1], [1, 3], [3, 2, -2]],
backend=self.backend.name)
return ncon.ncon([A, l, Abar], [[1, 2, -1], [1, 3], [3, 2, -2]],
backend=self.backend.name)

def right_transfer_operator(self, B, r, Bbar):
return ncon([B, r, Bbar], [[-1, 2, 1], [1, 3], [-2, 2, 3]],
backend=self.backend.name)
return ncon.ncon([B, r, Bbar], [[-1, 2, 1], [1, 3], [-2, 2, 3]],
backend=self.backend.name)

def __len__(self) -> int:
return len(self.tensors)

def position(self, site: int, normalize: Optional[bool] = True) -> np.number:
def position(self, site: int, normalize: Optional[bool] = True,
D: Optional[int] = None,
max_truncation_err: Optional[float] = None) -> np.number:
"""Shift `center_position` to `site`.

Args:
site: The site to which FiniteMPS.center_position should be shifted
normalize: If `True`, normalize matrices when shifting.
D: If not `None`, truncate the MPS bond dimensions to `D`.
max_truncation_err: if not `None`, truncate each bond dimension,
but keeping the truncation error below `max_truncation_err`.
Returns:
`Tensor`: The norm of the tensor at `FiniteMPS.center_position`
Raises:
Expand All @@ -141,27 +157,40 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number:
raise ValueError(
"BaseMPS.center_position is `None`, cannot shift `center_position`."
"Reset `center_position` manually or use `canonicalize`")

if max_truncation_err is not None and max_truncation_err >= 1.0:
raise ValueError("max_truncation_err should be 0 <= max_truncation_er"
f" < 1, found max_truncation_err = {max_truncation_err}")
#`site` has to be between 0 and len(mps) - 1
if site >= len(self.tensors) or site < 0:
raise ValueError('site = {} not between values'
' 0 < site < N = {}'.format(site, len(self)))


#nothing to do
if site == self.center_position:
Z = self.norm(self.tensors[self.center_position])
if normalize:
self.tensors[self.center_position] /= Z
return Z

#shift center_position to the right using QR decomposition
#shift center_position to the right using QR or SV decomposition
if site > self.center_position:
n = self.center_position
for n in range(self.center_position, site):
Q, R = self.qr(self.tensors[n])
self.tensors[n] = Q
self.tensors[n + 1] = ncon([R, self.tensors[n + 1]],
[[-1, 1], [1, -2, -3]],
backend=self.backend.name)
use_svd = (D is not None and D < self.bond_dimension(n + 1)
) or max_truncation_err is not None
if not use_svd:
isometry, rest = self.qr(self.tensors[n])
else:
isometry, S, V, _ = self.svd(self.tensors[n], 2, D,
max_truncation_err)
rest = ncon.ncon([self.backend.diagflat(S), V], [[-1, 1], [1, -2]],
backend=self.backend)

self.tensors[n] = isometry
self.tensors[n + 1] = ncon.ncon([rest, self.tensors[n + 1]],
[[-1, 1], [1, -2, -3]],
backend=self.backend.name)
Z = self.norm(self.tensors[n + 1])
# for an mps with > O(10) sites one needs to normalize to avoid
# over or underflow errors; this takes care of the normalization
Expand All @@ -170,18 +199,26 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number:

self.center_position = site

#shift center_position to the left using RQ decomposition
#shift center_position to the left using RQ or SV decomposition
else:
for n in reversed(range(site + 1, self.center_position + 1)):

R, Q = self.rq(self.tensors[n])
use_svd = (D is not None and D < self.bond_dimension(n)
) or max_truncation_err is not None
if not use_svd:
rest, isometry = self.rq(self.tensors[n])
else:
U, S, isometry, _ = self.svd(self.tensors[n], 1, D,
max_truncation_err)
rest = ncon.ncon([U, self.backend.diagflat(S)], [[-1, 1], [1, -2]],
backend=self.backend)

self.tensors[n] = isometry #a right-isometric tensor of rank 3
self.tensors[n - 1] = ncon.ncon([self.tensors[n - 1], rest],
[[-1, -2, 1], [1, -3]],
backend=self.backend.name)
Z = self.norm(self.tensors[n - 1])
# for an mps with > O(10) sites one needs to normalize to avoid
# over or underflow errors; this takes care of the normalization
self.tensors[n] = Q #Q is a right-isometric tensor of rank 3
self.tensors[n - 1] = ncon([self.tensors[n - 1], R],
[[-1, -2, 1], [1, -3]],
backend=self.backend.name)
Z = self.norm(self.tensors[n - 1])
if normalize:
self.tensors[n - 1] /= Z

Expand All @@ -191,15 +228,23 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number:

@property
def dtype(self) -> Type[np.number]:
if not all(
[self.tensors[0].dtype == tensor.dtype for tensor in self.tensors]):
if not all(t.dtype == self.tensors[0].dtype for t in self.tensors):
raise TypeError('not all dtype in BaseMPS.tensors are the same')

return self.tensors[0].dtype

def save(self, path: str):
raise NotImplementedError()

def bond_dimension(self, bond) -> List:
"""The bond dimension of `bond`"""
if bond > len(self):
raise IndexError(f"bond {bond} out of bounds for"
f" an MPS of length {len(self)}")
if bond < len(self):
return self.tensors[bond].shape[0]
return self.tensors[bond].shape[2]

@property
def bond_dimensions(self) -> List:
"""A list of bond dimensions of `BaseMPS`"""
Expand Down Expand Up @@ -439,7 +484,9 @@ def apply_two_site_gate(self,
site1: int,
site2: int,
max_singular_values: Optional[int] = None,
max_truncation_err: Optional[float] = None) -> Tensor:
max_truncation_err: Optional[float] = None,
center_position: Optional[int] = None,
relative: bool = False) -> Tensor:
"""Apply a two-site gate to an MPS. This routine will in general destroy
any canonical form of the state. If a canonical form is needed, the user
can restore it using `FiniteMPS.position`.
Expand All @@ -450,6 +497,15 @@ def apply_two_site_gate(self,
site2: The second site where the gate acts.
max_singular_values: The maximum number of singular values to keep.
max_truncation_err: The maximum allowed truncation error.
center_position: An optional value to choose the MPS tensor at
`center_position` to be isometric after the application of the gate.
Defaults to `site1`. If the MPS is canonical (i.e.
`BaseMPS.center_position != None`), and if the orthogonality center
coincides with either `site1` or `site2`, the orthogonality center will
be shifted to `center_position` (`site1` by default). If the
orthogonality center does not coincide with `(site1, site2)` then
`MPS.center_position` is set to `None`.
relative: Multiply `max_truncation_err` with the largest singular value.

Returns:
`Tensor`: A scalar tensor containing the truncated weight of the
Expand All @@ -473,6 +529,10 @@ def apply_two_site_gate(self,
"neighbor gates are currently"
"supported".format(site2, site1))

if center_position is not None and center_position not in (site1, site2):
raise ValueError(f"center_position = {center_position} not "
f"in {(site1, site2)} ")

if (max_singular_values or
max_truncation_err) and self.center_position not in (site1, site2):
raise ValueError(
Expand All @@ -481,28 +541,59 @@ def apply_two_site_gate(self,
'is applied at the center position of the MPS'.format(
self.center_position, site1, site2))

gate_node = Node(gate, backend=self.backend)
node1 = Node(self.tensors[site1], backend=self.backend)
node2 = Node(self.tensors[site2], backend=self.backend)
node1[2] ^ node2[0]
gate_node[2] ^ node1[1]
gate_node[3] ^ node2[1]
left_edges = [node1[0], gate_node[0]]
right_edges = [gate_node[1], node2[2]]
result = node1 @ node2 @ gate_node
U, S, V, tw = split_node_full_svd(
result,
left_edges=left_edges,
right_edges=right_edges,
max_singular_values=max_singular_values,
max_truncation_err=max_truncation_err,
left_name=node1.name,
right_name=node2.name)
V.reorder_edges([S[1]] + right_edges)
left_edges = left_edges + [S[1]]
res = contract_between(U, S, name=U.name).reorder_edges(left_edges)
self.tensors[site1] = res.tensor
self.tensors[site2] = V.tensor
use_svd = (max_truncation_err is not None) or (max_singular_values
is not None)
gate = self.backend.convert_to_tensor(gate)
tensor = ncon.ncon([self.tensors[site1], self.tensors[site2], gate],
[[-1, 1, 2], [2, 3, -4], [-2, -3, 1, 3]],
backend=self.backend)

def set_center_position(site):
if self.center_position is not None:
if self.center_position in (site1, site2):
assert site in (site1, site2)
self.center_position = site
else:
self.center_position = None

if center_position is None:
center_position = site1

if use_svd:
U, S, V, tw = self.backend.svd(
tensor,
pivot_axis=2,
max_singular_values=max_singular_values,
max_truncation_error=max_truncation_err,
relative=relative)
if center_position == site2:
left_tensor = U
right_tensor = ncon.ncon([self.backend.diagflat(S), V],
[[-1, 1], [1, -2, -3]],
backend=self.backend)
set_center_position(site2)
else:
left_tensor = ncon.ncon([U, self.backend.diagflat(S)],
[[-1, -2, 1], [1, -3]],
backend=self.backend)
right_tensor = V
set_center_position(site1)

else:
tw = self.backend.zeros(1, dtype=self.dtype)
if center_position == site2:
R, Q = self.backend.rq(tensor, pivot_axis=2)
left_tensor = R
right_tensor = Q
set_center_position(site2)
else:
Q, R = self.backend.qr(tensor, pivot_axis=2)
left_tensor = Q
right_tensor = R
set_center_position(site1)

self.tensors[site1] = left_tensor
self.tensors[site2] = right_tensor
return tw

def apply_one_site_gate(self, gate: Tensor, site: int) -> None:
Expand All @@ -519,9 +610,9 @@ def apply_one_site_gate(self, gate: Tensor, site: int) -> None:
if site < 0 or site >= len(self):
raise ValueError('site = {} is not between 0 <= site < N={}'.format(
site, len(self)))
self.tensors[site] = ncon([gate, self.tensors[site]],
[[-2, 1], [-1, 1, -3]],
backend=self.backend.name)
self.tensors[site] = ncon.ncon([gate, self.tensors[site]],
[[-2, 1], [-1, 1, -3]],
backend=self.backend.name)

def check_orthonormality(self, which: Text, site: int) -> Tensor:
"""Check orthonormality of tensor at site `site`.
Expand Down Expand Up @@ -555,8 +646,8 @@ def check_orthonormality(self, which: Text, site: int) -> Tensor:
M=self.backend.sparse_shape(result)[1],
dtype=self.dtype)
return self.backend.sqrt(
ncon([tmp, self.backend.conj(tmp)], [[1, 2], [1, 2]],
backend=self.backend))
ncon.ncon([tmp, self.backend.conj(tmp)], [[1, 2], [1, 2]],
backend=self.backend))

# pylint: disable=inconsistent-return-statements
def check_canonical(self) -> Any:
Expand Down Expand Up @@ -601,9 +692,9 @@ def get_tensor(self, site: int) -> Tensor:
'index `site` has to be larger than 0 (found `site`={}).'.format(
site))
if (site == len(self) - 1) and (self.connector_matrix is not None):
return ncon([self.tensors[site], self.connector_matrix],
[[-1, -2, 1], [1, -3]],
backend=self.backend.name)
return ncon.ncon([self.tensors[site], self.connector_matrix],
[[-1, -2, 1], [1, -3]],
backend=self.backend.name)
return self.tensors[site]

def canonicalize(self, *args, **kwargs) -> np.number:
Expand Down
15 changes: 15 additions & 0 deletions tensornetwork/matrixproductstates/base_mps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ def test_position_raises_error(backend):
" `None`, cannot shift `center_position`."
"Reset `center_position` manually or use `canonicalize`"):
mps.position(1)
mps = BaseMPS(tensors, center_position=0, backend=backend)
with pytest.raises(
ValueError,
match="max_truncation_err"):
mps.position(1, max_truncation_err=1.1)



def test_position_no_normalization(backend):
Expand Down Expand Up @@ -233,6 +239,15 @@ def test_position_no_shift_no_normalization(backend):
Z = mps.position(int(N / 2), normalize=False)
np.testing.assert_allclose(Z, 5.656854)

def test_position_truncation(backend):
D, d, N = 10, 2, 10
tensors = [np.ones((1, d, D))] + [np.ones((D, d, D)) for _ in range(N - 2)
] + [np.ones((D, d, 1))]
mps = BaseMPS(tensors, center_position=0, backend=backend)
mps.position(N-1)
mps.position(0, D=5)
assert np.all(np.array(mps.bond_dimensions) <= 5)


def test_different_dtypes_raises_error():
D, d = 4, 2
Expand Down
Loading