In [1]:
import torch
import torch.nn as nn
import numpy
import torch.nn.functional as F

# word embedding 以序列建模为例子
# 构建序列，序列的字符以其在词表中的索引的形式表示
# 考虑source sentence 和 target sentence


batch_size = 2
# 单词长度
max_src_num_words = 8
max_tgt_num_words = 8

# 序列最大长度
max_src_seq_len = 5
max_tgt_seq_len = 5
max_position_len = 5
model_dim = 8   # 512
# src_len = torch.randint(2, 5, (batch_size,))
# tgt_len = torch.randint(2, 5, (batch_size,))

# 原序列长度分别为2和4
# 目标序列长度为4和3
src_len = torch.Tensor([2, 4]).to(torch.int32)
tgt_len = torch.Tensor([4, 3]).to(torch.int32)


# step1: 生成原序列和目标序列
# 单词索引构成的句子(同时使用F.pad)
# 把每一个句子变成二维然后在第0维cat
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_src_num_words, (L,)), (0, max(src_len)-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_tgt_num_words, (L,)), (0, max(tgt_len)-L)), 0) for L in tgt_len])

#step2: 构造embedding(+1是因为padding的0不在字符表中)
src_embedding_table = nn.Embedding(max_src_num_words+1, model_dim)
tgt_embedding_table = nn.Embedding(max_tgt_num_words+1, model_dim)

src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

# step3: 构建position embedding
post_mat = torch.arange(max_position_len).reshape((-1, 1))
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1))/model_dim)
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:, 0::2] = torch.sin(post_mat/i_mat)
pe_embedding_table[:, 1::2] = torch.cos(post_mat/i_mat)
# print(pe_embedding_table)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)
# print(pe_embedding.weight)

# 输出【【0，1，2，3】【0，1，2，3】】
src_pos = torch.cat([torch.unsqueeze(torch.arange(max(src_len)),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max(tgt_len)),0) for _ in tgt_len]).to(torch.int32)
# print(src_pos)
# print(tgt_pos)
src_pe_emnedding = pe_embedding(src_pos)
tgt_pe_emnedding = pe_embedding(tgt_pos)
# print(src_pe_emnedding)
# print(tgt_pe_emnedding)

# step4: 构建encoder的self-attention mask (构建原因：不知道句子的具体长度，对于padding的部分需要进行掩码处理)
# mask 的 shape [batchsize, max_src_len, max_src_len]  值为1或-inf

valid_encoder_pos =torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(src_len)-L)),0) for L in src_len ])
valid_encoder_pos = torch.unsqueeze(valid_encoder_pos,2)
print(valid_encoder_pos)
print(valid_encoder_pos.shape)

valid_encoder_pos_matrix =torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1,2))
print(valid_encoder_pos_matrix)

tensor([[[1.],
         [1.],
         [0.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         [1.]]])
torch.Size([2, 4, 1])
tensor([[[1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])


由于第一个句子长度为2，所以第一个单词与第一第二个单词能建立联系，所以矩阵为1，1，0，0。其他行同理

In [2]:
print(src_len)

tensor([2, 4], dtype=torch.int32)


In [3]:
# 得到需要被mask的位置
invalid_encoder_pos_matrix = 1- valid_encoder_pos_matrix
mask_ecodner_selfattention_matrix = invalid_encoder_pos_matrix.to(torch.bool)
print(mask_ecodner_selfattention_matrix)

tensor([[[False, False,  True,  True],
         [False, False,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True]],

        [[False, False, False, False],
         [False, False, False, False],
         [False, False, False, False],
         [False, False, False, False]]])


In [4]:
# 模拟得到的score
score = torch.randn(batch_size, max(src_len), max(src_len))
print(score)

tensor([[[-0.2359,  0.2242,  0.3788, -0.1779],
         [ 0.1949, -1.7529,  0.0588, -0.6717],
         [-0.0459,  1.8927, -0.2315, -0.4335],
         [-0.2519,  1.6211, -1.6212, -0.1008]],

        [[ 1.0812,  0.2933,  1.1187,  0.1584],
         [ 1.4733, -0.4585,  0.0456, -0.1305],
         [-1.2709, -0.5012,  0.9085,  0.2950],
         [-0.9198,  0.0420,  1.0917, -0.3686]]])


In [5]:
mask_score = score.masked_fill(mask_ecodner_selfattention_matrix, 1e-9)
print(mask_score)
prob = F.softmax(mask_score, -1)
print(prob)

tensor([[[-2.3590e-01,  2.2422e-01,  1.0000e-09,  1.0000e-09],
         [ 1.9492e-01, -1.7529e+00,  1.0000e-09,  1.0000e-09],
         [ 1.0000e-09,  1.0000e-09,  1.0000e-09,  1.0000e-09],
         [ 1.0000e-09,  1.0000e-09,  1.0000e-09,  1.0000e-09]],

        [[ 1.0812e+00,  2.9331e-01,  1.1187e+00,  1.5838e-01],
         [ 1.4733e+00, -4.5854e-01,  4.5614e-02, -1.3045e-01],
         [-1.2709e+00, -5.0124e-01,  9.0853e-01,  2.9504e-01],
         [-9.1982e-01,  4.2036e-02,  1.0917e+00, -3.6862e-01]]])
tensor([[[0.1955, 0.3096, 0.2475, 0.2475],
         [0.3586, 0.0511, 0.2951, 0.2951],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.3460, 0.1573, 0.3592, 0.1375],
         [0.6306, 0.0914, 0.1512, 0.1268],
         [0.0596, 0.1286, 0.5267, 0.2852],
         [0.0780, 0.2040, 0.5827, 0.1353]]])


In [9]:
# step5: intra_attention_mask
# Q*K^T  shape [batch_size, tgt_seq_len, src_seq_len]
valid_encoder_pos =torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(src_len)-L)),0) for L in src_len ]), 2)
valid_decoder_pos =torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(tgt_len)-L)),0) for L in tgt_len ]), 2)
# 表明原序列和目标序列的相关性
valid_cross_pos =torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1, 2))
print(valid_encoder_pos)
print(valid_decoder_pos)
print(valid_cross_pos)

