In [1]:
#|default_exp tensor_decomposition
#|export
import torch
from tqdm.auto import tqdm
from typing import List

# 1.8: 张量分解

## 单秩分解



将 K 阶张量分解为 K 个向量的直积
$$
T = \zeta \prod_{\otimes k=0}^{K-1} \boldsymbol{v}^{(k)}
$$

* 绝大多数张量不存在严格的单秩分解
* $\zeta$ 是常系数
* 最优单秩近似问题：
    * 向量限制为单位向量，保证数值稳定性；向量长度可以提取到 $\zeta$ 里
    $$
    \min_{\zeta, \{|\boldsymbol{v}[k]|=1\}} \left|T - \zeta \prod_{\otimes k=0}^{K-1} \boldsymbol{v}^{(k)}\right|
    $$


Note:
* 没有 entanglement 的多量子态就可以被单秩分解

![tensor-rank-1-decomposition](images/tensor-rank-1-decomposition.png)


### 最优单秩近似的迭代解法

对于某个维度 $m$，收缩除了 $m$ 维度的所有维度：这些维度和对应的向量进行收缩，得到的结果就是 $m$ 维度对应的向量的近似

$$\sum_{\{s_k,k\neq m\}} T_{s_0s_1...s_{K-1}} \prod_{k\neq m} v_{s_k}^{[k]*} = \zeta v_{s_m}^{[m]}$$

图例（以三阶张量为例）
![tensor-decomposition-rank-1-example](images/rank-1-tensor-decomposition-iter-algorithm.png)

伪代码：
```
初始化：
* T 已知，它的阶数为 K
* zeta 可以初始化为 1
* 随机生成 K 个单位向量 v_0, v_1, ..., v_{K-1}


for _ in range(num_iterations):
    for i in range(K):
        vs = [v[x] for x in range(K) and x != i]
        scaled_vi = contract(T, vs)
        vi = scaled_vi / norm(scaled_vi)
        v[i] = vi

    zeta_new = contract(T, v)
    if |zeta_new - zeta| < eps: # Or the diff of vs is small
        break
    zeta = zeta_new
```



In [2]:
#|export
def outer_product(vectors: List[torch.Tensor]) -> torch.Tensor:
    for v in vectors:
        assert v.dim() == 1
    num_vectors = len(vectors)
    assert num_vectors >= 2
    # if num_vectors <= 26, we can just use einsum
    if num_vectors <= 26:
        alphabet = "abcdefghijklmnopqrstuvwxyz"
        input_string = ",".join(c for c in alphabet[:num_vectors])
        output_string = alphabet[:num_vectors]
        einsum_exp = f"{input_string} -> {output_string}"
        return torch.einsum(einsum_exp, *vectors)
    else:
        reshaped_vectors = []
        for (i, v) in enumerate(vectors):
            # calculate shapes
            s = torch.ones(num_vectors, dtype=torch.int)
            s[i] = v.shape[0]
            s = s.tolist()
            reshaped_vectors.append(v.reshape(s))

        # broadcast multiplication
        v = reshaped_vectors[0]
        for i in range(1, num_vectors):
            v = v * reshaped_vectors[i]
        return v


In [3]:
# Try
T = torch.randn(2, 3, 4, dtype=torch.complex64)
v0 = torch.randn(2, dtype=torch.complex64)
v1 = torch.randn(3, dtype=torch.complex64)
v2 = torch.randn(4, dtype=torch.complex64)

# Try to calculate zeta differently
zeta_ref = torch.einsum("abc, a, b, c ->", T, v0, v1, v2)
outer_v0v1v2 = outer_product([v0, v1, v2])
zeta = (T * outer_v0v1v2).sum()
assert zeta.isclose(zeta_ref)
# Try to calculate v0 differently
v0_ref = torch.einsum("abc,b,c -> a", T, v1, v2)
outer_v1v2 = outer_product([v1, v2]).unsqueeze(0)
v0 = (T * outer_v1v2).sum((1, 2))
assert v0.allclose(v0_ref)
# Try to calculate v1 differently
v1_ref = torch.einsum("abc,a,c -> b", T, v0, v2, )
outer_v0v2 = outer_product([v0, v2]).unsqueeze(1)
v1 = (T * outer_v0v2).sum((0, 2))
assert v1.allclose(v1_ref)
# Try to calculate v2 differently
v2_ref = torch.einsum("abc,a,b -> c", T, v0, v1)
outer_v0v1 = outer_product([v0, v1]).unsqueeze(2)
v2 = (T * outer_v0v1).sum((0, 1))
assert v2.allclose(v2_ref)

