In [1]:
# https://zhuanlan.zhihu.com/p/559824020
import torch

print(f'CUDA available: {torch.cuda.is_available()}')
print(f'torch version: {torch.__version__}')

CUDA available: True
torch version: 2.6.0+cu126


In [15]:
'''tensordot
用于缩并/扩充多个维度
'''
# 1. 最理想的情况 = 张量积：X^{abcd} * Y^{ef} = Z^{abcdef}
x = torch.rand((3, 4, 5, 6, 7))
y = torch.rand((8, 9, 10, 11))
assert torch.tensordot(x, y, dims=0).shape == (3, 4, 5, 6, 7, 8, 9, 10, 11)

# 2. 只缩并一个维度 X^{abc d} * Y^{d ef} = Z^{abcef}
a = torch.rand((7, 8, 9, 10, 11))
assert torch.tensordot(x, a, dims=1).shape == (3, 4, 5, 6, 8, 9, 10, 11)

# 3. 只缩并两个维度 X^{ab cd} * Y^{cd ef} = Z^{abef}
b = torch.rand((6, 7, 10, 11))
assert torch.tensordot(x, b, dims=2).shape == (3, 4, 5, 10, 11)

# 4. 只缩并三个维度 X^{abc def} * Y^{def g} = Z^{abcg}
c = torch.rand((5, 6, 7, 32))
assert torch.tensordot(x, c, dims=3).shape == (3, 4, 32)

# 5. 无限推广后缘维度缩并

# 6. 指定维度角标缩并
g = torch.rand((3, 10, 4, 12, 11, 5))
h = torch.rand((12, 2, 6, 11, 10, 7, 8))

g_dims = [3, 4, 1]
h_dims = [0, 3, 4]
assert torch.tensordot(g, h, dims=(g_dims, h_dims)).shape == (3, 4, 5, 2, 6, 7, 8)

g_dims = [3, 4]
h_dims = [0, 3]
assert torch.tensordot(g, h, dims=(g_dims, h_dims)).shape == (3, 10, 4, 5, 2, 6, 10, 7, 8)

g_dims = [3]
h_dims = [0]
assert torch.tensordot(g, h, dims=(g_dims, h_dims)).shape == (3, 10, 4, 11, 5, 2, 6, 11, 10, 7, 8)


In [10]:
'''multiply (__mul__ / *)
最自然的操作：
- 1. 对应位置元素相乘
- 2. 一个数乘以数组的每个元素

广播机制下的操作：
- 3. 扩维拷贝后相乘
'''
# 1. 对应位置元素相乘
x = torch.randn(6, 4, 13, 5)
a = torch.randn(6, 4, 13, 5)
assert torch.mul(x, a).shape == (6, 4, 13, 5)

# 2. 广播机制是从最后一个维度开始对齐
x = torch.randn(6, 4, 13, 5)
a = torch.randn(      13, 5)
assert torch.mul(a, x).shape == (6, 4, 13, 5)

# 3. 检查到不匹配的维度为1就会在这个维度上复制到匹配
x = torch.rand((3, 1, 2, 5))
a = torch.rand((3, 6, 2, 1))
assert torch.mul(x, a).shape == (3, 6, 2, 5)

# 4. multiply广播机制等价的循环
a = torch.rand((3, 1))
b = torch.rand((1, 6))
broadcast_mul = a * b
loop_mul = []
for i in a.squeeze().tolist():
    for j in b.squeeze().tolist():
        loop_mul.append(i*j)
loop_mul = torch.as_tensor(loop_mul).reshape(a.numel(), b.numel())
torch.equal(broadcast_mul, loop_mul)

True

In [20]:
'''matmul
广播机制下的操作：
- 1. 高维张量相乘，大于倒数两个维度的维度就是batch维度，不参与运算，只要求对齐
- 2. 需要保证前一个张量的最后一个维度`shape[-1]`和后一个张量的倒数第二个维度`shape[-2]`保持一致
- 3. 如果后一个张量是向量就是倒数第一个维度

情况分析：
- 1. 矩阵乘矩阵：left.shape[-1] = right.shape[-2]
- 2. 向量乘矩阵：left.shape[-1] = right.shape[-2]
- 3. 矩阵乘向量：left.shape[-1] = right.shape[-1]
- 4. 向量乘向量：left.shape[-1] = right.shape[-1]

matmul广播机制存在的缺点：
matmul广播机制是采用copy data来expand到相同维度尺寸
因此存在broadcast的matmul性能还不如einsum
尽量自己手动broadcast一下，避免触发matmul的广播
'''
# 1. 向量乘矩阵的维度缩并
vec = torch.rand(3)
mat = torch.rand((3, 5))
assert vec.matmul(mat).shape == (5,)

In [None]:
# 矩阵乘向量
print(mat.t().matmul(vec))
# 矩阵乘矩阵
row = vec.unsqueeze(0)
print(f'row matrix shape: {row.shape}')
print(row.matmul(mat))

col = vec.unsqueeze(1)
print(f'col matrix shape: {col.shape}')
print(mat.t().matmul(col))

# 向量乘向量
arr = torch.randn(3)
print(vec.matmul(arr))