# 4.2: 矩阵乘积态的中心正交形式（Central Orthogonal Form）

## 理论基础

中心正交化的理论基础是矩阵乘积态的规范自由度 (gauge degrees of freedom)

简而言之：不同的局域张量构成的矩阵乘积态可以对应相同的全局张量

![gauge_transformation_example](./images/gauge_transformation_example.png)

> 在虚拟指标中间插入一个 Unitary 和它的 Hermitian 不会改变全局张量

中心正交形式的矩阵乘积态是通过引入对各个局域张量的正交、归一约束条件，消除（绝大部分）的规范自由度

## 中心正交形式的约束条件

* 正交中心左侧的张量满足从左至右的正交条件：$(A_{[0,1]}^{(n)})^\dagger A_{[0,1]}^{(n)} = I \quad (n < n_c)$
* 正交中心右侧的张量满足从右至左的正交条件：$A_{[0]}^{(n)}(A_{[0]}^{(n)})^\dagger = I \quad (n > n_c)$
* 对于归一化矩阵乘积态，正交中心处的张量（被称为中心张量）满足归一化条件：$|A^{(n_c)}| = 1$

![central-orthogonal-form](./images/central_orthogonal_form_illustration.png)
* 正交方向由虚拟指标上的箭头标记
* (b)和(c)又称等距(isometric)条件，意味着对应的张量为等距张量

## 中心正交化

将给定矩阵乘积态转换为中心正交形式：从左至右地分解正交中心左侧的张量，以及从右至左地分解正交中心右侧的张量，使得正交中心两侧的所有张量满足相应的正交条件

![central-orthogonalization-left-to-right](./images/central-orgonalization-left-to-right.png)

从左至右的分解公式：

$A_{[0,1]}^{(n)} \stackrel{\text{SVD}}{\longrightarrow} \tilde{A}^{(n)}SV$

$SVA_{[0]}^{(n+1)} \stackrel{\text{收缩、变形}}{\longrightarrow} A^{(n+1)}$


性质：
* 新获得的$A^{(n)}$与$A^{(n+1)}$的收缩等于变换之前$A^{(n)}$与$A^{(n+1)}$的收缩；
* 因此，中心正交化不改变全局张量，因此属于**规范变换**（gauge transformation）；
* 也可采用QR分解

## 中心正交矩阵乘积态的性质
* 全局张量的模等于中心张量的模
* 通过中心张量的奇异值分解可获得 MPS 的纠缠
    * 目标：考虑在**正交中心**左侧的虚拟指标处，对量子态进行二分并求其二分纠缠谱。
    * 一般做法为对全局张量进行矩阵化：将正交中心左侧的所有物理变形为矩阵左指标，其余指标变形为右指标。该矩阵化的奇异谱即为上述二分下的纠缠谱。
    * 在中心正交形式下，该奇异谱等于中心张量矩阵化 $A^{(n)}_{[0]}$ 的奇异谱。
    * 在**正交中心**右侧的虚拟指标处，对量子态进行二分并求其二分纠缠谱的做法类似
    ![center-tensor-singular-spetrum-example](./images/center_tensor_singular_spectrum_example.png)
* 约化密度矩阵（reduceddensitymatrix）的计算可以加速
    * 定义：$\hat{\rho}^{(n)} = Tr_n|\varphi\rangle\langle\varphi|$
    * 部分正交张量的计算可以被省略，从而提升计算效率
        * 当不求迹指标位于正交中心时，MPS的约化密度矩阵等于中心张量约化密度矩阵（下图 a）
        * 当不求迹指标不位于正交中心时，需处理该指标与正交中心间的张量，来计算MPS的约化密度矩阵 （下图 b）
        ![mps_cof_reducted_density_matrix](./images/mps_cof_reduced_density_matrix_illustration.png)


## Code

In [1]:
# |default_exp mps.modules
# |export
import torch
from typing import List, Tuple, Literal
from tensor_network.mps.functional import gen_random_mps_tensors, MPSType

### Orthogonalization One Step

In [2]:
# |export mps.functional
from typing import Literal, Tuple


