From 018b9ede738368ade0b79fbe80da69c6a30d3381 Mon Sep 17 00:00:00 2001 From: mganahl Date: Wed, 24 Mar 2021 13:07:39 +0100 Subject: [PATCH 1/7] fix apply_twoside_gate in base_mps.py to use QR if no truncation is requested --- tensornetwork/matrixproductstates/base_mps.py | 147 ++++++++++++------ 1 file changed, 98 insertions(+), 49 deletions(-) diff --git a/tensornetwork/matrixproductstates/base_mps.py b/tensornetwork/matrixproductstates/base_mps.py index 26def495f..5209cc4f4 100644 --- a/tensornetwork/matrixproductstates/base_mps.py +++ b/tensornetwork/matrixproductstates/base_mps.py @@ -19,10 +19,11 @@ from functools import partial from tensornetwork.backends.decorators import jit import warnings -from tensornetwork.ncon_interface import ncon from tensornetwork.backend_contextmanager import get_default_backend from tensornetwork.backends.abstract_backend import AbstractBackend from typing import Any, List, Optional, Text, Type, Union, Dict, Sequence +import tensornetwork.ncon_interface as ncon + Tensor = Any @@ -82,10 +83,10 @@ def __init__(self, else: self.backend = backend_factory.get_backend(backend) + # the dtype is deduced from the tensor object. self.tensors = [self.backend.convert_to_tensor(t) for t in tensors] - if not all( - [self.tensors[0].dtype == tensor.dtype for tensor in self.tensors]): + if not all(t.dtype == self.tensors[0].dtype for t in self.tensors): raise TypeError('not all dtypes in BaseMPS.tensors are the same') self.connector_matrix = connector_matrix @@ -96,18 +97,21 @@ def __init__(self, ######################################################################## @partial(jit, backend=self.backend, static_argnums=(1,)) def svd(tensor, max_singular_values=None): - return self.backend.svd(tensor=tensor, pivot_axis=2, - max_singular_values=max_singular_values) + return self.backend.svd( + tensor=tensor, pivot_axis=2, max_singular_values=max_singular_values) + self.svd = svd @partial(jit, backend=self.backend) def qr(tensor): return self.backend.qr(tensor, 2) + self.qr = qr @partial(jit, backend=self.backend) def rq(tensor): return self.backend.rq(tensor, 1) + self.rq = rq self.norm = self.backend.jit(self.backend.norm) @@ -116,12 +120,12 @@ def rq(tensor): ######################################################################## def left_transfer_operator(self, A, l, Abar): - return ncon([A, l, Abar], [[1, 2, -1], [1, 3], [3, 2, -2]], - backend=self.backend.name) + return ncon.ncon([A, l, Abar], [[1, 2, -1], [1, 3], [3, 2, -2]], + backend=self.backend.name) def right_transfer_operator(self, B, r, Bbar): - return ncon([B, r, Bbar], [[-1, 2, 1], [1, 3], [-2, 2, 3]], - backend=self.backend.name) + return ncon.ncon([B, r, Bbar], [[-1, 2, 1], [1, 3], [-2, 2, 3]], + backend=self.backend.name) def __len__(self) -> int: return len(self.tensors) @@ -159,9 +163,9 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number: for n in range(self.center_position, site): Q, R = self.qr(self.tensors[n]) self.tensors[n] = Q - self.tensors[n + 1] = ncon([R, self.tensors[n + 1]], - [[-1, 1], [1, -2, -3]], - backend=self.backend.name) + self.tensors[n + 1] = ncon.ncon([R, self.tensors[n + 1]], + [[-1, 1], [1, -2, -3]], + backend=self.backend.name) Z = self.norm(self.tensors[n + 1]) # for an mps with > O(10) sites one needs to normalize to avoid # over or underflow errors; this takes care of the normalization @@ -178,9 +182,9 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number: # for an mps with > O(10) sites one needs to normalize to avoid # over or underflow errors; this takes care of the normalization self.tensors[n] = Q #Q is a right-isometric tensor of rank 3 - self.tensors[n - 1] = ncon([self.tensors[n - 1], R], - [[-1, -2, 1], [1, -3]], - backend=self.backend.name) + self.tensors[n - 1] = ncon.ncon([self.tensors[n - 1], R], + [[-1, -2, 1], [1, -3]], + backend=self.backend.name) Z = self.norm(self.tensors[n - 1]) if normalize: self.tensors[n - 1] /= Z @@ -191,9 +195,8 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number: @property def dtype(self) -> Type[np.number]: - if not all( - [self.tensors[0].dtype == tensor.dtype for tensor in self.tensors]): - raise TypeError('not all dtype in BaseMPS.tensors are the same') + if not all(t.dtype == self.tensors[0].dtype for t in self.tensors): + raise TypeError('not all dtype in BaseMPS.tensors are the same') return self.tensors[0].dtype @@ -439,7 +442,9 @@ def apply_two_site_gate(self, site1: int, site2: int, max_singular_values: Optional[int] = None, - max_truncation_err: Optional[float] = None) -> Tensor: + max_truncation_err: Optional[float] = None, + center_position: Optional[int] = None, + relative: bool = False) -> Tensor: """Apply a two-site gate to an MPS. This routine will in general destroy any canonical form of the state. If a canonical form is needed, the user can restore it using `FiniteMPS.position`. @@ -450,6 +455,15 @@ def apply_two_site_gate(self, site2: The second site where the gate acts. max_singular_values: The maximum number of singular values to keep. max_truncation_err: The maximum allowed truncation error. + center_position: An optional value to choose the MPS tensor at + `center_position` to be isometric after the application of the gate. + Defaults to `site1`. If the MPS is canonical (i.e. + `BaseMPS.center_position != None`), and if the orthogonality center + coincides with either `site1` or `site2`, the orthogonality center will + be shifted to `center_position` (`site1` by default). If the + orthogonality center does not coincide with `(site1, site2)` then + `MPS.center_position` is set to `None`. + relative: Multiply `max_truncation_err` with the largest singular value. Returns: `Tensor`: A scalar tensor containing the truncated weight of the @@ -473,6 +487,10 @@ def apply_two_site_gate(self, "neighbor gates are currently" "supported".format(site2, site1)) + if center_position is not None and center_position not in (site1, site2): + raise ValueError(f"center_position = {center_position} not " + f"in {(site1, site2)} ") + if (max_singular_values or max_truncation_err) and self.center_position not in (site1, site2): raise ValueError( @@ -481,28 +499,59 @@ def apply_two_site_gate(self, 'is applied at the center position of the MPS'.format( self.center_position, site1, site2)) - gate_node = Node(gate, backend=self.backend) - node1 = Node(self.tensors[site1], backend=self.backend) - node2 = Node(self.tensors[site2], backend=self.backend) - node1[2] ^ node2[0] - gate_node[2] ^ node1[1] - gate_node[3] ^ node2[1] - left_edges = [node1[0], gate_node[0]] - right_edges = [gate_node[1], node2[2]] - result = node1 @ node2 @ gate_node - U, S, V, tw = split_node_full_svd( - result, - left_edges=left_edges, - right_edges=right_edges, - max_singular_values=max_singular_values, - max_truncation_err=max_truncation_err, - left_name=node1.name, - right_name=node2.name) - V.reorder_edges([S[1]] + right_edges) - left_edges = left_edges + [S[1]] - res = contract_between(U, S, name=U.name).reorder_edges(left_edges) - self.tensors[site1] = res.tensor - self.tensors[site2] = V.tensor + use_svd = (max_truncation_err is not None) or (max_singular_values + is not None) + gate = self.backend.convert_to_tensor(gate) + tensor = ncon.ncon([self.tensors[site1], self.tensors[site2], gate], + [[-1, 1, 2], [2, 3, -4], [-2, -3, 1, 3]], + backend=self.backend) + + def set_center_position(site): + if self.center_position is not None: + if self.center_position in (site1, site2): + assert site in (site1, site2) + self.center_position = site + else: + self.center_position = None + + if center_position is None: + center_position = site1 + + if use_svd: + U, S, V, tw = self.backend.svd( + tensor, + pivot_axis=2, + max_singular_values=max_singular_values, + max_truncation_error=max_truncation_err, + relative=relative) + if center_position == site2: + left_tensor = U + right_tensor = ncon.ncon([self.backend.diagflat(S), V], + [[-1, 1], [1, -2, -3]], + backend=self.backend) + set_center_position(site2) + else: + left_tensor = ncon.ncon([U, self.backend.diagflat(S)], + [[-1, -2, 1], [1, -3]], + backend=self.backend) + right_tensor = V + set_center_position(site1) + + else: + tw = self.backend.zeros(1, dtype=self.dtype) + if center_position == site2: + R, Q = self.backend.rq(tensor, pivot_axis=2) + left_tensor = R + right_tensor = Q + set_center_position(site2) + else: + Q, R = self.backend.qr(tensor, pivot_axis=2) + left_tensor = Q + right_tensor = R + set_center_position(site1) + + self.tensors[site1] = left_tensor + self.tensors[site2] = right_tensor return tw def apply_one_site_gate(self, gate: Tensor, site: int) -> None: @@ -519,9 +568,9 @@ def apply_one_site_gate(self, gate: Tensor, site: int) -> None: if site < 0 or site >= len(self): raise ValueError('site = {} is not between 0 <= site < N={}'.format( site, len(self))) - self.tensors[site] = ncon([gate, self.tensors[site]], - [[-2, 1], [-1, 1, -3]], - backend=self.backend.name) + self.tensors[site] = ncon.ncon([gate, self.tensors[site]], + [[-2, 1], [-1, 1, -3]], + backend=self.backend.name) def check_orthonormality(self, which: Text, site: int) -> Tensor: """Check orthonormality of tensor at site `site`. @@ -555,8 +604,8 @@ def check_orthonormality(self, which: Text, site: int) -> Tensor: M=self.backend.sparse_shape(result)[1], dtype=self.dtype) return self.backend.sqrt( - ncon([tmp, self.backend.conj(tmp)], [[1, 2], [1, 2]], - backend=self.backend)) + ncon.ncon([tmp, self.backend.conj(tmp)], [[1, 2], [1, 2]], + backend=self.backend)) # pylint: disable=inconsistent-return-statements def check_canonical(self) -> Any: @@ -601,9 +650,9 @@ def get_tensor(self, site: int) -> Tensor: 'index `site` has to be larger than 0 (found `site`={}).'.format( site)) if (site == len(self) - 1) and (self.connector_matrix is not None): - return ncon([self.tensors[site], self.connector_matrix], - [[-1, -2, 1], [1, -3]], - backend=self.backend.name) + return ncon.ncon([self.tensors[site], self.connector_matrix], + [[-1, -2, 1], [1, -3]], + backend=self.backend.name) return self.tensors[site] def canonicalize(self, *args, **kwargs) -> np.number: From d7f4769a034645cf285f17b841b11c25bb580be1 Mon Sep 17 00:00:00 2001 From: mganahl Date: Wed, 24 Mar 2021 13:19:58 +0100 Subject: [PATCH 2/7] fix indentation --- tensornetwork/matrixproductstates/base_mps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensornetwork/matrixproductstates/base_mps.py b/tensornetwork/matrixproductstates/base_mps.py index 5209cc4f4..1051ffe63 100644 --- a/tensornetwork/matrixproductstates/base_mps.py +++ b/tensornetwork/matrixproductstates/base_mps.py @@ -196,7 +196,7 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number: @property def dtype(self) -> Type[np.number]: if not all(t.dtype == self.tensors[0].dtype for t in self.tensors): - raise TypeError('not all dtype in BaseMPS.tensors are the same') + raise TypeError('not all dtype in BaseMPS.tensors are the same') return self.tensors[0].dtype From 0841178a8f15a93c52764d3d52054d2b827f261e Mon Sep 17 00:00:00 2001 From: mganahl Date: Wed, 24 Mar 2021 14:09:00 +0100 Subject: [PATCH 3/7] fiux test --- tensornetwork/backends/jax/jax_backend_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index 89cfbc9b0..82faf4873 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -894,7 +894,7 @@ def matvec(vec): init /= jnp.linalg.norm(init) ncv = 20 - numeig = 4 + numeig = 3 which = 'SA' tol = 1E-10 maxiter = 30 From f9fd3c11ba3bf3859f4f2b76d2dc2c038c768917 Mon Sep 17 00:00:00 2001 From: mganahl Date: Thu, 25 Mar 2021 10:01:43 +0100 Subject: [PATCH 4/7] add D and max_truncation_err arguments to position to allow optional MPS truncation --- tensornetwork/matrixproductstates/base_mps.py | 83 +++++++++++++++---- 1 file changed, 66 insertions(+), 17 deletions(-) diff --git a/tensornetwork/matrixproductstates/base_mps.py b/tensornetwork/matrixproductstates/base_mps.py index 26def495f..a8b84924c 100644 --- a/tensornetwork/matrixproductstates/base_mps.py +++ b/tensornetwork/matrixproductstates/base_mps.py @@ -94,10 +94,18 @@ def __init__(self, ######################################################################## ########## define functions for jitted operations ########## ######################################################################## - @partial(jit, backend=self.backend, static_argnums=(1,)) - def svd(tensor, max_singular_values=None): - return self.backend.svd(tensor=tensor, pivot_axis=2, - max_singular_values=max_singular_values) + @partial(jit, backend=self.backend, static_argnums=(1, 2, 3)) + def svd(tensor, + pivot_axis=2, + max_singular_values=None, + max_truncation_error=None): + return self.backend.svd( + tensor=tensor, + pivot_axis=pivot_axis, + max_singular_values=max_singular_values, + max_truncation_error=max_truncation_error, + relative=True) + self.svd = svd @partial(jit, backend=self.backend) @@ -126,12 +134,17 @@ def right_transfer_operator(self, B, r, Bbar): def __len__(self) -> int: return len(self.tensors) - def position(self, site: int, normalize: Optional[bool] = True) -> np.number: + def position(self, site: int, normalize: Optional[bool] = True, + D: Optional[int] = None, + max_truncation_err: Optional[float] = None) -> np.number: """Shift `center_position` to `site`. Args: site: The site to which FiniteMPS.center_position should be shifted normalize: If `True`, normalize matrices when shifting. + D: If not `None`, truncate the MPS bond dimensions to `D`. + max_truncation_err: if not `None`, truncate each bond dimension, + but keeping the truncation error below `max_truncation_err`. Returns: `Tensor`: The norm of the tensor at `FiniteMPS.center_position` Raises: @@ -141,11 +154,15 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number: raise ValueError( "BaseMPS.center_position is `None`, cannot shift `center_position`." "Reset `center_position` manually or use `canonicalize`") - + if max_truncation_err is not None and max_truncation_err >= 1.0: + raise ValueError("max_truncation_err should be 0 <= max_truncation_er < 1," + f" found max_truncation_err = {max_truncation_err}") #`site` has to be between 0 and len(mps) - 1 if site >= len(self.tensors) or site < 0: raise ValueError('site = {} not between values' ' 0 < site < N = {}'.format(site, len(self))) + + #nothing to do if site == self.center_position: Z = self.norm(self.tensors[self.center_position]) @@ -153,13 +170,25 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number: self.tensors[self.center_position] /= Z return Z - #shift center_position to the right using QR decomposition + #shift center_position to the right using QR or SV decomposition if site > self.center_position: n = self.center_position for n in range(self.center_position, site): - Q, R = self.qr(self.tensors[n]) - self.tensors[n] = Q - self.tensors[n + 1] = ncon([R, self.tensors[n + 1]], + use_svd = (D is not None and D < self.bond_dimension(n + 1) + ) or max_truncation_err is not None + if not use_svd: + isometry, rest = self.qr(self.tensors[n]) + else: + isometry, S, V, tw = self.svd( + self.tensors[n], + pivot_axis=2, + max_singular_values=D, + max_truncation_error=max_truncation_err) + rest = ncon([self.backend.diagflat(S), V], [[-1, 1], [1, -2]], + backend=self.backend) + + self.tensors[n] = isometry + self.tensors[n + 1] = ncon([rest, self.tensors[n + 1]], [[-1, 1], [1, -2, -3]], backend=self.backend.name) Z = self.norm(self.tensors[n + 1]) @@ -170,18 +199,29 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number: self.center_position = site - #shift center_position to the left using RQ decomposition + #shift center_position to the left using RQ or SV decomposition else: for n in reversed(range(site + 1, self.center_position + 1)): - - R, Q = self.rq(self.tensors[n]) - # for an mps with > O(10) sites one needs to normalize to avoid - # over or underflow errors; this takes care of the normalization - self.tensors[n] = Q #Q is a right-isometric tensor of rank 3 - self.tensors[n - 1] = ncon([self.tensors[n - 1], R], + use_svd = (D is not None and D < self.bond_dimension(n) + ) or max_truncation_err is not None + if not use_svd: + rest, isometry = self.rq(self.tensors[n]) + else: + U, S, isometry, tw = self.svd( + self.tensors[n], + pivot_axis=1, + max_singular_values=D, + max_truncation_error=max_truncation_err) + rest = ncon([U, self.backend.diagflat(S)], [[-1, 1], [1, -2]], + backend=self.backend) + + self.tensors[n] = isometry #a right-isometric tensor of rank 3 + self.tensors[n - 1] = ncon([self.tensors[n - 1], rest], [[-1, -2, 1], [1, -3]], backend=self.backend.name) Z = self.norm(self.tensors[n - 1]) + # for an mps with > O(10) sites one needs to normalize to avoid + # over or underflow errors; this takes care of the normalization if normalize: self.tensors[n - 1] /= Z @@ -200,6 +240,15 @@ def dtype(self) -> Type[np.number]: def save(self, path: str): raise NotImplementedError() + def bond_dimension(self, bond) -> List: + """The bond dimension of `bond`""" + if bond > len(self): + raise IndexError(f"bond {bond} out of bounds for" + f" an MPS of length {len(self)}") + if bond < len(self): + return self.tensors[bond].shape[0] + return self.tensors[bond].shape[2] + @property def bond_dimensions(self) -> List: """A list of bond dimensions of `BaseMPS`""" From 8e9090387c28ff8a4865df878150b1a7f91a50b3 Mon Sep 17 00:00:00 2001 From: mganahl Date: Thu, 25 Mar 2021 10:01:43 +0100 Subject: [PATCH 5/7] add D and max_truncation_err arguments to position to allow optional MPS truncation --- tensornetwork/matrixproductstates/base_mps.py | 83 +++++++++++++++---- 1 file changed, 66 insertions(+), 17 deletions(-) diff --git a/tensornetwork/matrixproductstates/base_mps.py b/tensornetwork/matrixproductstates/base_mps.py index 26def495f..a8b84924c 100644 --- a/tensornetwork/matrixproductstates/base_mps.py +++ b/tensornetwork/matrixproductstates/base_mps.py @@ -94,10 +94,18 @@ def __init__(self, ######################################################################## ########## define functions for jitted operations ########## ######################################################################## - @partial(jit, backend=self.backend, static_argnums=(1,)) - def svd(tensor, max_singular_values=None): - return self.backend.svd(tensor=tensor, pivot_axis=2, - max_singular_values=max_singular_values) + @partial(jit, backend=self.backend, static_argnums=(1, 2, 3)) + def svd(tensor, + pivot_axis=2, + max_singular_values=None, + max_truncation_error=None): + return self.backend.svd( + tensor=tensor, + pivot_axis=pivot_axis, + max_singular_values=max_singular_values, + max_truncation_error=max_truncation_error, + relative=True) + self.svd = svd @partial(jit, backend=self.backend) @@ -126,12 +134,17 @@ def right_transfer_operator(self, B, r, Bbar): def __len__(self) -> int: return len(self.tensors) - def position(self, site: int, normalize: Optional[bool] = True) -> np.number: + def position(self, site: int, normalize: Optional[bool] = True, + D: Optional[int] = None, + max_truncation_err: Optional[float] = None) -> np.number: """Shift `center_position` to `site`. Args: site: The site to which FiniteMPS.center_position should be shifted normalize: If `True`, normalize matrices when shifting. + D: If not `None`, truncate the MPS bond dimensions to `D`. + max_truncation_err: if not `None`, truncate each bond dimension, + but keeping the truncation error below `max_truncation_err`. Returns: `Tensor`: The norm of the tensor at `FiniteMPS.center_position` Raises: @@ -141,11 +154,15 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number: raise ValueError( "BaseMPS.center_position is `None`, cannot shift `center_position`." "Reset `center_position` manually or use `canonicalize`") - + if max_truncation_err is not None and max_truncation_err >= 1.0: + raise ValueError("max_truncation_err should be 0 <= max_truncation_er < 1," + f" found max_truncation_err = {max_truncation_err}") #`site` has to be between 0 and len(mps) - 1 if site >= len(self.tensors) or site < 0: raise ValueError('site = {} not between values' ' 0 < site < N = {}'.format(site, len(self))) + + #nothing to do if site == self.center_position: Z = self.norm(self.tensors[self.center_position]) @@ -153,13 +170,25 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number: self.tensors[self.center_position] /= Z return Z - #shift center_position to the right using QR decomposition + #shift center_position to the right using QR or SV decomposition if site > self.center_position: n = self.center_position for n in range(self.center_position, site): - Q, R = self.qr(self.tensors[n]) - self.tensors[n] = Q - self.tensors[n + 1] = ncon([R, self.tensors[n + 1]], + use_svd = (D is not None and D < self.bond_dimension(n + 1) + ) or max_truncation_err is not None + if not use_svd: + isometry, rest = self.qr(self.tensors[n]) + else: + isometry, S, V, tw = self.svd( + self.tensors[n], + pivot_axis=2, + max_singular_values=D, + max_truncation_error=max_truncation_err) + rest = ncon([self.backend.diagflat(S), V], [[-1, 1], [1, -2]], + backend=self.backend) + + self.tensors[n] = isometry + self.tensors[n + 1] = ncon([rest, self.tensors[n + 1]], [[-1, 1], [1, -2, -3]], backend=self.backend.name) Z = self.norm(self.tensors[n + 1]) @@ -170,18 +199,29 @@ def position(self, site: int, normalize: Optional[bool] = True) -> np.number: self.center_position = site - #shift center_position to the left using RQ decomposition + #shift center_position to the left using RQ or SV decomposition else: for n in reversed(range(site + 1, self.center_position + 1)): - - R, Q = self.rq(self.tensors[n]) - # for an mps with > O(10) sites one needs to normalize to avoid - # over or underflow errors; this takes care of the normalization - self.tensors[n] = Q #Q is a right-isometric tensor of rank 3 - self.tensors[n - 1] = ncon([self.tensors[n - 1], R], + use_svd = (D is not None and D < self.bond_dimension(n) + ) or max_truncation_err is not None + if not use_svd: + rest, isometry = self.rq(self.tensors[n]) + else: + U, S, isometry, tw = self.svd( + self.tensors[n], + pivot_axis=1, + max_singular_values=D, + max_truncation_error=max_truncation_err) + rest = ncon([U, self.backend.diagflat(S)], [[-1, 1], [1, -2]], + backend=self.backend) + + self.tensors[n] = isometry #a right-isometric tensor of rank 3 + self.tensors[n - 1] = ncon([self.tensors[n - 1], rest], [[-1, -2, 1], [1, -3]], backend=self.backend.name) Z = self.norm(self.tensors[n - 1]) + # for an mps with > O(10) sites one needs to normalize to avoid + # over or underflow errors; this takes care of the normalization if normalize: self.tensors[n - 1] /= Z @@ -200,6 +240,15 @@ def dtype(self) -> Type[np.number]: def save(self, path: str): raise NotImplementedError() + def bond_dimension(self, bond) -> List: + """The bond dimension of `bond`""" + if bond > len(self): + raise IndexError(f"bond {bond} out of bounds for" + f" an MPS of length {len(self)}") + if bond < len(self): + return self.tensors[bond].shape[0] + return self.tensors[bond].shape[2] + @property def bond_dimensions(self) -> List: """A list of bond dimensions of `BaseMPS`""" From 50fe965224b572213da8b34c050269a8c43ec463 Mon Sep 17 00:00:00 2001 From: mganahl Date: Thu, 25 Mar 2021 10:28:59 +0100 Subject: [PATCH 6/7] minor bug fixes, tests added --- tensornetwork/matrixproductstates/base_mps.py | 22 +++++++------------ .../matrixproductstates/base_mps_test.py | 15 +++++++++++++ tensornetwork/matrixproductstates/dmrg.py | 6 ++--- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/tensornetwork/matrixproductstates/base_mps.py b/tensornetwork/matrixproductstates/base_mps.py index a8b84924c..7b7de076a 100644 --- a/tensornetwork/matrixproductstates/base_mps.py +++ b/tensornetwork/matrixproductstates/base_mps.py @@ -85,7 +85,7 @@ def __init__(self, # the dtype is deduced from the tensor object. self.tensors = [self.backend.convert_to_tensor(t) for t in tensors] if not all( - [self.tensors[0].dtype == tensor.dtype for tensor in self.tensors]): + (self.tensors[0].dtype == tensor.dtype for tensor in self.tensors)): raise TypeError('not all dtypes in BaseMPS.tensors are the same') self.connector_matrix = connector_matrix @@ -155,8 +155,8 @@ def position(self, site: int, normalize: Optional[bool] = True, "BaseMPS.center_position is `None`, cannot shift `center_position`." "Reset `center_position` manually or use `canonicalize`") if max_truncation_err is not None and max_truncation_err >= 1.0: - raise ValueError("max_truncation_err should be 0 <= max_truncation_er < 1," - f" found max_truncation_err = {max_truncation_err}") + raise ValueError("max_truncation_err should be 0 <= max_truncation_er" + f" < 1, found max_truncation_err = {max_truncation_err}") #`site` has to be between 0 and len(mps) - 1 if site >= len(self.tensors) or site < 0: raise ValueError('site = {} not between values' @@ -179,11 +179,8 @@ def position(self, site: int, normalize: Optional[bool] = True, if not use_svd: isometry, rest = self.qr(self.tensors[n]) else: - isometry, S, V, tw = self.svd( - self.tensors[n], - pivot_axis=2, - max_singular_values=D, - max_truncation_error=max_truncation_err) + isometry, S, V, _ = self.svd(self.tensors[n], 2, D, + max_truncation_err) rest = ncon([self.backend.diagflat(S), V], [[-1, 1], [1, -2]], backend=self.backend) @@ -207,11 +204,8 @@ def position(self, site: int, normalize: Optional[bool] = True, if not use_svd: rest, isometry = self.rq(self.tensors[n]) else: - U, S, isometry, tw = self.svd( - self.tensors[n], - pivot_axis=1, - max_singular_values=D, - max_truncation_error=max_truncation_err) + U, S, isometry, _ = self.svd(self.tensors[n], 1, D, + max_truncation_err) rest = ncon([U, self.backend.diagflat(S)], [[-1, 1], [1, -2]], backend=self.backend) @@ -232,7 +226,7 @@ def position(self, site: int, normalize: Optional[bool] = True, @property def dtype(self) -> Type[np.number]: if not all( - [self.tensors[0].dtype == tensor.dtype for tensor in self.tensors]): + (self.tensors[0].dtype == tensor.dtype for tensor in self.tensors)): raise TypeError('not all dtype in BaseMPS.tensors are the same') return self.tensors[0].dtype diff --git a/tensornetwork/matrixproductstates/base_mps_test.py b/tensornetwork/matrixproductstates/base_mps_test.py index 7a328adca..9228b11b0 100644 --- a/tensornetwork/matrixproductstates/base_mps_test.py +++ b/tensornetwork/matrixproductstates/base_mps_test.py @@ -187,6 +187,12 @@ def test_position_raises_error(backend): " `None`, cannot shift `center_position`." "Reset `center_position` manually or use `canonicalize`"): mps.position(1) + mps = BaseMPS(tensors, center_position=0, backend=backend) + with pytest.raises( + ValueError, + match="max_truncation_err"): + mps.position(1, max_truncation_err=1.1) + def test_position_no_normalization(backend): @@ -233,6 +239,15 @@ def test_position_no_shift_no_normalization(backend): Z = mps.position(int(N / 2), normalize=False) np.testing.assert_allclose(Z, 5.656854) +def test_position_truncation(backend): + D, d, N = 10, 2, 10 + tensors = [np.ones((1, d, D))] + [np.ones((D, d, D)) for _ in range(N - 2) + ] + [np.ones((D, d, 1))] + mps = BaseMPS(tensors, center_position=0, backend=backend) + mps.position(N-1) + mps.position(0, D=5) + assert np.all(np.array(mps.bond_dimensions) <= 5) + def test_different_dtypes_raises_error(): D, d = 4, 2 diff --git a/tensornetwork/matrixproductstates/dmrg.py b/tensornetwork/matrixproductstates/dmrg.py index ca3d26738..647e3339a 100644 --- a/tensornetwork/matrixproductstates/dmrg.py +++ b/tensornetwork/matrixproductstates/dmrg.py @@ -297,8 +297,7 @@ def _optimize_2s_local(self, energy = energies[0] local_ground_state /= self.backend.norm(local_ground_state) - u, s, vh, _ = self.mps.svd(local_ground_state, - max_bond_dim) + u, s, vh, _ = self.mps.svd(local_ground_state, 2, max_bond_dim, None) s = self.backend.diagflat(s) self.mps.tensors[site] = u if site < len(self.mps.tensors) - 1: @@ -329,8 +328,7 @@ def _optimize_2s_local(self, energy = energies[0] local_ground_state /= self.backend.norm(local_ground_state) - u, s, vh, _ = self.mps.svd(local_ground_state, - max_bond_dim) + u, s, vh, _ = self.mps.svd(local_ground_state, 2, max_bond_dim, None) s = self.backend.diagflat(s) self.mps.tensors[site] = vh if site > 0: From 7a6138e02202d895704fdd0648ac7d3cc51687eb Mon Sep 17 00:00:00 2001 From: mganahl Date: Fri, 26 Mar 2021 08:27:23 +0100 Subject: [PATCH 7/7] fix ncon call --- tensornetwork/matrixproductstates/base_mps.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tensornetwork/matrixproductstates/base_mps.py b/tensornetwork/matrixproductstates/base_mps.py index 862629b9a..efa1c1a4d 100644 --- a/tensornetwork/matrixproductstates/base_mps.py +++ b/tensornetwork/matrixproductstates/base_mps.py @@ -184,13 +184,13 @@ def position(self, site: int, normalize: Optional[bool] = True, else: isometry, S, V, _ = self.svd(self.tensors[n], 2, D, max_truncation_err) - rest = ncon([self.backend.diagflat(S), V], [[-1, 1], [1, -2]], - backend=self.backend) + rest = ncon.ncon([self.backend.diagflat(S), V], [[-1, 1], [1, -2]], + backend=self.backend) self.tensors[n] = isometry - self.tensors[n + 1] = ncon([rest, self.tensors[n + 1]], - [[-1, 1], [1, -2, -3]], - backend=self.backend.name) + self.tensors[n + 1] = ncon.ncon([rest, self.tensors[n + 1]], + [[-1, 1], [1, -2, -3]], + backend=self.backend.name) Z = self.norm(self.tensors[n + 1]) # for an mps with > O(10) sites one needs to normalize to avoid # over or underflow errors; this takes care of the normalization @@ -209,13 +209,13 @@ def position(self, site: int, normalize: Optional[bool] = True, else: U, S, isometry, _ = self.svd(self.tensors[n], 1, D, max_truncation_err) - rest = ncon([U, self.backend.diagflat(S)], [[-1, 1], [1, -2]], - backend=self.backend) + rest = ncon.ncon([U, self.backend.diagflat(S)], [[-1, 1], [1, -2]], + backend=self.backend) self.tensors[n] = isometry #a right-isometric tensor of rank 3 - self.tensors[n - 1] = ncon([self.tensors[n - 1], rest], - [[-1, -2, 1], [1, -3]], - backend=self.backend.name) + self.tensors[n - 1] = ncon.ncon([self.tensors[n - 1], rest], + [[-1, -2, 1], [1, -3]], + backend=self.backend.name) Z = self.norm(self.tensors[n - 1]) # for an mps with > O(10) sites one needs to normalize to avoid # over or underflow errors; this takes care of the normalization