# 4.2: 矩阵乘积态的中心正交形式（Central Orthogonal Form）

## Code

In [1]:
# |default_exp mps.modules
# |export
import torch
from typing import List, Tuple, Literal
from tensor_network.mps.functional import gen_random_mps_tensors, MPSType

### Orthogonalization One Step

In [2]:
# |export mps.functional
from typing import Literal, Tuple


def orthogonalize_left2right_step(
    mps_tensors: List[torch.Tensor],
    local_tensor_idx: int,
    mode: Literal["svd", "qr"],
    truncate_dim: int | None = None,
    return_locals: bool = False,
) -> List[torch.Tensor] | Tuple[torch.Tensor, torch.Tensor]:
    """
    One step of orthogonalization from left to right, which will make the local tensor isometric and the right one to it transformed.

    Args:
        mps_tensors: List[torch.Tensor], MPS tensors
        local_tensor_idx: int, the index of the local tensor to be orthogonalized.
        mode: Literal["svd", "qr"], the mode of orthogonalization.
        truncate_dim: int | None, the dimension to be truncated. If None, no truncation will be performed.
        return_locals: bool, whether to return the local tensors. If True, only the local and the right one will be returned.

    Returns:
        List[torch.Tensor], the list of tensors after one step of orthogonalization from left to right.
    """
    length = len(mps_tensors)
    assert length > 1, "mps_tensors must have at least 2 tensors"
    assert 0 <= local_tensor_idx < length - 1, "local_tensor_idx must be in [0, length - 2]"
    mode = mode.lower()
    assert mode in ["svd", "qr"], "mode must be either 'svd' or 'qr'"
    local_tensor = mps_tensors[local_tensor_idx]
    shape = local_tensor.shape  # (virtual_dim, physical_dim, virtual_dim)
    if truncate_dim is not None:
        virtual_dim = shape[2]
        assert virtual_dim > truncate_dim > 0, (
            "truncate_dim must be positive and less than virtual_dim"
        )
        assert mode == "svd", "mode must be 'svd' when truncate_dim is provided"
        need_truncate = True
    else:
        need_truncate = False

    view_matrix = local_tensor.view(-1, shape[2])

    if mode == "svd":
        u, lm, v = torch.linalg.svd(view_matrix, full_matrices=False)
        if need_truncate:
            u = u[:, :truncate_dim]
            lm = lm[:truncate_dim]  # (truncate_dim)
            v = v[:truncate_dim, :]  # (truncate_dim, virtual_dim)
            r = lm.unsqueeze(1) * v  # (truncate_dim, virtual_dim)
        else:
            r = lm.unsqueeze(1) * v  # (virtual_dim, virtual_dim)
    else:
        u, r = torch.linalg.qr(view_matrix)

    new_local_tensor = u.reshape(
        shape[0], shape[1], -1
    )  # (virtual_dim, physical_dim, virtual_dim or truncate_dim)
    local_tensor_right = mps_tensors[
        local_tensor_idx + 1
    ]  # (virtual_dim, physical_dim, virtual_dim)
    new_local_tensor_right = torch.einsum("ab,bcd->acd", r, local_tensor_right)
    if return_locals:
        return new_local_tensor, new_local_tensor_right
    else:
        return (
            mps_tensors[:local_tensor_idx]
            + [new_local_tensor, new_local_tensor_right]
            + mps_tensors[local_tensor_idx + 2 :]
        )


