In [1]:
import torch

## 指标收缩

### 指标收缩的重要函数：einsum


*   其输入为字符串公式与相关张量；
*   公式包含箭头 "->"，箭头左侧为各个待收缩张量的指标，右侧为收缩所得张量的指标；
*   左侧各个张量的指标用逗号隔开，共有指标使用同一个字母表示；
*   当左侧出现的指标没有出现在右侧时，说明对该指标作求和运算；
*   例：$T_k = \sum_{ij} A_{ijk} B_{ik} C_{jk}$ 的代码为 `T = torch.einsum('ijk,ik,jk->k', A, B, C)`


#### 例子：向量内积

In [2]:
u = torch.randn(4)
v = torch.randn(4)

In [3]:
z1 = u.dot(v)
z2 = u.inner(v)
z3 = torch.einsum("i,i->", u, v)
z4 = (u * v).sum()

# Try
t1 = torch.einsum("a,b->", u, v)

assert z1.isclose(z2) and z1.isclose(z3) and z1.isclose(z4)
assert not z1.isclose(t1)

print(f"{z1=}, {t1=}")

t2 = (u.reshape(4, 1) * v.reshape(1, 4)).sum()
assert t1.isclose(t2)
print(f"{t2=}")

z1=tensor(2.5590), t1=tensor(1.0885)
t2=tensor(1.0885)


b$z_1 = z_2 = z_3 = z_4 = \sum_{i} u_i \cdot v_i$

but $t_1 = t_2 = \sum_{a} \sum_{b} u_a \cdot v_b$ which is equivalent to a kronecker product and then a sum

#### 例子：矩阵乘积
矩阵乘：$X = PQ \leftrightarrow X_{ik} = \sum_j P_{ij} Q_{jk}$

In [4]:
P = torch.randn(3, 4)
Q = torch.randn(4, 5)

In [5]:
X1 = P @ Q
X2 = torch.matmul(P, Q)
X3 = torch.einsum("ij,jk->ik", P, Q)

assert X1.allclose(X2) and X1.allclose(X3)

print(X3)

tensor([[-0.4166,  3.4424, -4.2510, -0.2182, -0.4939],
        [ 0.1289, -0.9257,  1.3034,  0.1052,  0.4339],
        [ 0.6358, -1.4259,  0.8684,  1.1713, -2.3189]])


#### 例子：张量收缩

In [6]:
i = 2
j = 3
k = 4
a = 5
b = 6
c = 7

A = torch.randn(i, a, j)
B = torch.randn(j, b, k)
C = torch.randn(k, c, i)

In [7]:
X = torch.einsum("iaj, jbk, kci->abc", A, B, C)
print(X.shape)

torch.Size([5, 6, 7])


In [8]:
# Alternatively
## To contract index j
tmp_a = A.reshape(i * a, j)
tmp_b = B.reshape(j, b * k)
tmp_iabk = (tmp_a @ tmp_b).reshape(i, a, b, k)

## To contract index i and k
tmp_abki = tmp_iabk.permute(1, 2, 3, 0)
tmp_ab_ki = tmp_abki.reshape(a * b, k * i)
tmp_c_kic = C.permute(0, 2, 1)
tmp_c_ki_c = tmp_c_kic.reshape(k * i, c)
X1 = (tmp_ab_ki @ tmp_c_ki_c).reshape(a, b, c)

assert X.allclose(X1)

# Try
tmp_c_ikc = C.permute(2, 0, 1)
tmp_c_ik_c = tmp_c_ikc.reshape(i * k, c)
Try_X = (tmp_ab_ki @ tmp_c_ik_c).reshape(a, b, c)
assert not X.allclose(Try_X)