def orthogonalize_left2right_step(
    mps_tensors: List[torch.Tensor],
    local_tensor_idx: int,
    mode: Literal["svd", "qr"],
    truncate_dim: int | None = None,
    return_locals: bool = False,
) -> List[torch.Tensor] | Tuple[torch.Tensor, torch.Tensor]:
    """
    One step of orthogonalization from left to right, which will make the local tensor isometric and the right one to it transformed.

    Args:
        mps_tensors: List[torch.Tensor], MPS tensors
        local_tensor_idx: int, the index of the local tensor to be orthogonalized.
        mode: Literal["svd", "qr"], the mode of orthogonalization.
        truncate_dim: int | None, the dimension to be truncated. If None, no truncation will be performed.
        return_locals: bool, whether to return the local tensors. If True, only the local and the right one will be returned.

    Returns:
        List[torch.Tensor], the list of tensors after one step of orthogonalization from left to right.
    """
    length = len(mps_tensors)
    assert length > 1, "mps_tensors must have at least 2 tensors"
    assert 0 <= local_tensor_idx < length - 1, "local_tensor_idx must be in [0, length - 2]"
    mode = mode.lower()
    assert mode in ["svd", "qr"], "mode must be either 'svd' or 'qr'"
    local_tensor = mps_tensors[local_tensor_idx]
    shape = local_tensor.shape  # (virtual_dim, physical_dim, virtual_dim)
    if truncate_dim is not None:
        virtual_dim = shape[2]
        assert truncate_dim > 0, "truncate_dim must be positive"
        truncate_dim = min(truncate_dim, virtual_dim)
        assert mode == "svd", "mode must be 'svd' when truncate_dim is provided"
        need_truncate = True
    else:
        need_truncate = False

    view_matrix = local_tensor.view(-1, shape[2])

    if mode == "svd":
        u, lm, v = torch.linalg.svd(view_matrix, full_matrices=False)
        if need_truncate:
            u = u[:, :truncate_dim]
            lm = lm[:truncate_dim]  # (truncate_dim)
            v = v[:truncate_dim, :]  # (truncate_dim, virtual_dim)
            r = lm.unsqueeze(1) * v  # (truncate_dim, virtual_dim)
        else:
            r = lm.unsqueeze(1) * v  # (virtual_dim, virtual_dim)
    else:
        u, r = torch.linalg.qr(view_matrix)

    new_local_tensor = u.reshape(
        shape[0], shape[1], -1
    )  # (virtual_dim, physical_dim, virtual_dim or truncate_dim)
    local_tensor_right = mps_tensors[
        local_tensor_idx + 1
    ]  # (virtual_dim, physical_dim, virtual_dim)
    new_local_tensor_right = torch.einsum("ab,bcd->acd", r, local_tensor_right)
    if return_locals:
        return new_local_tensor, new_local_tensor_right
    else:
        return (
            mps_tensors[:local_tensor_idx]
            + [new_local_tensor, new_local_tensor_right]
            + mps_tensors[local_tensor_idx + 2 :]
        )