def orthogonalize_right2left_step(
    mps_tensors: List[torch.Tensor],
    local_tensor_idx: int,
    mode: Literal["svd", "qr"],
    truncate_dim: int | None = None,
    return_locals: bool = False,
) -> List[torch.Tensor] | Tuple[torch.Tensor, torch.Tensor]:
    """
    One step of orthogonalization from right to left, which will make the local tensor isometric and the left one to it transformed.

    Args:
        mps_tensors: List[torch.Tensor], MPS tensors
        local_tensor_idx: int, the index of the local tensor to be orthogonalized.
        mode: Literal["svd", "qr"], the mode of orthogonalization.
        truncate_dim: int | None, the dimension to be truncated. If None, no truncation will be performed.
        return_locals: bool, whether to return the local tensors. If True, only the local and the left one will be returned.

    Returns:
        List[torch.Tensor], the list of tensors after one step of orthogonalization from right to left.
    """
    length = len(mps_tensors)
    assert length > 1, "mps_tensors must have at least 2 tensors"
    assert 1 <= local_tensor_idx < length, "local_tensor_idx must be in [1, length - 1]"
    mode = mode.lower()
    assert mode in ["svd", "qr"], "mode must be either 'svd' or 'qr'"
    local_tensor = mps_tensors[local_tensor_idx]
    shape = local_tensor.shape  # (virtual_dim, physical_dim, virtual_dim)
    if truncate_dim is not None:
        virtual_dim = shape[0]
        assert virtual_dim > truncate_dim > 0, (
            "truncate_dim must be positive and less than virtual_dim"
        )
        assert mode == "svd", "mode must be 'svd' when truncate_dim is provided"
        need_truncate = True
    else:
        need_truncate = False

    view_matrix = local_tensor.view(
        shape[0], -1
    ).t()  # (virtual_dim, virtual_dim * physical_dim) -> (virtual_dim * physical_dim, virtual_dim)
    if mode == "svd":
        u, lm, v = torch.linalg.svd(view_matrix, full_matrices=False)
        if need_truncate:
            u = u[:, :truncate_dim]
            lm = lm[:truncate_dim]  # (truncate_dim)
            v = v[:truncate_dim, :]  # (truncate_dim, virtual_dim)
            r = lm.unsqueeze(1) * v  # (truncate_dim, virtual_dim)
        else:
            r = lm.unsqueeze(1) * v  # (virtual_dim, virtual_dim)
    else:
        u, r = torch.linalg.qr(view_matrix)

    new_local_tensor = u.t().reshape(
        -1, shape[1], shape[2]
    )  # (virtual_dim or truncate_dim, physical_dim, virtual_dim)
    local_tensor_left = mps_tensors[
        local_tensor_idx - 1
    ]  # (virtual_dim, physical_dim, virtual_dim)
    new_local_tensor_left = torch.einsum("abc,dc->abd", local_tensor_left, r)
    if return_locals:
        return new_local_tensor_left, new_local_tensor
    else:
        return (
            mps_tensors[: local_tensor_idx - 1]
            + [new_local_tensor_left, new_local_tensor]
            + mps_tensors[local_tensor_idx + 1 :]
        )

#### Test

In [3]:
from tensor_network import setup_ref_code_import
from Library.MatrixProductState import MPS_basic
from copy import deepcopy

length = 8

para = {"length": length, "d": 3, "chi": 4, "dtype": torch.complex128}

for i in range(length - 1):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_left2right(i, "svd")
    orthogonalized_mps_tensors = orthogonalize_left2right_step(mps_tensors, i, "svd")
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])

# with truncate_dim
for i in range(length - 1):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_left2right(i, "svd", 2)
    orthogonalized_mps_tensors = orthogonalize_left2right_step(
        mps_tensors, i, "svd", truncate_dim=2
    )
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])


for i in range(1, length):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_right2left(i, "svd")
    orthogonalized_mps_tensors = orthogonalize_right2left_step(mps_tensors, i, "svd")
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])

# with truncate_dim
for i in range(1, length):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_right2left(i, "svd", 2)
    orthogonalized_mps_tensors = orthogonalize_right2left_step(
        mps_tensors, i, "svd", truncate_dim=2
    )
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])

From setup_ref_code_import:
  Added reference_code_path='/Users/zhiqiu/offline_code/personal/tensor_network/reference_code' to sys.path.
  You can import the reference code now.


### Orthogonalization

In [4]:
# |export mps.functional
def orthogonalize_arange(
    mps_tensors: List[torch.Tensor],
    start_idx: int,
    end_idx: int,
    mode: Literal["svd", "qr"],
    truncate_dim: int | None = None,
    return_changed: bool = False,
) -> List[torch.Tensor] | Tuple[List[torch.Tensor], List[int]]:
    """
    Perform orthogonalization on the range of tensors.

    Args:
        mps_tensors: List[torch.Tensor], MPS tensors
        start_idx: int, the start index of the range
        end_idx: int, the end index of the range
        mode: Literal["svd", "qr"], the mode of orthogonalization
        truncate_dim: int | None, the dimension to be truncated. If None, no truncation will be performed.
        return_changed: bool, whether to return the changed tensors. If True, changed tensors' indices will be returned as well.

    Returns:
        List[torch.Tensor], the list of tensors after orthogonalization
    """
    length = len(mps_tensors)
    assert length > 1, "mps_tensors must have at least 2 tensors"
    assert 0 <= start_idx < length and 0 <= end_idx < length, (
        "start_idx and end_idx must be in [0, length - 1]"
    )
    mps_tensors = [m for m in mps_tensors]
    changed_indices = set()
    if start_idx < end_idx:
        for idx in range(start_idx, end_idx, 1):
            local, local_right = orthogonalize_left2right_step(
                mps_tensors, idx, mode, truncate_dim, return_locals=True
            )
            mps_tensors[idx] = local
            mps_tensors[idx + 1] = local_right
            changed_indices.add(idx)
            changed_indices.add(idx + 1)
    elif start_idx > end_idx:
        for idx in range(start_idx, end_idx, -1):
            local_left, local = orthogonalize_right2left_step(
                mps_tensors, idx, mode, truncate_dim, return_locals=True
            )
            mps_tensors[idx - 1] = local_left
            mps_tensors[idx] = local
            changed_indices.add(idx - 1)
            changed_indices.add(idx)
    else:
        # do nothing when start_idx == end_idx
        pass

    if return_changed:
        changed_indices = list(changed_indices)
        changed_indices.sort()
        return mps_tensors, changed_indices
    else:
        return mps_tensors

