In [8]:
import torch.nn as nn
import torch

In [9]:
multihead_attn = nn.MultiheadAttention(embed_dim=200,  # E_q(E_q必须能整除num_heads)
                                       num_heads=5,  # 注意力的数目
                                       # 默认kdim=None(即kdim=embed_dim)
                                       kdim=100,  # E_k
                                       # 默认vdim=None(即vdim=embed_dim)
                                       vdim=50,  # E_v
                                       dropout=0.1  # Dropout probability on `attn_output_weights`
                                       )

# mask.shape=(14, 4)
mask = torch.arange(4)[None, :] < torch.tensor([2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4])[:, None]
# mask.shape=(1, 14, 4)
mask = torch.unsqueeze(mask, 0)
# mask.shape=(40, 14, 4)=(N * num_heads, L, S)
mask = torch.repeat_interleave(mask, 40, dim=0)
mask = ~mask

query = torch.randn(14, 8, 200)  # query.shape=(L, N, E_q);L is the target sequence length
key = torch.randn(4, 8, 100)  # key.shape=(S, N, E_k);S is the source sequence length
value = torch.randn(4, 8, 50)  # value.shape=(S, N, E_v)

In [10]:
attn_output, attn_output_weights = multihead_attn(query=query,
                                                  key=key,
                                                  value=value,
                                                  # attn_mask.shape=(N * num_heads, L, S) or (L, S)
                                                  # a True value indicates that the corresponding position is not allowed to attend
                                                  attn_mask=mask)

In [11]:
# ★★★★★注意: pytorch输出的是头的平均注意力分数(tensorflow输出的是所有头的注意力分数)
"""
内部机制:
if need_weights:
    # average attention weights over heads
    attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
    return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
    return attn_output, None
"""
# attn_output.shape=(L, N, E_q)
attn_output.shape

torch.Size([14, 8, 200])

In [12]:
# attn_output_weights.shape=(N, L, S)
attn_output_weights.shape

torch.Size([8, 14, 4])

In [13]:
print(mask[0, :, :])
print(attn_output_weights[0, :, :])

tensor([[False, False,  True,  True],
        [False, False,  True,  True],
        [False, False,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False,  True],
        [False, False, False,  True],
        [False, False, False,  True],
        [False, False, False,  True],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]])
tensor([[0.5695, 0.3725, 0.0000, 0.0000],
        [0.4529, 0.4285, 0.0000, 0.0000],
        [0.6418, 0.4693, 0.0000, 0.0000],
        [0.3648, 0.7463, 0.0000, 0.0000],
        [0.3560, 0.3851, 0.3700, 0.0000],
        [0.2858, 0.1156, 0.5682, 0.0000],
        [0.3019, 0.4454, 0.3176, 0.0000],
        [0.3328, 0.4049, 0.3734, 0.0000],
        [0.4084, 0.2755, 0.4273, 0.0000],
        [0.3450, 0.1566, 0.1948, 0.1429],
        [0.4138, 0.1709, 0.3063, 0.1878],
     

In [14]:
for i, j in multihead_attn.named_parameters():
    print(str(i) + ".shape=", j.shape)

q_proj_weight.shape= torch.Size([200, 200])
k_proj_weight.shape= torch.Size([200, 100])
v_proj_weight.shape= torch.Size([200, 50])
in_proj_bias.shape= torch.Size([600])
out_proj.weight.shape= torch.Size([200, 200])
out_proj.bias.shape= torch.Size([200])
