In [3]:
import torch
import torch.nn as nn
import math


class Encoder(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.dim = hidden_dim

        # multi-head SA
        self.num_heads = 12
        self.dim_per_head = self.dim // self.num_heads  # 64
        self.Wq = nn.Linear(self.dim, self.num_heads*self.dim_per_head, bias=False)
        self.Wk = nn.Linear(self.dim, self.num_heads*self.dim_per_head, bias=False)
        self.Wv = nn.Linear(self.dim, self.num_heads*self.dim_per_head, bias=False)
        self.layerNorm_SA = nn.LayerNorm(self.dim)
        self.W_reshape_back = nn.Linear(self.num_heads*self.dim_per_head,self.dim)

        # FFN layer
        self.ffn1 = nn.Linear(self.dim,self.dim*4)
        self.ffn2 = nn.Linear(self.dim*4,self.dim)
        self.act = nn.GELU()
        self.layerNorm_ffn = nn.LayerNorm(self.dim)

        # dropout
        self.att_drop_prob = 0.1
        self.state_drop_prob = 0.5
        self.att_drop = nn.Dropout(self.att_drop_prob)
        self.state_drop = nn.Dropout(self.state_drop_prob)

    def compute_mask_logits(self,attention_mask):
        '''

        :param attention_mask:  (N,L)
        :return: (N,#heads, L,L)
        '''
        mask_logits = torch.zeros(attention_mask.size(0), self.num_heads, attention_mask.size(1), attention_mask.size(1))
        mask_logits = mask_logits + attention_mask[:, None, None, :]
        mask_logits = (1.0 - mask_logits) * -10000.
        return mask_logits


    def MultiHeadSelfAttention(self, x , attention_mask):
        '''

        :param x: (N,L,D)
        :param attention_mask: (N,L)
            1: normal token
            0: padding token
        :return: (N,L,D)
        '''
        '''
        Q,K,V:

        (N,L,(#heads * dph)) ->(N,#heads,L,dph)
        '''
        new_size = x.size()[:-1] + (self.num_heads,self.dim_per_head) # (N,L, #heads, dph)
        Q = self.Wq(x).view(*new_size).permute(0,2,1,3) # (N,#heads,L,dph)
        K = self.Wk(x).view(*new_size).permute(0,2,1,3)
        V = self.Wv(x).view(*new_size).permute(0,2,1,3)

        '''
        attention mask here:
        implementation idea: 一般来说, 在计算出attention_logits后, 在计算attention_score之前, 将那些要mask掉的padding entry的attention_logits减掉一个非常大的正数,
        这样它通过softmax之后的概率就很小了
        <=> 加上一个非常大的负数(叫做masked_logits)
        '''
        attention_logits = torch.matmul(Q,K.transpose(2,3))/math.sqrt(self.dim)
        attention_logits += self.compute_mask_logits(attention_mask)
        attention_score = nn.Softmax(-1)(attention_logits)
        attention_score = self.att_drop(attention_score)
        O = torch.matmul(attention_score,V) # (N,#heads,L,dph)
        O = O.permute(0,2,1,3)  # (N,L, #heads, dph)
        O = O.contiguous().view(x.size(0),x.size(1),-1) # (N,L, #heads*dph)
        O = self.W_reshape_back(O) # (N,L,D)
        O = self.state_drop(O)
        O = self.layerNorm_SA(x + O)
        return O

    def FFN(self,x):
        tmp1 = self.act(self.ffn1(x))
        tmp2 = self.ffn2(tmp1)
        output = self.state_drop(tmp2)
        output = self.layerNorm_ffn(x+output)
        return output

    def forward(self, x,attention_mask):
        '''

        :param x: shape (N,L,D) N is batch size, L is the length of the sequnce, D is the dimension of word embeddings
        :return: shape (N,L,D)
        '''
        x = self.MultiHeadSelfAttention(x,attention_mask)
        x = self.FFN(x)

        return x

In [4]:
layer = Encoder(768)
x = torch.zeros(2, 256, 768)
mask = torch.ones(2, 256)
print(layer(x, mask))

tensor([[[ 0.1296,  0.2674,  0.3411,  ...,  0.2935,  1.8878,  0.0240],
         [-0.8150, -0.4709, -0.0282,  ..., -0.1764, -0.0612,  0.0415],
         [ 0.0180,  0.1580, -0.0664,  ..., -0.0574,  2.3035,  0.0326],
         ...,
         [-0.4061, -0.0716, -0.0716,  ..., -0.4664,  0.2013,  0.2898],
         [ 0.1330, -0.7335, -0.1249,  ...,  0.0414,  0.6644,  0.3684],
         [ 0.1406,  0.1403, -0.1040,  ..., -0.6165, -0.1040,  0.0038]],

        [[-0.1078, -0.1078, -0.0817,  ..., -0.1078,  0.6668, -0.2443],
         [-0.3618, -0.0478,  0.3415,  ..., -0.1585,  1.7119, -0.1204],
         [-0.7267, -0.0633, -0.0876,  ...,  0.2774,  0.0530,  0.1673],
         ...,
         [-0.0402, -0.2177,  1.0245,  ...,  0.3415, -0.1628, -0.0402],
         [-0.8310, -0.0564,  0.2999,  ...,  0.3284, -0.0564, -0.0564],
         [-0.8092, -0.0789, -0.0469,  ..., -0.0789, -0.0789, -0.0940]]],
       grad_fn=<NativeLayerNormBackward0>)