def orthogonalize_right2left_step(
    mps_tensors: List[torch.Tensor],
    local_tensor_idx: int,
    mode: Literal["svd", "qr"],
    truncate_dim: int | None = None,
    return_locals: bool = False,
) -> List[torch.Tensor] | Tuple[torch.Tensor, torch.Tensor]:
    """
    One step of orthogonalization from right to left, which will make the local tensor isometric and the left one to it transformed.

    Args:
        mps_tensors: List[torch.Tensor], MPS tensors
        local_tensor_idx: int, the index of the local tensor to be orthogonalized.
        mode: Literal["svd", "qr"], the mode of orthogonalization.
        truncate_dim: int | None, the dimension to be truncated. If None, no truncation will be performed.
        return_locals: bool, whether to return the local tensors. If True, only the local and the left one will be returned.

    Returns:
        List[torch.Tensor], the list of tensors after one step of orthogonalization from right to left.
    """
    length = len(mps_tensors)
    assert length > 1, "mps_tensors must have at least 2 tensors"
    assert 1 <= local_tensor_idx < length, "local_tensor_idx must be in [1, length - 1]"
    mode = mode.lower()
    assert mode in ["svd", "qr"], "mode must be either 'svd' or 'qr'"
    local_tensor = mps_tensors[local_tensor_idx]
    shape = local_tensor.shape  # (virtual_dim, physical_dim, virtual_dim)
    if truncate_dim is not None:
        virtual_dim = shape[0]
        assert truncate_dim > 0, "truncate_dim must be positive"
        truncate_dim = min(truncate_dim, virtual_dim)
        assert mode == "svd", "mode must be 'svd' when truncate_dim is provided"
        need_truncate = True
    else:
        need_truncate = False

    view_matrix = local_tensor.view(
        shape[0], -1
    ).t()  # (virtual_dim, virtual_dim * physical_dim) -> (virtual_dim * physical_dim, virtual_dim)
    if mode == "svd":
        u, lm, v = torch.linalg.svd(view_matrix, full_matrices=False)
        if need_truncate:
            u = u[:, :truncate_dim]
            lm = lm[:truncate_dim]  # (truncate_dim)
            v = v[:truncate_dim, :]  # (truncate_dim, virtual_dim)
            r = lm.unsqueeze(1) * v  # (truncate_dim, virtual_dim)
        else:
            r = lm.unsqueeze(1) * v  # (virtual_dim, virtual_dim)
    else:
        u, r = torch.linalg.qr(view_matrix)

    new_local_tensor = u.t().reshape(
        -1, shape[1], shape[2]
    )  # (virtual_dim or truncate_dim, physical_dim, virtual_dim)
    local_tensor_left = mps_tensors[
        local_tensor_idx - 1
    ]  # (virtual_dim, physical_dim, virtual_dim)
    new_local_tensor_left = torch.einsum("abc,dc->abd", local_tensor_left, r)
    if return_locals:
        return new_local_tensor_left, new_local_tensor
    else:
        return (
            mps_tensors[: local_tensor_idx - 1]
            + [new_local_tensor_left, new_local_tensor]
            + mps_tensors[local_tensor_idx + 1 :]
        )

#### Test

In [3]:
from tensor_network import setup_ref_code_import
from Library.MatrixProductState import MPS_basic
from copy import deepcopy

length = 8

para = {"length": length, "d": 3, "chi": 4, "dtype": torch.complex128}

for i in range(length - 1):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_left2right(i, "svd")
    orthogonalized_mps_tensors = orthogonalize_left2right_step(mps_tensors, i, "svd")
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])

# with truncate_dim
for i in range(length - 1):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_left2right(i, "svd", 2)
    orthogonalized_mps_tensors = orthogonalize_left2right_step(
        mps_tensors, i, "svd", truncate_dim=2
    )
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])


for i in range(1, length):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_right2left(i, "svd")
    orthogonalized_mps_tensors = orthogonalize_right2left_step(mps_tensors, i, "svd")
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])

# with truncate_dim
for i in range(1, length):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_right2left(i, "svd", 2)
    orthogonalized_mps_tensors = orthogonalize_right2left_step(
        mps_tensors, i, "svd", truncate_dim=2
    )
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])

### Orthogonalization

