In [1]:
import torch
from torch import nn
import math

In [2]:
torch.manual_seed(6)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


x = torch.randint(100, (15, 32, 300))
pe = PositionalEncoding(300)
y_1 = pe(x)

print('-' * 10)
# x.shape
# print(x)
print(y_1)

----------
tensor([[[ 33.3333,  68.8889,   3.3333,  ...,   6.6667,  72.2222,  92.2222],
         [  0.0000,  62.2222, 105.5556,  ..., 108.8889,  18.8889,  26.6667],
         [ 36.6667,  40.0000,  42.2222,  ...,  17.7778, 106.6667,  25.5556],
         ...,
         [ 30.0000,  88.8889,  57.7778,  ..., 101.1111,  43.3333,  75.5556],
         [ 84.4445,  23.3333,  10.0000,  ...,   0.0000,  41.1111,  86.6667],
         [ 54.4444,  81.1111,  71.1111,  ...,  76.6667,  13.3333,  88.8889]],

        [[ 12.0461,   7.2670,  96.4531,  ...,   0.0000,   0.0000,  13.3333],
         [ 35.3794, 108.3781,   8.6754,  ...,   7.7778,  85.5557,  26.6667],
         [  5.3794,   0.0000,  40.8976,  ...,  18.8889,   0.0000,  10.0000],
         ...,
         [ 57.6016, 101.7115,  68.6754,  ...,  17.7778,  85.5557,   2.2222],
         [ 57.6016,   0.0000,  80.8976,  ...,   0.0000,  63.3335,  43.3333],
         [ 45.3794,  81.7115,   6.4531,  ...,  26.6667,  72.2223,   7.7778]],

        [[ 87.6770,  30.6487,  77

In [24]:
torch.manual_seed(6)


class PositionalEncoding2(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding2, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.pow(10000, torch.arange(0, d_model, 2, dtype=torch.float32) / d_model)
        # print(div_term)
        # div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        print(pe.shape)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        '''
        print(self.pe.shape)
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


max_len = 10
batch_size = 32
d_model = 512
PE2 = PositionalEncoding2(d_model, max_len=max_len)
x = torch.zeros(max_len, batch_size, d_model)
y = PE2(x)
y.shape

torch.Size([10, 1, 512])
torch.Size([10, 1, 512])


torch.Size([10, 32, 512])

In [4]:
y_2 - y_1

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.3333e+01,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -4.0000e+01,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  5.5556e+01,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  9.8889e+01,
           8.3333e+01,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  9.3934e+01,  0.0000e+00,  ...,  0.0000e+00,
           7.2222e+01,  0.0000e+00],
         ...,
         [-5.7602e+01,  0

In [5]:
d_model = 100

div_term = torch.pow(10000, torch.arange(0, d_model, 2).float() / d_model)
print(div_term.shape)
div_term2 = 1 / torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
print(div_term2.shape)

print(div_term - div_term2)

torch.Size([50])
torch.Size([50])
tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -2.3842e-07,
        -2.3842e-07,  0.0000e+00,  0.0000e+00, -4.7684e-07,  0.0000e+00,
        -4.7684e-07,  0.0000e+00,  0.0000e+00, -1.9073e-06,  0.0000e+00,
        -9.5367e-07, -3.8147e-06,  0.0000e+00, -1.9073e-06, -3.8147e-06,
        -7.6294e-06, -1.5259e-05, -1.1444e-05, -7.6294e-06,  0.0000e+00,
         0.0000e+00, -3.8147e-05, -3.0518e-05,  0.0000e+00, -7.6294e-05,
        -3.0518e-05,  0.0000e+00, -1.2207e-04, -6.1035e-05,  0.0000e+00,
        -1.8311e-04, -1.2207e-04,  6.1035e-05, -2.4414e-04, -8.5449e-04,
        -7.3242e-04, -4.8828e-04, -1.2207e-03, -9.7656e-04, -4.8828e-04,
        -2.1973e-03, -1.4648e-03, -3.4180e-03, -4.8828e-04,  1.9531e-03])


# Mask

由于在 Encoder 和 Decoder 中都需要进行 mask 操作，因此就无法确定这个函数的参数中 seq_len 的值，如果是在 Encoder 中调用的，seq_len 就等于 src_len；如果是在 Decoder 中调用的，seq_len 就有可能等于 src_len，也有可能等于 tgt_len（因为 Decoder 有两次 mask）

这个函数最核心的一句代码是 seq_k.data.eq(0)，这句的作用是返回一个大小和 seq_k 一样的 tensor，只不过里面的值只有 True 和 False。如果 seq_k 某个位置的值等于 0，那么对应位置就是 True，否则即为 False。举个例子，输入为 seq_data = [1, 2, 3, 4, 0]，seq_data.data.eq(0) 就会返回 [False, False, False, False, True]

剩下的代码主要是扩展维度，强烈建议读者打印出来，看看最终返回的数据是什么样子

In [6]:
import numpy as np


def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], True is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

In [7]:
seq_k = torch.randint(100, (32, 15))
seq_k.eq(0).unsqueeze(1)
# seq_k.data.eq(0).unsqueeze(1)

tensor([[[False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False]],

        [[False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False]],

        [[False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False]],

        [[False, False, False, False, False, False, False, False, False, False,
          False, False, False,  True, False]],

        [[False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False]],

        [[False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False]],

        [[False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False]],

        [[False, False, False, False, False, False, False, False, False, False,
          False, False, 

In [8]:
def get_attn_subsequence_mask(seq):
    '''
    seq: [batch_size, tgt_len]
    '''
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)  # Upper triangular matrix
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask  # [batch_size, tgt_len, tgt_len]

In [9]:
from torch import nn

nn.TransformerDecoder

torch.nn.modules.transformer.TransformerDecoder