In [1]:
import mlx.core as mx

## 指标收缩

### 指标收缩的重要函数：einsum


*   其输入为字符串公式与相关张量；
*   公式包含箭头 "->"，箭头左侧为各个待收缩张量的指标，右侧为收缩所得张量的指标；
*   左侧各个张量的指标用逗号隔开，共有指标使用同一个字母表示；
*   当左侧出现的指标没有出现在右侧时，说明对该指标作求和运算；
*   例：$T_k = \sum_{ij} A_{ijk} B_{ik} C_{jk}$ 的代码为 `T = torch.einsum('ijk,ik,jk->k', A, B, C)`


#### 例子：向量内积

In [2]:
u = mx.random.normal([4])
v = mx.random.normal([4])

In [3]:
z1 = mx.inner(u, v)
z2 = mx.einsum("i,i->", u, v)
z3 = (u * v).sum()

# Try
t1 = mx.einsum("a,b->", u, v)

assert mx.isclose(z1, z2) and mx.isclose(z1, z3)
assert not mx.isclose(t1, z1)

print(f"{z1=}, {t1=}")

t2 = (u.reshape(4, 1) * v.reshape(1, 4)).sum()
assert mx.isclose(t1, t2)
print(f"{t2=}")

z1=array(0.693031, dtype=float32), t1=array(0.457336, dtype=float32)
t2=array(0.457336, dtype=float32)


b$z_1 = z_2 = z_3 = z_4 = \sum_{i} u_i \cdot v_i$

but $t_1 = t_2 = \sum_{a} \sum_{b} u_a \cdot v_b$ which is equivalent to a kronecker product and then a sum

#### 例子：矩阵乘积
矩阵乘：$X = PQ \leftrightarrow X_{ik} = \sum_j P_{ij} Q_{jk}$

In [4]:
P = mx.random.normal([3, 4])
Q = mx.random.normal([4, 5])

In [5]:
X1 = P @ Q
X2 = mx.matmul(P, Q)
X3 = mx.einsum("ij,jk->ik", P, Q)

assert mx.allclose(X1, X2) and mx.allclose(X1, X3)

print(X3)

array([[3.78842, -0.113603, -1.42321, 0.64539, 0.751379],
       [3.61685, 0.662649, -0.277227, 2.08043, 1.43522],
       [-4.03933, 2.16959, 1.04694, -0.194527, -1.40938]], dtype=float32)


#### 例子：张量收缩

In [9]:
i = 2
j = 3
k = 4
a = 5
b = 6
c = 7

A = mx.random.normal([i, a, j])
B = mx.random.normal([j, b, k])
C = mx.random.normal([k, c, i])

In [10]:
X = mx.einsum("iaj, jbk, kci->abc", A, B, C)
print(X.shape)

(5, 6, 7)


In [11]:
# Alternatively
## To contract index j
tmp_a = A.reshape(i * a, j)
tmp_b = B.reshape(j, b * k)
tmp_iabk = (tmp_a @ tmp_b).reshape(i, a, b, k)

## To contract index i and k
tmp_abki = tmp_iabk.transpose(1, 2, 3, 0)
tmp_ab_ki = tmp_abki.reshape(a * b, k * i)
tmp_c_kic = C.transpose(0, 2, 1)
tmp_c_ki_c = tmp_c_kic.reshape(k * i, c)
X1 = (tmp_ab_ki @ tmp_c_ki_c).reshape(a, b, c)

assert mx.allclose(X, X1), f"{mx.linalg.norm(X - X1)}"

# Try
tmp_c_ikc = C.transpose(2, 0, 1)
tmp_c_ik_c = tmp_c_ikc.reshape(i * k, c)
Try_X = (tmp_ab_ki @ tmp_c_ik_c).reshape(a, b, c)
assert not mx.allclose(X, Try_X)