In [4]:
# |export mps.functional
def orthogonalize_arange(
    mps_tensors: List[torch.Tensor],
    start_idx: int,
    end_idx: int,
    mode: Literal["svd", "qr"],
    truncate_dim: int | None = None,
    return_changed: bool = False,
) -> List[torch.Tensor] | Tuple[List[torch.Tensor], List[int]]:
    """
    Perform orthogonalization on the range of tensors.

    Args:
        mps_tensors: List[torch.Tensor], MPS tensors
        start_idx: int, the start index of the range
        end_idx: int, the end index of the range
        mode: Literal["svd", "qr"], the mode of orthogonalization
        truncate_dim: int | None, the dimension to be truncated. If None, no truncation will be performed.
        return_changed: bool, whether to return the changed tensors. If True, changed tensors' indices will be returned as well.

    Returns:
        List[torch.Tensor], the list of tensors after orthogonalization
    """
    length = len(mps_tensors)
    assert length > 1, "mps_tensors must have at least 2 tensors"
    assert 0 <= start_idx < length and 0 <= end_idx < length, (
        "start_idx and end_idx must be in [0, length - 1]"
    )
    mps_tensors = [m for m in mps_tensors]
    changed_indices = set()
    if start_idx < end_idx:
        for idx in range(start_idx, end_idx, 1):
            local, local_right = orthogonalize_left2right_step(
                mps_tensors, idx, mode, truncate_dim, return_locals=True
            )
            mps_tensors[idx] = local
            mps_tensors[idx + 1] = local_right
            changed_indices.add(idx)
            changed_indices.add(idx + 1)
    elif start_idx > end_idx:
        for idx in range(start_idx, end_idx, -1):
            local_left, local = orthogonalize_right2left_step(
                mps_tensors, idx, mode, truncate_dim, return_locals=True
            )
            mps_tensors[idx - 1] = local_left
            mps_tensors[idx] = local
            changed_indices.add(idx - 1)
            changed_indices.add(idx)
    else:
        # do nothing when start_idx == end_idx
        pass

    if return_changed:
        changed_indices = list(changed_indices)
        changed_indices.sort()
        return mps_tensors, changed_indices
    else:
        return mps_tensors

#### Test

In [5]:
from random import randint

num_trials = 20

length = 8

para = {"length": length, "d": 3, "chi": 4, "dtype": torch.complex128}

i = 0
while i < num_trials:
    start_idx = randint(0, length - 1)
    end_idx = randint(0, length - 1)
    if start_idx == end_idx:
        continue
    i += 1
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_n1_n2(start_idx, end_idx, "svd", -1, False)
    orthogonalized_mps_tensors = orthogonalize_arange(mps_tensors, start_idx, end_idx, "svd")
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])


# with truncate_dim = 2
i = 0
while i < num_trials:
    start_idx = randint(0, length - 1)
    end_idx = randint(0, length - 1)
    if start_idx == end_idx:
        continue
    i += 1
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_n1_n2(start_idx, end_idx, "svd", 2, False)
    orthogonalized_mps_tensors = orthogonalize_arange(mps_tensors, start_idx, end_idx, "svd", 2)
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])

### MPS Module and Center Orthogonalization 

In [None]:
# |export
from tensor_network.mps.functional import (
    orthogonalize_arange,
    calc_global_tensor_by_tensordot,
    calculate_mps_norm_factors,
    calc_inner_product,
    tt_decomposition,
)
import sys
from einops import einsum


