In [5]:
import torch
import torch.nn as nn
import math


class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.dim = hidden_dim

        # multi-head SA
        self.num_heads = 12
        self.dim_per_head = self.dim // self.num_heads  # 64
        self.Wq = nn.Linear(self.dim, self.num_heads*self.dim_per_head, bias=False)
        self.Wk = nn.Linear(self.dim, self.num_heads*self.dim_per_head, bias=False)
        self.Wv = nn.Linear(self.dim, self.num_heads*self.dim_per_head, bias=False)
        self.layerNorm_SA = nn.LayerNorm(self.dim)
        self.W_reshape_back = nn.Linear(self.num_heads*self.dim_per_head,self.dim)

        # FFN layer
        self.ffn1 = nn.Linear(self.dim,self.dim*4)
        self.ffn2 = nn.Linear(self.dim*4,self.dim)
        self.act = nn.GELU()
        self.layerNorm_ffn = nn.LayerNorm(self.dim)

        # dropout
        self.att_drop_prob = 0.1
        self.state_drop_prob = 0.5
        self.att_drop = nn.Dropout(self.att_drop_prob)
        self.state_drop = nn.Dropout(self.state_drop_prob)

    def compute_mask_logits(self,attention_mask):
        '''

        :param attention_mask:  (N,L)
        :return: (N,#heads, L,L)
        '''
        mask_logits = torch.zeros(attention_mask.size(0), self.num_heads, attention_mask.size(1), attention_mask.size(1))
        mask_logits = mask_logits + attention_mask[:, None, None, :]
        mask_logits = (1.0 - mask_logits) * -10000.
        return mask_logits


    def MultiHeadSelfAttention(self, x , attention_mask):
        '''

        :param x: (N,L,D)
        :param attention_mask: (N,L)
            1: normal token
            0: padding token
        :return: (N,L,D)
        '''
        '''
        Q,K,V:

        (N,L,(#heads * dph)) ->(N,#heads,L,dph)
        '''
        new_size = x.size()[:-1] + (self.num_heads,self.dim_per_head) # (N,L, #heads, dph)
        Q = self.Wq(x).view(*new_size).permute(0,2,1,3) # (N,#heads,L,dph)
        K = self.Wk(x).view(*new_size).permute(0,2,1,3)
        V = self.Wv(x).view(*new_size).permute(0,2,1,3)

        '''
        attention mask here:
        implementation idea: 一般来说, 在计算出attention_logits后, 在计算attention_score之前, 将那些要mask掉的padding entry的attention_logits减掉一个非常大的正数,
        这样它通过softmax之后的概率就很小了
        <=> 加上一个非常大的负数(叫做masked_logits)
        '''
        attention_logits = torch.matmul(Q,K.transpose(2,3))/math.sqrt(self.dim)
        attention_logits += self.compute_mask_logits(attention_mask)
        attention_score = nn.Softmax(-1)(attention_logits)
        attention_score = self.att_drop(attention_score)
        O = torch.matmul(attention_score,V) # (N,#heads,L,dph)
        O = O.permute(0,2,1,3)  # (N,L, #heads, dph)
        O = O.contiguous().view(x.size(0),x.size(1),-1) # (N,L, #heads*dph)
        O = self.W_reshape_back(O) # (N,L,D)
        O = self.state_drop(O)
        O = self.layerNorm_SA(x + O)
        return O

    def FFN(self,x):
        tmp1 = self.act(self.ffn1(x))
        tmp2 = self.ffn2(tmp1)
        output = self.state_drop(tmp2)
        output = self.layerNorm_ffn(x+output)
        return output

    def forward(self, x,attention_mask):
        '''

        :param x: shape (N,L,D) N is batch size, L is the length of the sequnce, D is the dimension of word embeddings
        :return: shape (N,L,D)
        '''
        x = self.MultiHeadSelfAttention(x,attention_mask)
        x = self.FFN(x)

        return x

In [6]:
layer = EncoderLayer(768)
x = torch.zeros(2, 256, 768)
mask = torch.ones(2, 256)
print(layer(x, mask))

tensor([[[-1.3027e-01, -2.9429e-01,  9.9336e-03,  ..., -1.8524e+00,
           1.7145e-01, -1.3801e+00],
         [-1.6916e-01,  8.8574e-03,  8.8574e-03,  ...,  8.8574e-03,
           2.3673e-02,  8.8574e-03],
         [ 1.3826e+00,  3.3214e-01, -2.2303e-03,  ..., -2.8271e-01,
           4.3603e-03, -1.3091e+00],
         ...,
         [ 1.7313e+00, -2.6134e-01,  6.2458e-01,  ..., -1.4259e+00,
           1.7066e-01,  5.4035e-01],
         [ 1.7454e+00, -2.9977e-02, -5.7963e-01,  ...,  1.2735e-01,
          -2.0289e-01, -2.9977e-02],
         [ 1.5966e+00, -2.8424e-02,  3.8446e-01,  ..., -1.4197e+00,
          -1.8418e-02, -1.3079e+00]],

        [[-1.2251e-03,  1.3571e-01,  2.5282e-01,  ...,  1.3117e-01,
           1.5548e-01, -1.2251e-03],
         [ 1.9450e+00,  5.9076e-01,  2.8833e-01,  ...,  5.6149e-01,
           1.8793e-01,  7.3232e-02],
         [ 2.1371e-02, -3.0977e-01,  2.6564e-01,  ..., -5.0757e-01,
           4.3785e-01,  1.6523e-01],
         ...,
         [ 9.8950e-02,  3