tensor([[[1.],
         [1.],
         [0.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         [1.]]])
tensor([[[1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [0.]]])
tensor([[[1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.]]])


In [23]:
invalid_cross_pos = 1-valid_cross_pos 
print(invalid_cross_pos)
mask_cross_attention = valid_cross_pos.to(torch.bool)
print(mask_cross_attention)

encoder_mask_score = score.masked_fill(mask_cross_attention, -1e9)
encode_prob = F.softmax(encoder_mask_score, -1)
print(encode_prob)

tensor([[[0., 0., 1., 1.],
         [0., 0., 1., 1.],
         [0., 0., 1., 1.],
         [0., 0., 1., 1.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [1., 1., 1., 1.]]])
tensor([[[ True,  True, False, False],
         [ True,  True, False, False],
         [ True,  True, False, False],
         [ True,  True, False, False]],

        [[ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True],
         [False, False, False, False]]])
tensor([[[0.0000, 0.0000, 0.6357, 0.3643],
         [0.0000, 0.0000, 0.6749, 0.3251],
         [0.0000, 0.0000, 0.5503, 0.4497],
         [0.0000, 0.0000, 0.1794, 0.8206]],

        [[0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.0780, 0.2040, 0.5827, 0.1353]]])


In [24]:
# step6:decoder self-attention_mask构建
valid_decoder_tri_matrix = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones((L,L))),(0,max(tgt_len)-L,0,max(tgt_len)-L)),0)\
                                      for L in tgt_len])
print(valid_decoder_tri_matrix)

tensor([[[1., 0., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 1., 0.],
         [1., 1., 1., 1.]],

        [[1., 0., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 1., 0.],
         [0., 0., 0., 0.]]])


In [25]:
invalid_decoder_tri_matrix = 1- valid_decoder_tri_matrix
invalid_decoder_tri_matrix = invalid_decoder_tri_matrix.to(torch.bool)
print(invalid_decoder_tri_matrix)

tensor([[[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [False, False, False, False]],

        [[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [ True,  True,  True,  True]]])


In [27]:
decoder_score = torch.randn(batch_size, max(tgt_len), max(tgt_len))
decoder_mask_score = decoder_score.masked_fill(invalid_decoder_tri_matrix, -1e9)
decoder_mask_prob = F.softmax(decoder_mask_score,-1)
print(decoder_mask_score)
print(decoder_mask_prob)

tensor([[[ 9.3735e-02, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [ 5.9159e-01, -1.6989e+00, -1.0000e+09, -1.0000e+09],
         [-5.6981e-02, -5.7947e-01,  1.3748e+00, -1.0000e+09],
         [ 1.1662e-01,  1.0154e+00, -5.7117e-01,  4.5714e-01]],

        [[ 9.4912e-01, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0390e+00,  2.1273e+00, -1.0000e+09, -1.0000e+09],
         [ 8.1318e-01,  7.0122e-04,  1.0762e+00, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]]])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.9081, 0.0919, 0.0000, 0.0000],
         [0.1730, 0.1026, 0.7244, 0.0000],
         [0.1864, 0.4579, 0.0937, 0.2620]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.0405, 0.9595, 0.0000, 0.0000],
         [0.3644, 0.1617, 0.4740, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500]]])


In [None]:
# step7: 构建scaled self-attention
def scaled_dot_product_attention(Q,K,V，atten_mask):
    score = torch.bmm(Q,K.transpose(-2,-1))/torch.sqrt(model_dim)
    mask_score = score.masked_fill(atten_mask, -1e9)
    prob = F.softmax(mask_score,-1)
    context = torch.bmm(prob, V)
    return context