class MPS:
    def __init__(
        self,
        *,
        mps_tensors: List[torch.Tensor] | None = None,
        length: int | None = None,
        physical_dim: int | None = None,
        virtual_dim: int | None = None,
        mps_type: MPSType | None = None,
        dtype: torch.dtype | None = None,
        device: torch.device | None = None,
        requires_grad: bool | None = None,
    ) -> None:
        if mps_tensors is None:
            assert (
                length is not None
                and physical_dim is not None
                and virtual_dim is not None
                and mps_type is not None
                and dtype is not None
                and device is not None
                and requires_grad is not None
            ), (
                f"mps_tensors is None, so all arguments must be provided, but got {mps_tensors=}, {length=}, {physical_dim=}, {virtual_dim=}, {mps_type=}, {dtype=}, {device=}, {requires_grad=}"
            )
            mps_tensors = gen_random_mps_tensors(
                length, physical_dim, virtual_dim, mps_type, dtype, device
            )
            for i in range(len(mps_tensors)):
                mps_tensors[i].requires_grad = requires_grad
            self._length: int = length
            self._physical_dim: int = physical_dim
            self._virtual_dim: int = virtual_dim
            self._mps_type: MPSType = mps_type
            self._dtype: torch.dtype = dtype
            self._device: torch.device = device
        else:
            # TODO: checking whether the mps_tensors is valid, not emergent
            self._length: int = len(mps_tensors)
            self._physical_dim: int = mps_tensors[0].shape[1]
            self._virtual_dim: int = mps_tensors[0].shape[2]
            self._mps_type: MPSType = (
                MPSType.Open if mps_tensors[0].shape[0] == 1 else MPSType.Periodic
            )
            self._dtype: torch.dtype = mps_tensors[0].dtype
            self._device: torch.device = mps_tensors[0].device
            requires_grad = mps_tensors[0].requires_grad if requires_grad is None else requires_grad

        self._requires_grad: bool = requires_grad
        self._mps: List[torch.Tensor] = mps_tensors
        self._center: int | None = None

    def center_orthogonalization_(
        self, center: int, mode: Literal["svd", "qr"], truncate_dim: int | None = None
    ):
        """
        Perform center orthogonalization on the MPS. This is an in-place operation.

        Args:
            center: int, the center of the MPS.
            mode: Literal["svd", "qr"], the mode of orthogonalization.
            truncate_dim: int | None, the dimension to be truncated. If None, no truncation will be performed.
        """
        assert -self.length <= center < self.length, "center out of range"
        if center < 0:
            center = self.length + center
        if self._center is None:
            new_local_tensors = orthogonalize_arange(self._mps, 0, center, mode, truncate_dim)
            new_local_tensors = orthogonalize_arange(
                new_local_tensors, self.length - 1, center, mode, truncate_dim
            )
            for i in range(self.length):
                self._mps[i] = new_local_tensors[i]
        elif self.center != center:
            new_local_tensors, changed_indices = orthogonalize_arange(
                self._mps, self.center, center, mode, truncate_dim, return_changed=True
            )
            for changed_idx in changed_indices:
                self._mps[changed_idx] = new_local_tensors[changed_idx]
        else:
            # when self.center == center
            pass
        self._center = center

    def force_set_local_tensor_(self, i: int, value: torch.Tensor):
        """
        Force set the local tensor at index i to the given value with checking the shape and dtype.

        Args:
            i: int, the index of the local tensor to be set.
            value: torch.Tensor, the value to be set.
        """
        value = value.to(dtype=self._dtype, device=self._device)
        value.requires_grad = self._requires_grad
        self._mps[i] = value

    def __getitem__(self, i: int) -> torch.Tensor:
        return self._mps[i]

    def __setitem__(self, i: int, value: torch.Tensor):
        local_tensor_shape = self[i].shape
        local_tensor_dtype = self[i].dtype
        assert value.shape == local_tensor_shape, (
            f"value shape must match local tensor shape {local_tensor_shape}, but got {value.shape}"
        )
        assert value.dtype == local_tensor_dtype, (
            f"value dtype must match local tensor dtype {local_tensor_dtype}, but got {value.dtype}"
        )
        self.force_set_local_tensor_(i, value)

    def global_tensor(self) -> torch.Tensor:
        """
        Calculate the global tensor of the MPS.
        """
        # use tensordot to contract the mps tensors, because it's faster than calc_global_tensor_by_contract
        if self.length > 15:
            print(
                "Warning: Calculating global tensor of MPS with length > 15, this may up all the memory",
                file=sys.stderr,
            )
        return calc_global_tensor_by_tensordot(self._mps)

    def norm_factors(self) -> torch.Tensor:
        """
        Calculate the norm factors of the MPS.
        """
        return calculate_mps_norm_factors(self._mps).real

    def norm(self, *, _efficient_mode: bool = True) -> torch.Tensor:
        """
        Calculate the norm of the MPS.
        """
        if _efficient_mode and self.center is not None:
            return self._mps[self.center].norm()
        else:
            norm_factors = self.norm_factors()
            # use sqrt inside the product to avoid overflow
            return torch.prod(norm_factors.sqrt())

    def normalize_(self, *, _efficient_mode: bool = True):
        """
        Normalize the MPS in-place.
        """
        if _efficient_mode and self.center is not None:
            self._mps[self.center] /= self.norm()
        else:
            norm_factors = 1 / self.norm_factors().sqrt()
            for i in range(self.length):
                self._mps[i] *= norm_factors[i]

    def inner_product(self, other: "MPS", return_product_factors: bool = False) -> torch.Tensor:
        """
        Calculate the inner product of two MPS. These two MPS must have the same length.
        """
        assert isinstance(other, MPS), "other must be a MPS"
        assert self.length == other.length, "length of two MPS must be the same"
        product_factors = calc_inner_product(self._mps, other._mps)
        if return_product_factors:
            return product_factors
        else:
            return torch.prod(product_factors)

    def check_orthogonality(
        self, *, check_mode: Literal["print", "assert"] = "print", tolerance: float = 1e-6
    ):
        """
        Check the orthogonality of the MPS.
        """
        assert check_mode.lower() in ["print", "assert"], (
            "check_mode must be either 'print' or 'assert'"
        )
        print_mode = check_mode.lower() == "print"
        if self.center is None:
            print("center is None, so no orthogonality check can be performed")
        else:
            identity = torch.eye(
                2, dtype=self._dtype, device=self._device
            )  # cache for the identity matrix
            for i in range(self.length):
                if i == self.center:
                    if print_mode:
                        print(f"Local Tensor {i}: Center")
                else:
                    local_tensor = self._mps[i]
                    if i < self.center:
                        product = torch.einsum("abc,abd->cd", local_tensor.conj(), local_tensor)
                    else:
                        product = torch.einsum("xab,yab->xy", local_tensor, local_tensor.conj())

                    assert product.shape[0] == product.shape[1]

                    if identity.shape[0] != product.shape[0]:
                        identity = torch.eye(
                            product.shape[0], dtype=self._dtype, device=self._device
                        )

                    diff_norm = (product - identity).norm(p=1).item()

                    if print_mode:
                        print(f"Local Tensor {i}: {diff_norm}")
                    else:
                        assert diff_norm < tolerance, (
                            f"Local Tensor {i} is not orthogonal, {diff_norm=}"
                        )

    def to_(self, dtype: torch.dtype | None = None, device: torch.device | None = None) -> "MPS":
        if dtype is not None and self._dtype != dtype:
            for i in range(self.length):
                self._mps[i] = self._mps[i].to(dtype=dtype)
            self._dtype = dtype
        if device is not None and self._device != device:
            for i in range(self.length):
                self._mps[i] = self._mps[i].to(device=device)
            self._device = device
        return self

    def one_body_reduced_density_matrix(
        self, *, idx: int, inplace_mutation: bool = False
    ) -> torch.Tensor:
        assert 0 <= idx < self.length, "idx must be in [0, length - 1]"
        if self.center is None:  # TODO: optimize this branch
            # maybe we can just use einsum here, need some benchmarking
            if inplace_mutation:
                self.center_orthogonalization_(idx, "qr")
                return self.one_body_reduced_density_matrix(idx=idx)
            else:
                # do center orthogonalization out of place
                local_tensors = self.local_tensors
                length = len(local_tensors)
                center = idx
                mode = "qr"
                local_tensors = orthogonalize_arange(local_tensors, 0, center, mode)
                local_tensors = orthogonalize_arange(local_tensors, length - 1, center, mode)
                center_tensor = local_tensors[center]
        else:
            if self.center == idx:
                center_tensor = self.center_tensor
            else:  # TODO: optimize this branch
                if inplace_mutation:
                    self.center_orthogonalization_(idx, "qr")
                    return self.one_body_reduced_density_matrix(idx=idx)
                else:
                    # moving center out of place
                    local_tensors = self.local_tensors
                    new_center = idx
                    local_tensors = orthogonalize_arange(
                        local_tensors, self.center, new_center, mode="qr"
                    )
                    center_tensor = local_tensors[new_center]

        return einsum(
            center_tensor,
            center_tensor.conj(),
            "left mid right, left mid_conj right -> mid mid_conj",
        )

    @property
    def center_tensor(self) -> torch.Tensor | None:
        if self.center is None:
            return None
        else:
            return self._mps[self.center]

    @property
    def local_tensors(self) -> List[torch.Tensor]:
        return [i for i in self._mps]

    @property
    def length(self) -> int:
        return self._length

    @property
    def physical_dim(self) -> int:
        return self._physical_dim

    @property
    def virtual_dim(self) -> int:
        return self._virtual_dim

    @property
    def mps_type(self) -> MPSType:
        return self._mps_type

    @property
    def center(self) -> int | None:
        return self._center

    @staticmethod
    def from_state_tensor(
        state_tensor: torch.Tensor, max_rank: int | None = None, use_svd: bool = False
    ) -> "MPS":
        local_tensors, _ = tt_decomposition(state_tensor, max_rank=max_rank, use_svd=use_svd)
        mps = MPS(mps_tensors=local_tensors)
        mps._center = len(local_tensors) - 1
        return mps