#### Test

In [None]:
from random import randint

num_trials = 20

length = 8

para = {"length": length, "d": 3, "chi": 4, "dtype": torch.complex128}

i = 0
while i < num_trials:
    start_idx = randint(0, length - 1)
    end_idx = randint(0, length - 1)
    if start_idx == end_idx:
        continue
    i += 1
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_n1_n2(start_idx, end_idx, "svd", -1, False)
    orthogonalized_mps_tensors = orthogonalize_arange(mps_tensors, start_idx, end_idx, "svd")
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])


# with truncate_dim = 2
i = 0
while i < num_trials:
    start_idx = randint(0, length - 1)
    end_idx = randint(0, length - 1)
    if start_idx == end_idx:
        continue
    i += 1
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    psi.orthogonalize_n1_n2(start_idx, end_idx, "svd", 2, False)
    orthogonalized_mps_tensors = orthogonalize_arange(mps_tensors, start_idx, end_idx, "svd", 2)
    orthogonalized_mps_tensors_ref = psi.tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])

### MPS Module and Center Orthogonalization 

In [6]:
# |export
from tensor_network.mps.functional import (
    orthogonalize_arange,
    calc_global_tensor_by_tensordot,
    calculate_mps_norm_factors,
    calc_inner_product,
)
import sys