#### 实现

In [76]:
#|export
def rank1_tc(x, v=None, it_time=10000, tol=1e-14):
    import torch as tc
    # From: https://github.com/ranshiju/Python-for-Tensor-Network-Tutorial/blob/4c89b0766159d3495122ec39339e7bd019f10fdf/Library/MathFun.py#L231
    # 初始化向量组v
    if v is None:
        v = list()
        for n in range(x.ndimension()):
            v.append(tc.randn(x.shape[n], device=x.device, dtype=x.dtype))

    # 归一化向量组v
    for n in range(x.ndimension()):
        v[n] /= v[n].norm()

    norm1 = 1
    err = tc.ones(x.ndimension(), device=x.device, dtype=tc.float64)
    err_norm = tc.ones(x.ndimension(), device=x.device, dtype=tc.float64)
    for t in tqdm(range(it_time)):
        for n in range(x.ndimension()):
            x1 = x.clone()
            for m in range(n):
                # TODO: why do we need conj here? same for my implementation of rank1_decomposition
                x1 = tc.tensordot(x1, v[m].conj(), [[0], [0]])
            for m in range(len(v) - 1, n, -1):
                x1 = tc.tensordot(x1, v[m].conj(), [[-1], [0]])
            norm = x1.norm()
            v1 = x1 / norm
            err[n] = (v[n] - v1).norm()
            err_norm[n] = (norm - norm1).norm()
            v[n] = v1
            norm1 = norm
        if err.sum() / x.ndimension() < tol and err_norm.sum() / x.ndimension() < tol:
            break
    return v, norm1


#|export
def rank1_decomposition(tensor: torch.Tensor,
                        num_iter: int = 10000,
                        stop_criterion: str = "zeta",
                        eps: float = 1e-14) -> (List[torch.Tensor], torch.Tensor):
    assert stop_criterion in ["zeta", "norms"]
    device = tensor.device
    t_shape = tensor.shape
    k = len(t_shape)
    decomposed_vecs = [torch.randn(d, dtype=tensor.dtype, device=device) for d in t_shape]
    decomposed_vecs = [v / v.norm() for v in decomposed_vecs]
    zeta = 1.

    if stop_criterion == "zeta":
        for _ in tqdm(range(num_iter)):
            for idx in range(k):
                vs = decomposed_vecs[:idx] + decomposed_vecs[idx + 1:]
                vs = [v.conj() for v in vs]
                outer = outer_product(vs).unsqueeze(idx)
                sum_indices = list(range(k))
                sum_indices.pop(idx)
                vi = (tensor * outer).sum(tuple(sum_indices))
                vi /= vi.norm()
                decomposed_vecs[idx] = vi

            # calculate zeta
            vs = [v.conj() for v in decomposed_vecs]
            zeta_new = (tensor * outer_product(vs)).sum().real
            if (zeta_new - zeta).norm() < eps:
                break
            zeta = zeta_new
    else:
        # FIXME: seems to have bugs, adapted from the ref implementation, because the iteration takes a lot longer than my implementation - "zeta"
        v_norm_diffs = torch.ones(k, dtype=torch.float32, device=device)
        zeta_diffs = torch.ones(k, dtype=torch.float32, device=device)
        for _ in tqdm(range(num_iter)):
            for idx in range(k):
                # contraction
                vs = decomposed_vecs[:idx] + decomposed_vecs[idx + 1:]
                vs = [v.conj() for v in vs]
                outer = outer_product(vs).unsqueeze(idx)
                sum_indices = list(range(k))
                sum_indices.pop(idx)
                vi = (tensor * outer).sum(tuple(sum_indices))
                # calculate diffs
                vi_norm = vi.norm()
                vi /= vi_norm
                v_norm_diffs[idx] = (decomposed_vecs[idx] - vi).norm()
                zeta_diffs[idx] = (vi_norm - zeta).norm()
                decomposed_vecs[idx] = vi
                zeta = vi_norm

            if v_norm_diffs.sum() / k < eps and zeta_diffs.sum() / k < eps:
                break

    return decomposed_vecs, zeta


