In [1]:
import torch

In [2]:
# for 循环完成加法操作
def sum_with_for(x, y):
    result = []
    for i, j in zip(x, y):
        result.append(i + j)
    return torch.tensor(result)

In [3]:
x = torch.randn(100)
y = torch.randn(100)

In [4]:
%timeit -n 100 sum_with_for(x, y)  # for 循环
%timeit -n 100 (x + y) # 向量化计算

In [16]:
a = torch.ones(3, 2)
b = torch.zeros(2, 3, 1)

(a + b).shape

In [21]:
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = torch.tensor([[[-1], [-2], [-3]], [[-4], [-5], [-6]]])
a.shape, b.shape

In [24]:
print(a + b)
print(a)
print(b)

In [15]:
# 比较 expand 和 repeat 的内存占用情况
a = torch.ones(1, 3)
print(str(a.storage().size()))

# expand 不额外占用内存，只返回一个新的视图
b = a.expand(3, 3)
print(str(b.storage().size()))

# repeat 复制了原始张量
c = a.repeat(3, 3)
print(str(c.storage().size()))

In [25]:
c  # 直接复制，增大了内存浪费

In [29]:
# 手动广播
a = torch.ones(3, 2)
b = torch.zeros(2, 3, 1)

# 1. unsqueeze + expand
# a.unsqueeze(0).expand(2, 3, 2) + b.expand(2, 3, 2)

# 2. view + expand
# a.view(1, 3, 2).expand(2, 3, 2) + b.expand(2, 3, 2)

# 3. None + expand 【推荐】
a[None, :, :].expand(2, 3, 2) + b.expand(2, 3, 2)

In [32]:
a = torch.Tensor([i for i in range(24)]).view(2, 3, 4)
a

In [33]:
# 提取位置 [0, 1, 2] 的元素
# 等价于 a[(0, 1, 2)]
a[0, 1, 2]

In [34]:
# 第三个维度全取
# 等价于 a[(1, 1)]，a[(1, 1, )]，a[1, 1]
a[1, 1, :]

In [35]:
# : and ...
a = torch.rand(64, 3, 224, 224)
print(a[:, :, 0:224:4, :].shape)  # 第一、二、四维度全取，第三个维度取 0 到 223 间隔 4 个一取
# 省略 start 和 end 代表整个维度
print(a[:, :, ::4, :].shape)  # 第一、二、四维度全取，第三个维度间隔 4 个一取，从开始取到结尾

In [36]:
# 使用 ... 代替一个或多个维度，建议一个索引中只使用一次
a[..., ::4, :].shape  # 第一、二维度都取，用 ... 替代了 :, :,

In [38]:
a[..., ::4, ...].shape  # 如果将最后一个维度也改为 ... 那么在匹配维度时将混乱出错

In [42]:
# None 扩展维度
x = torch.randn(3, 224, 224)
x.shape

In [43]:
print(x.unsqueeze(0).shape)  # 使用 unsqueeze 在第 0 位置补充维度
print(x[None, ...].shape)  # 直接指定 0 号位置补充维度 (... 代表后面所有维度)

In [50]:
x = torch.randn(3, 3, 3)
x.shape

In [51]:
# 变为 [1, 3, 1, 3, 1, 3]
x = x.unsqueeze(0)  # [1, 3, 3, 3]
x = x.unsqueeze(2)  # [1, 3, 1, 3, 3]
x = x.unsqueeze(4)  # [1, 3, 1, 3, 1, 3]
x.shape

In [52]:
x = torch.randn(3, 3, 3)

x = x[None, :, None, :, None, :]
x.shape

In [53]:
# 综合使用 None 和广播机制

In [66]:
# 假设batch_size为16，features为256
a = torch.arange(16 * 256).view(16, 256)
a.shape

In [71]:
b = a.unsqueeze(1)  # b.shape = [16, 1, 256]
c = b.transpose(2, 1)  # c.shape = [16, 256, 1]
print((b @ c).shape)
print((c @ b).shape)  

In [72]:
b = a[:, None, :]
c = a[:, :, None]
print(b.shape)
print(c.shape)

In [73]:
print((b * c).shape)
print((c * b).shape)

In [76]:
# 补充：逐元素计算
a = torch.arange(16 * 256).view(16, 256)
a.shape

In [77]:
(a * a).shape