class MPS:
    def __init__(
        self,
        *,
        mps_tensors: List[torch.Tensor] | None = None,
        length: int | None = None,
        physical_dim: int | None = None,
        virtual_dim: int | None = None,
        mps_type: MPSType | None = None,
        dtype: torch.dtype | None = None,
        device: torch.device | None = None,
        requires_grad: bool | None = None,
    ) -> None:
        if mps_tensors is None:
            assert (
                length is not None
                and physical_dim is not None
                and virtual_dim is not None
                and mps_type is not None
                and dtype is not None
                and device is not None
                and requires_grad is not None
            ), "mps_tensors is None, so all arguments must be provided"
            mps_tensors = gen_random_mps_tensors(
                length, physical_dim, virtual_dim, mps_type, dtype, device
            )
            for i in range(len(mps_tensors)):
                mps_tensors[i].requires_grad = requires_grad
            self._length: int = length
            self._physical_dim: int = physical_dim
            self._virtual_dim: int = virtual_dim
            self._mps_type: MPSType = mps_type
            self._dtype: torch.dtype = dtype
            self._device: torch.device = device
        else:
            # TODO: checking whether the mps_tensors is valid, not emergent
            self._length: int = len(mps_tensors)
            self._physical_dim: int = mps_tensors[0].shape[1]
            self._virtual_dim: int = mps_tensors[0].shape[2]
            self._mps_type: MPSType = (
                MPSType.Open if mps_tensors[0].shape[0] == 1 else MPSType.Periodic
            )
            self._dtype: torch.dtype = mps_tensors[0].dtype
            self._device: torch.device = mps_tensors[0].device
            requires_grad = mps_tensors[0].requires_grad if requires_grad is None else requires_grad

        self._requires_grad: bool = requires_grad
        self._mps: List[torch.Tensor] = mps_tensors
        self._center: int | None = None

    def center_orthogonalization_(
        self, center: int, mode: Literal["svd", "qr"], truncate_dim: int | None = None
    ):
        """
        Perform center orthogonalization on the MPS. This is an in-place operation.

        Args:
            center: int, the center of the MPS.
            mode: Literal["svd", "qr"], the mode of orthogonalization.
            truncate_dim: int | None, the dimension to be truncated. If None, no truncation will be performed.
        """
        assert -self.length <= center < self.length, "center out of range"
        if center < 0:
            center = self.length + center
        if self._center is None:
            new_local_tensors = orthogonalize_arange(self._mps, 0, center, mode, truncate_dim)
            new_local_tensors = orthogonalize_arange(
                new_local_tensors, self.length - 1, center, mode, truncate_dim
            )
            for i in range(self.length):
                self._mps[i] = new_local_tensors[i]
        elif self.center != center:
            new_local_tensors, changed_indices = orthogonalize_arange(
                self._mps, self.center, center, mode, truncate_dim, return_changed=True
            )
            for changed_idx in changed_indices:
                self._mps[changed_idx] = new_local_tensors[changed_idx]
        else:
            # when self.center == center
            pass
        self._center = center

    def force_set_local_tensor_(self, i: int, value: torch.Tensor):
        """
        Force set the local tensor at index i to the given value with checking the shape and dtype.

        Args:
            i: int, the index of the local tensor to be set.
            value: torch.Tensor, the value to be set.
        """
        value = value.to(dtype=self._dtype, device=self._device)
        value.requires_grad = self._requires_grad
        self._mps[i] = value

    def __getitem__(self, i: int) -> torch.Tensor:
        return self._mps[i]

    def __setitem__(self, i: int, value: torch.Tensor):
        local_tensor_shape = self[i].shape
        local_tensor_dtype = self[i].dtype
        assert value.shape == local_tensor_shape, (
            f"value shape must match local tensor shape {local_tensor_shape}, but got {value.shape}"
        )
        assert value.dtype == local_tensor_dtype, (
            f"value dtype must match local tensor dtype {local_tensor_dtype}, but got {value.dtype}"
        )
        self.force_set_local_tensor_(i, value)

    def global_tensor(self) -> torch.Tensor:
        """
        Calculate the global tensor of the MPS.
        """
        # use tensordot to contract the mps tensors, because it's faster than calc_global_tensor_by_contract
        if self.length > 15:
            print(
                "Warning: Calculating global tensor of MPS with length > 15, this may up all the memory",
                file=sys.stderr,
            )
        return calc_global_tensor_by_tensordot(self._mps)

    def norm_factors(self) -> torch.Tensor:
        """
        Calculate the norm factors of the MPS.
        """
        return calculate_mps_norm_factors(self._mps)

    def norm(self) -> torch.Tensor:
        """
        Calculate the norm of the MPS.
        """
        norm_factors = self.norm_factors()
        # use sqrt inside the product to avoid overflow
        return torch.prod(norm_factors.sqrt())

    def normalize_(self):
        """
        Normalize the MPS in-place.
        """
        norm_factors = 1 / self.norm_factors().sqrt()
        for i in range(self.length):
            self._mps[i] *= norm_factors[i]

    def inner_product(self, other: "MPS", return_product_factors: bool = False) -> torch.Tensor:
        assert isinstance(other, MPS), "other must be a MPS"
        assert self.length == other.length, "length of two MPS must be the same"
        product_factors = calc_inner_product(self._mps, other._mps)
        if return_product_factors:
            return product_factors
        else:
            return torch.prod(product_factors)

    def to_(self, dtype: torch.dtype | None = None, device: torch.device | None = None) -> "MPS":
        if dtype is not None and self._dtype != dtype:
            for i in range(self.length):
                self._mps[i] = self._mps[i].to(dtype=dtype)
            self._dtype = dtype
        if device is not None and self._device != device:
            for i in range(self.length):
                self._mps[i] = self._mps[i].to(device=device)
            self._device = device
        return self

    @property
    def local_tensors(self) -> List[torch.Tensor]:
        return [i for i in self._mps]

    @property
    def length(self) -> int:
        return self._length

    @property
    def physical_dim(self) -> int:
        return self._physical_dim

    @property
    def virtual_dim(self) -> int:
        return self._virtual_dim

    @property
    def mps_type(self) -> MPSType:
        return self._mps_type

    @property
    def center(self) -> int | None:
        return self._center

#### Test

In [7]:
length = 8
physical_dim = 3
virtual_dim = 4
dtype = torch.complex128

para = {"length": length, "d": physical_dim, "chi": virtual_dim, "dtype": dtype}

for i in range(length):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    mps = MPS(mps_tensors=mps_tensors, requires_grad=False)
    psi.center_orthogonalization(i, "svd")
    mps.center_orthogonalization_(i, "svd")
    orthogonalized_mps_tensors_ref = psi.tensors
    orthogonalized_mps_tensors = mps.local_tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])

# with truncate_dim
for i in range(length):
    psi = MPS_basic(para=para)
    mps_tensors = deepcopy(psi.tensors)
    mps = MPS(mps_tensors=mps_tensors, requires_grad=False)
    psi.center_orthogonalization(i, "svd", 2)
    mps.center_orthogonalization_(i, "svd", 2)
    orthogonalized_mps_tensors_ref = psi.tensors
    orthogonalized_mps_tensors = mps.local_tensors
    assert len(orthogonalized_mps_tensors) == len(orthogonalized_mps_tensors_ref)
    for j in range(len(orthogonalized_mps_tensors)):
        assert torch.allclose(orthogonalized_mps_tensors[j], orthogonalized_mps_tensors_ref[j])