#### Test: Central Orthogonalization

In [7]:
length = 8
physical_dim = 3
virtual_dim = 4
dtype = torch.complex128

para = {"length": length, "d": physical_dim, "chi": virtual_dim, "dtype": dtype}

for i in range(length):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    mps = MPS(mps_tensors=mps_tensors, requires_grad=False)
    global_tensor = mps.global_tensor()
    psi.center_orthogonalization(i, "svd")
    mps.center_orthogonalization_(i, "svd")
    new_global_tensor = mps.global_tensor()
    assert torch.allclose(global_tensor, new_global_tensor)
    mps.check_orthogonality(check_mode="assert")
    orthogonalized_mps_tensors_ref = psi.tensors
    orthogonalized_mps_tensors = mps.local_tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])

# with truncate_dim
for i in range(length):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    mps = MPS(mps_tensors=mps_tensors, requires_grad=False)
    psi.center_orthogonalization(i, "svd", 2)
    mps.center_orthogonalization_(i, "svd", 2)
    mps.check_orthogonality(check_mode="assert")
    orthogonalized_mps_tensors_ref = psi.tensors
    orthogonalized_mps_tensors = mps.local_tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])

#### Test: Norm of MPS = Norm of Center of MPS

In [8]:
length = 8
physical_dim = 3
virtual_dim = 4
dtype = torch.complex128
device = torch.device("cpu")

