In [1]:
#|default_exp tensor_decomposition
#|export
import torch
from tqdm.auto import tqdm
from typing import List

# 1.8: 张量分解

## 单秩分解



将 K 阶张量分解为 K 个向量的直积
$$
T = \zeta \prod_{\otimes k=0}^{K-1} \boldsymbol{v}^{(k)}
$$

* 绝大多数张量不存在严格的单秩分解
* $\zeta$ 是常系数
* 最优单秩近似问题：
    * 向量限制为单位向量，保证数值稳定性；向量长度可以提取到 $\zeta$ 里
    $$
    \min_{\zeta, \{|\boldsymbol{v}[k]|=1\}} \left|T - \zeta \prod_{\otimes k=0}^{K-1} \boldsymbol{v}^{(k)}\right|
    $$


Note:
* 没有 entanglement 的多量子态就可以被单秩分解

![tensor-rank-1-decomposition](images/tensor-rank-1-decomposition.png)

TODO:
* 看书 1.6
* Rank 1 分解的代码



## 最优单秩近似的迭代解法

对于某个维度 $m$，收缩除了 $m$ 维度的所有维度：这些维度和对应的向量进行收缩，得到的结果就是 $m$ 维度对应的向量的近似

$$\sum_{\{s_k,k\neq m\}} T_{s_0s_1...s_{K-1}} \prod_{k\neq m} v_{s_k}^{[k]*} = \zeta v_{s_m}^{[m]}$$

图例（以三阶张量为例）
![tensor-decomposition-rank-1-example](images/rank-1-tensor-decomposition-iter-algorithm.png)

伪代码：
```
初始化：
* T 已知，它的阶数为 K
* zeta 可以初始化为 1
* 随机生成 K 个单位向量 v_0, v_1, ..., v_{K-1}


for _ in range(num_iterations):
    for i in range(K):
        vs = [v[x] for x in range(K) and x != i]
        scaled_vi = contract(T, vs)
        vi = scaled_vi / norm(scaled_vi)
        v[i] = vi

    zeta_new = contract(T, v)
    if |zeta_new - zeta| < eps: # Or the diff of vs is small
        break
    zeta = zeta_new
```



In [2]:
#|export
def outer_product(vectors: List[torch.Tensor]) -> torch.Tensor:
    for v in vectors:
        assert v.dim() == 1
    num_vectors = len(vectors)
    assert num_vectors >= 2
    # if num_vectors <= 26, we can just use einsum
    if num_vectors <= 26:
        alphabet = "abcdefghijklmnopqrstuvwxyz"
        input_string = ",".join(c for c in alphabet[:num_vectors])
        output_string = alphabet[:num_vectors]
        einsum_exp = f"{input_string} -> {output_string}"
        return torch.einsum(einsum_exp, *vectors)
    else:
        reshaped_vectors = []
        for (i, v) in enumerate(vectors):
            # calculate shapes
            s = torch.ones(num_vectors, dtype=torch.int)
            s[i] = v.shape[0]
            s = s.tolist()
            reshaped_vectors.append(v.reshape(s))

        # broadcast multiplication
        v = reshaped_vectors[0]
        for i in range(1, num_vectors):
            v = v * reshaped_vectors[i]
        return v


In [3]:
# Try
T = torch.randn(2, 3, 4, dtype=torch.complex64)
v0 = torch.randn(2, dtype=torch.complex64)
v1 = torch.randn(3, dtype=torch.complex64)
v2 = torch.randn(4, dtype=torch.complex64)

# Try to calculate zeta differently
zeta_ref = torch.einsum("abc, a, b, c ->", T, v0, v1, v2)
outer_v0v1v2 = outer_product([v0, v1, v2])
zeta = (T * outer_v0v1v2).sum()
assert zeta.isclose(zeta_ref)
# Try to calculate v0 differently
v0_ref = torch.einsum("abc,b,c -> a", T, v1, v2)
outer_v1v2 = outer_product([v1, v2]).unsqueeze(0)
v0 = (T * outer_v1v2).sum((1, 2))
assert v0.allclose(v0_ref)
# Try to calculate v1 differently
v1_ref = torch.einsum("abc,a,c -> b", T, v0, v2, )
outer_v0v2 = outer_product([v0, v2]).unsqueeze(1)
v1 = (T * outer_v0v2).sum((0, 2))
assert v1.allclose(v1_ref)
# Try to calculate v2 differently
v2_ref = torch.einsum("abc,a,b -> c", T, v0, v1)
outer_v0v1 = outer_product([v0, v1]).unsqueeze(2)
v2 = (T * outer_v0v1).sum((0, 1))
assert v2.allclose(v2_ref)

In [4]:
#|export
def rank1_decomposition(tensor: torch.Tensor, num_iter: int = 1000, eps: float = 1e-10) -> List[torch.Tensor]:
    t_shape = tensor.shape
    k = len(t_shape)
    decomposed_vecs = [torch.randn(d, dtype=tensor.dtype) for d in t_shape]
    decomposed_vecs = [v / v.norm() for v in decomposed_vecs]
    zeta = 1.
    for _ in tqdm(range(num_iter)):
        for idx in range(k):
            vs = decomposed_vecs[:idx] + decomposed_vecs[idx + 1:]
            outer = outer_product(vs).unsqueeze(idx)
            sum_indices = list(range(k))
            sum_indices.pop(idx)
            vi = (tensor * outer).sum(tuple(sum_indices))
            vi /= vi.norm()
            decomposed_vecs[idx] = vi

        zeta_new = (tensor * outer_product(decomposed_vecs)).sum()
        if (zeta_new - zeta).norm() < eps:
            break
        zeta = zeta_new

    return decomposed_vecs

In [5]:
a = torch.randn(2, 3, 4, 5, dtype=torch.float32)
decompositions = rank1_decomposition(a)
# TODO: to compare with the reference implementation
# TODO: to test a complex tensor

  0%|          | 0/1000 [00:00<?, ?it/s]