def rank1_decomposition_gradient_based(tensor: torch.Tensor, num_iter: int = 1000, eps: float = 1e-14) -> (
        List[torch.Tensor], torch.Tensor):
    t_shape = tensor.shape
    decomposed_vecs = [torch.randn(d, dtype=tensor.dtype) for d in t_shape]
    decomposed_vecs = [v / v.norm() for v in decomposed_vecs]
    for dv in decomposed_vecs:
        dv.requires_grad_(True)

    zeta = torch.ones(1, dtype=tensor.dtype)
    adam = torch.optim.Adam(decomposed_vecs, lr=0.1)
    for _ in tqdm(range(num_iter)):
        outer = outer_product(decomposed_vecs)
        loss = (tensor - outer).norm()
        loss.backward()
        adam.step()
        adam.zero_grad()
        with torch.no_grad():
            norms = [v.norm() for v in decomposed_vecs]
            zeta_new = torch.prod(torch.tensor(norms))
            if (zeta_new - zeta).norm() < eps:
                break
            else:
                zeta = zeta_new

    with torch.no_grad():
        decomposed_vecs = [v / v.norm() for v in decomposed_vecs]
        return decomposed_vecs, zeta

#### 测试及对比

##### Iteration-based vs. Gradient-based Optimization

In [99]:
# Comparison between iteration-based optimization and gradient-based optimization
a = torch.randn(2, 3, 4, 5, dtype=torch.float32)
decompositions0, zeta0 = rank1_decomposition(a)
decompositions1, zeta1 = rank1_decomposition_gradient_based(a)
# Below assertions often fail because gradient-based optimization is less robust, since gradient-based optimization is hard to determine when to stop since the delta includes the gradient as well
assert torch.allclose(zeta0, zeta1), f"{zeta0}, {zeta1}"
outer0 = zeta0 * outer_product(decompositions0)
outer1 = zeta1 * outer_product(decompositions1)
diff0 = (a - outer0).norm()
diff1 = (a - outer1).norm()
assert torch.isclose(diff0, diff1), f"{diff0=}, {diff1=}"

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

##### Test on Complex Tensor

In [100]:
# Test for a complex tensor
# This is sometimes flaky
a = torch.randn(2, 3, 4, 5, dtype=torch.complex64)
decompositions0, zeta0 = rank1_decomposition(a)
decompositions1, zeta1 = rank1_decomposition(a)
assert torch.allclose(zeta0, zeta1), f"{zeta0}, {zeta1}"
outer0 = zeta0 * outer_product(decompositions0)
outer1 = zeta1 * outer_product(decompositions1)
diff0 = (a - outer0).norm()
diff1 = (a - outer1).norm()
assert torch.isclose(diff0, diff1), f"{diff0=}, {diff1=}"

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

##### Test against Reference Implementation

In [97]:
a = torch.randn(2, 3, 4, 5, dtype=torch.complex64)
decompositions, zeta = rank1_decomposition(a, stop_criterion="zeta")
decompositions_ref, zeta_ref = rank1_tc(a)
assert torch.isclose(zeta, zeta_ref), f"{zeta=}, {zeta_ref=}"
outer = zeta * outer_product(decompositions)
outer_ref = zeta_ref * outer_product(decompositions_ref)
diff = (a - outer).norm()
diff_ref = (a - outer_ref).norm()
assert torch.isclose(diff, diff_ref), f"{diff=}, {diff_ref=}"

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

##### Test Reference Implementation

In [98]:
a = torch.randn(2, 3, 4, 5, dtype=torch.float32)
decompositions0, zeta0 = rank1_tc(a)
decompositions1, zeta1 = rank1_tc(a)
outer0 = zeta0 * outer_product(decompositions0)
outer1 = zeta1 * outer_product(decompositions1)
diff0 = (a - outer0).norm()
diff1 = (a - outer1).norm()
assert torch.isclose(diff0, diff1), f"{diff0=}, {diff1=}"

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]