for i in range(length):
    mps = MPS(
        length=length,
        physical_dim=physical_dim,
        virtual_dim=virtual_dim,
        dtype=dtype,
        device=device,
        mps_type=MPSType.Open,
        requires_grad=False,
    )
    mps.center_orthogonalization_(i, "svd")
    global_tensor = mps.global_tensor()
    norm_ref = global_tensor.norm()
    norm_mps_efficient = mps.norm(_efficient_mode=True)
    norm_mps_normal = mps.norm(_efficient_mode=False)
    assert torch.allclose(norm_mps_efficient, norm_ref)
    assert torch.allclose(norm_mps_normal, norm_ref)

for i in range(length):
    mps = MPS(
        length=length,
        physical_dim=physical_dim,
        virtual_dim=virtual_dim,
        dtype=dtype,
        device=device,
        mps_type=MPSType.Open,
        requires_grad=False,
    )
    mps.center_orthogonalization_(i, "svd")
    mps.normalize_(_efficient_mode=True)
    global_tensor = mps.global_tensor()
    norm = global_tensor.norm()
    assert torch.allclose(norm, torch.ones_like(norm))

#### Test: Center Tensor Singular Spetrum Property

In [9]:
length = 8
physical_dim = 3
virtual_dim = 4
dtype = torch.complex128
device = torch.device("cpu")
center = 3

mps = MPS(
    length=length,
    physical_dim=physical_dim,
    virtual_dim=virtual_dim,
    dtype=dtype,
    device=device,
    mps_type=MPSType.Open,
    requires_grad=False,
)

mps.center_orthogonalization_(center, "svd")

global_tensor = mps.global_tensor()
center_tensor = mps.center_tensor

