# 4.3: Tensor-Train 分解与虚拟维度裁剪

回顾：虚拟指标控制矩阵乘积态的参数复杂度

问题：如何最优地减少虚拟维数，以控制矩阵乘积态参数复杂度？

4.2 中心正交形式：矩阵乘积态的一种特殊形式，用于虚拟维数的最优裁剪 (optimal truncation)

4.3：具体如何裁剪？中心正交形式+TT分解

## 算法简介

* 以 $(d\times d\times d\times d)$ 的4阶张量 $\varphi_{s_0s_1s_2s_3}$ 为例：进行3次变形+SVD分解；
* 每次SVD，可根据奇异值进行维数裁剪，控制虚拟指标维数（见1.7节）；
* 分解所得MPS满足中心正交形式，正交中心位于最右侧；
* 若不裁剪，可使用QR分解。

![tt_decomposition_example](./images/tt_decomposition_example.png)

In [1]:
import torch

In [2]:
# |export mps.functional
from tensor_network.utils import check_state_tensor
from typing import List, Tuple


def tt_decomposition(
    state_tensor: torch.Tensor, *, max_rank: int | None = None, use_svd: bool = False
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    """
    Perform tensor-train decomposition

    Args:
        state_tensor: torch.Tensor, the state tensor to be decomposed
        max_rank: int | None, the maximum rank to be kept in SVD decomposition. If None, no rank clipping will be performed.
        use_svd: bool, whether to use SVD decomposition. If False, QR decomposition will be used.

    Returns:
        Tuple[List[torch.Tensor], List[torch.Tensor]], the local tensors and the clipped ranks.
        The local tensors are the MPS tensors after decomposition.
        The clipped ranks are the ranks after SVD decomposition if max_rank is not None. The clipped ranks are entanglement spectra.
    """
    check_state_tensor(state_tensor)
    clip_rank = max_rank is not None
    if clip_rank:
        assert max_rank > 0, "max_rank must be greater than 0"
        use_svd = True

    physical_dim = state_tensor.shape[0]
    shape = state_tensor.shape
    n_qubits = state_tensor.ndim
    left_dim = 1
    local_tensors = []
    remained_tensor = state_tensor
    clipped_ranks = []

    for i in range(n_qubits - 1):
        mid_dim = shape[i]
        if use_svd:
            q, s, v = torch.linalg.svd(remained_tensor.reshape(left_dim * mid_dim, -1))
            if clip_rank:
                rank = min(max_rank, s.shape[0])
            else:
                rank = s.shape[0]

            q = q[:, :rank]
            s = s[:rank]  # (rank)
            v = v[:rank, :]
            s = s.unsqueeze(1)
            remained_tensor = s * v
            new_left_dim = rank
            clipped_ranks.append(rank)
        else:
            q, r = torch.linalg.qr(
                remained_tensor.reshape(left_dim * mid_dim, -1)
            )  # (m, n) -> q (m, m) and r (m, n)
            remained_tensor = r
            new_left_dim = remained_tensor.shape[0]

        # new_left_dim is the right dim of the local tensor
        local_tensors.append(q.view(left_dim, mid_dim, new_left_dim))
        left_dim = new_left_dim

    local_tensors.append(remained_tensor.view(left_dim, physical_dim, 1))
    return local_tensors, clipped_ranks

In [3]:
from tensor_network.mps.modules import MPS

for i in range(2, 5):
    for use_svd in [True, False]:
        for _ in range(10):
            a = torch.randn(*([2] * i), dtype=torch.complex64)
            local_tensors, _ = tt_decomposition(a, use_svd=use_svd)
            mps = MPS(mps_tensors=local_tensors)
            mps._center = len(local_tensors) - 1
            global_tensor = mps.global_tensor()
            mps.check_orthogonality(check_mode="assert", tolerance=1e-5)
            assert torch.allclose(global_tensor, a), f"{(global_tensor - a).norm()}"