In [1]:
import torch
import torch.nn as nn

In [2]:
# 参数函数与Embedding完全相同
embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')

input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
offsets = torch.tensor([0, 4], dtype=torch.long)

"""
offsets (Tensor, optional) –
    Only used when input is 1D. offsets determines the starting index position of each bag (sequence) in input.
"""
embedding_sum(input=input,
              offsets=offsets)  # 确定每个序列序列的起始索引位置

tensor([[-3.5266,  0.0578,  2.4299],
        [-1.5867,  1.5058, -0.0771]], grad_fn=<EmbeddingBagBackward0>)

In [3]:
# 与上等价
input_1 = torch.tensor([[1, 2, 4, 5],
                        [4, 3, 2, 9]])
embedding_sum(input_1)

tensor([[-3.5266,  0.0578,  2.4299],
        [-1.5867,  1.5058, -0.0771]], grad_fn=<EmbeddingBagBackward0>)

* with mode="sum" is equivalent to Embedding followed by torch.sum(dim=1),

* with mode="mean" is equivalent to Embedding followed by torch.mean(dim=1),

* with mode="max" is equivalent to Embedding followed by torch.max(dim=1).

In [7]:
weight = torch.tensor([[1, 2.3, 3],
                       [4, 5.1, 6.3]])
input = torch.tensor([[1, 0]])

embeddingbag_mean = nn.EmbeddingBag.from_pretrained(weight, mode="mean")(input)  # 默认mode='mean'
embeddingbag_max = nn.EmbeddingBag.from_pretrained(weight, mode='sum')(input)
embeddingbag_sum = nn.EmbeddingBag.from_pretrained(weight, mode='max')(input)
embedding = nn.Embedding.from_pretrained(weight)(input)

print(embeddingbag_mean)
print(embeddingbag_mean.shape)
print(embeddingbag_max)
print(embeddingbag_sum)
print(embedding)
print(torch.sum(embedding, dim=1))
print(embedding.shape)

tensor([[2.5000, 3.7000, 4.6500]])
torch.Size([1, 3])
tensor([[5.0000, 7.4000, 9.3000]])
tensor([[4.0000, 5.1000, 6.3000]])
tensor([[[4.0000, 5.1000, 6.3000],
         [1.0000, 2.3000, 3.0000]]])
tensor([[5.0000, 7.4000, 9.3000]])
torch.Size([1, 2, 3])