center_tensor_mat_left = center_tensor.view(center_tensor.shape[0], -1)
left_singular_values = torch.linalg.svdvals(center_tensor_mat_left)
center_tensor_mat_right = center_tensor.view(-1, center_tensor.shape[-1])
right_singular_values = torch.linalg.svdvals(center_tensor_mat_right)

left_shape = global_tensor.shape[:center]
left_dim = torch.prod(torch.tensor(left_shape))
right_shape = global_tensor.shape[center + 1 :]
right_dim = torch.prod(torch.tensor(right_shape))

global_tensor_mat_left = global_tensor.view(left_dim, -1)
global_tensor_mat_right = global_tensor.view(-1, right_dim)

ref_left_singular_values = torch.linalg.svdvals(global_tensor_mat_left)
ref_right_singular_values = torch.linalg.svdvals(global_tensor_mat_right)

# add zero padding to the left and right singular values, because the references have more singular values due to numerical errors
left_singular_values = torch.cat(
    [
        left_singular_values,
        torch.zeros(ref_left_singular_values.shape[0] - left_singular_values.shape[0]),
    ]
)
right_singular_values = torch.cat(
    [
        right_singular_values,
        torch.zeros(ref_right_singular_values.shape[0] - right_singular_values.shape[0]),
    ]
)

assert torch.allclose(left_singular_values, ref_left_singular_values)
assert torch.allclose(right_singular_values, ref_right_singular_values)

#### Test: Reduced Density Matrix

In [10]:
from tensor_network.quantum_state.functional import calc_reduced_density_matrix
from Library.QuantumState import TensorPureState

In [11]:
length = 8
physical_dim = 3
virtual_dim = 4
dtype = torch.complex128
device = torch.device("cpu")

for center in range(length):
    for i in range(length):
        for inplace_mutation in [True, False]:
            psi = MPS_basic(
                para={"length": length, "d": physical_dim, "chi": virtual_dim, "dtype": dtype}
            )
            mps_tensors = deepcopy(psi.tensors)
            mps = MPS(mps_tensors=mps_tensors, requires_grad=False)
            mps.center_orthogonalization_(center, "qr")
            psi.center_orthogonalization(center, "svd")

            global_tensor = mps.global_tensor()
            global_tensor_ref = psi.full_tensor()
            assert torch.allclose(global_tensor_ref, global_tensor)

            tensor_pure_state = TensorPureState(tensor=global_tensor)

            reduced_density_matrix_tps = tensor_pure_state.reduced_density_matrix(i)
            reduced_density_matrix_psi = psi.one_body_RDM(i)
            reduced_density_matrix_mps = mps.one_body_reduced_density_matrix(
                idx=i, inplace_mutation=inplace_mutation
            )
            reduced_density_matrix_mine = calc_reduced_density_matrix(global_tensor, i)

            assert torch.allclose(reduced_density_matrix_mps, reduced_density_matrix_mine)
            assert torch.allclose(reduced_density_matrix_tps, reduced_density_matrix_mine)

            normalized_reduced_density_matrix = reduced_density_matrix_mps / torch.trace(
                reduced_density_matrix_mps
            )
            # transpose here because the reference code has a minor bug
            assert torch.allclose(reduced_density_matrix_psi.T, normalized_reduced_density_matrix)

In [12]:
length = 8
physical_dim = 3
virtual_dim = 4
dtype = torch.complex128
device = torch.device("cpu")
center = 3

para = {"length": length, "d": physical_dim, "chi": virtual_dim, "dtype": dtype}

psi = MPS_basic(para=para)
print(f"psi.center: {psi.center}")
reduced_density_matrix_psi_before = psi.one_body_RDM(center)
psi.center_orthogonalization(center, "svd")
print(f"psi.center: {psi.center}")
reduced_density_matrix_psi_after = psi.one_body_RDM(center)


# FIXME: this will fail. fix the bug in the reference code
# assert torch.allclose(reduced_density_matrix_psi_before, reduced_density_matrix_psi_after), f"reduced_density_matrix_psi_before: {reduced_density_matrix_psi_before}\n\nreduced_density_matrix_psi_after: {reduced_density_matrix_psi_after}"

psi.center: -1
psi.center: 3
