ref: https://space.bilibili.com/272226120

In [95]:
import torch
import torch.nn as nn 
from torch.nn import TransformerEncoderLayer
import math

In [85]:
class NaiveTransformerLayer(nn.Module):
    def __init__(self):
        super(NaiveTransformerLayer, self).__init__()
        self.dim = 768
        self.Wq = nn.Linear(self.dim, self.dim, bias=False)
        self.Wk = nn.Linear(self.dim, self.dim, bias=False)
        self.Wv = nn.Linear(self.dim, self.dim, bias=False)
        self.lm = nn.LayerNorm(self.dim)
        self.fnn1 = nn.Linear(self.dim, self.dim * 4, bias=True)
        self.fnn2 = nn.Linear(self.dim * 4, self.dim * 1, bias=True)
        self.act = nn.GELU()
        self.lm_ffn = nn.LayerNorm(self.dim)
        

    def SelfAttention(self, x):
        """
        input n*d
        output n*d
        """
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)
        attention_score = torch.mm(Q, K.transpose(0, 1)) / math.sqrt(self.dim)
        attention_score = nn.Softmax(dim=-1)(attention_score)
        O = torch.mm(attention_score, V)
        O = self.lm(x + O)
        return O
    def FFN(self, x):
        t1 = self.fnn1(x)
        t1 = self.act(t1)
        t2 = self.fnn2(t1)
        output = self.lm_ffn(t2 + x)
        return output

    def forward(self, x):
        # input x n*d
        # output n*d
        x = self.SelfAttention(x)
        x = self.FFN(x)
        return x


In [86]:
sample1 =torch.rand(size=(4,768))
# output1= NaiveTransformerLayer()(sample1)
net =NaiveTransformerLayer()
output1= net(sample1)
sample1.shape,output1.shape

(torch.Size([4, 768]), torch.Size([4, 768]))

In [87]:
## naivetransformerlayer 问题
    ## 1 batch
    ## 2 dropout
    ## 3 multihead
    ## 4 attention mask ,padding mask

In [88]:
## add batch , drop out
class Batch_NaiveTransformerLayer(nn.Module):
    def __init__(self):
        super(Batch_NaiveTransformerLayer, self).__init__()
        self.dim = 768
        self.Wq = nn.Linear(self.dim, self.dim, bias=False)
        self.Wk = nn.Linear(self.dim, self.dim, bias=False)
        self.Wv = nn.Linear(self.dim, self.dim, bias=False)
        self.lm = nn.LayerNorm(self.dim)
        self.fnn1 = nn.Linear(self.dim, self.dim * 4, bias=True)
        self.fnn2 = nn.Linear(self.dim * 4, self.dim * 1, bias=True)
        self.act = nn.GELU()
        self.lm_ffn = nn.LayerNorm(self.dim)
        self.att_drop_prob = 0.1
        self.state_drop_prob = 0.1
        self.att_drop = nn.Dropout(self.att_drop_prob)
        self.state_drop = nn.Dropout(self.state_drop_prob)

    def SelfAttention(self, x):
        """
        input n*d
        output n*d
        """
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)
        attention_score = torch.bmm(Q, K.transpose(-2, -1)) / math.sqrt(self.dim)
        attention_score = nn.Softmax(dim=-1)(attention_score)
        attention_score =self.att_drop(attention_score)
        O = torch.bmm(attention_score, V)
        O= self.state_drop(O)
        O = self.lm(x + O)
        return O

    def FFN(self, x):
        t1 = self.fnn1(x)
        t1 = self.act(t1)
        t2 = self.fnn2(t1)
        t2 =self.state_drop(t2)
        output = self.lm_ffn(t2 + x)
        return output

    def forward(self, x):
        # input x n*d
        # output n*d
        x = self.SelfAttention(x)
        x = self.FFN(x)
        return x


In [89]:
sample1 =torch.rand(size=(32,4,768))
# output1= NaiveTransformerLayer()(sample1)
net =Batch_NaiveTransformerLayer()
output1= net(sample1)
sample1.shape,output1.shape

(torch.Size([32, 4, 768]), torch.Size([32, 4, 768]))

In [90]:
## add multihead_attention
class MH_NaiveTransformerLayer(nn.Module):
    def __init__(self):
        super(MH_NaiveTransformerLayer, self).__init__()
        self.dim = 768
        self.num_heads = 12
        self.per_head_size = self.dim // self.num_heads  ## 64
        self.Wq = nn.Linear(self.dim, self.num_heads * self.per_head_size, bias=False)
        self.Wk = nn.Linear(self.dim, self.num_heads * self.per_head_size, bias=False)
        self.Wv = nn.Linear(self.dim, self.num_heads * self.per_head_size, bias=False)
        self.W = nn.Linear(self.num_heads*self.per_head_size,self.dim) ## 使用这个线性映射来保证多头后还能映射到原来的维度
        self.lm = nn.LayerNorm(self.dim)
        self.fnn1 = nn.Linear(self.dim, self.dim * 4, bias=True)
        self.fnn2 = nn.Linear(self.dim * 4, self.dim * 1, bias=True)
        self.act = nn.GELU()
        self.lm_ffn = nn.LayerNorm(self.dim)
        self.att_drop_prob = 0.1
        self.state_drop_prob = 0.5
        self.att_drop = nn.Dropout(self.att_drop_prob)
        self.state_drop = nn.Dropout(self.state_drop_prob)

    def SelfAttention(self, x):
        """
        input batch_size * N*(Head_num*Head_size) ==>rearray ==> batch_size * N*Head_num*Head_size==>rearray ==>batch_size * Head_num * N * Head_size
        output batch_size* n * d
        """
        # torch.Size([32, 4, 12, 64])
        # x.size 返回的是torch.Size,这是tuple的字类,两个tuple合并使用的是 +
        new_size = x.size()[:-1] + (self.num_heads, self.per_head_size)  # B* N* H*S
        Q = self.Wq(x).view(*new_size).permute(0,2,1,3)
        K = self.Wk(x).view(*new_size).permute(0,2,1,3)
        V = self.Wv(x).view(*new_size).permute(0,2,1,3)
        attention_score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.dim)
        attention_score = nn.Softmax(dim=-1)(attention_score)
        attention_score = self.att_drop(attention_score)
        O = torch.matmul(attention_score, V)
        O = self.W(O.permute(0, 2, 1, 3).contiguous().view(x.size(0), x.size(1), -1)) #  #O#:B* H* N* S==>B* N* H* S==> bxnxd
        O = self.state_drop(O)
        O = self.lm(x + O)
        return O

    def FFN(self, x):
        t1 = self.fnn1(x)
        t1 = self.act(t1)
        t2 = self.fnn2(t1)
        t2 = self.state_drop(t2)
        output = self.lm_ffn(t2 + x)
        return output

    def forward(self, x):
        # input x n*d
        # output n*d
        x = self.SelfAttention(x)
        x = self.FFN(x)
        return x


In [91]:
sample1 =torch.rand(size=(32,4,768))
# output1= NaiveTransformerLayer()(sample1)
net =MH_NaiveTransformerLayer()
output1= net(sample1)
sample1.shape,output1.shape

(torch.Size([32, 4, 768]), torch.Size([32, 4, 768]))

In [92]:
## add mask
## add multihead_attention
class Mask_TransformerLayer(nn.Module):
    def __init__(self):
        super(Mask_TransformerLayer, self).__init__()
        self.dim = 768
        self.num_heads = 12
        self.per_head_size = self.dim // self.num_heads  ## 64
        self.Wq = nn.Linear(self.dim, self.num_heads * self.per_head_size, bias=False)
        self.Wk = nn.Linear(self.dim, self.num_heads * self.per_head_size, bias=False)
        self.Wv = nn.Linear(self.dim, self.num_heads * self.per_head_size, bias=False)
        self.W = nn.Linear(
            self.num_heads * self.per_head_size, self.dim
        )  ## 使用这个线性映射来保证多头后还能映射到原来的维度
        self.lm = nn.LayerNorm(self.dim)
        self.fnn1 = nn.Linear(self.dim, self.dim * 4, bias=True)
        self.fnn2 = nn.Linear(self.dim * 4, self.dim * 1, bias=True)
        self.act = nn.GELU()
        self.lm_ffn = nn.LayerNorm(self.dim)
        self.att_drop_prob = 0.1
        self.state_drop_prob = 0.5
        self.att_drop = nn.Dropout(self.att_drop_prob)
        self.state_drop = nn.Dropout(self.state_drop_prob)

    def calc_mask_score(self, attention_mask):
        """
        input bxn
        output b*h*n*n 对于任意一个head 中,那些点会被注意到 n*n
        """
        mask_score = torch.zeros(
            attention_mask.size(0),
            self.num_heads,
            attention_mask.size(1),
            attention_mask.size(1),
        )
        mask_score = mask_score + attention_mask[:, None, None, :]
        mask_score = (1.0 - mask_score) * -10000.
        return mask_score

    def SelfAttention(self, x, attention_mask):
        """
        input batch_size * N*(Head_num*Head_size) ==>rearray ==> batch_size * N*Head_num*Head_size==>rearray ==>batch_size * Head_num * N * Head_size
        output batch_size* n * d

        attention mask  batchsize*N  ,N 序列长度
        1 normal token
        0 masked token
        """
        # torch.Size([32, 4, 12, 64])
        # x.size 返回的是torch.Size,这是tuple的字类,两个tuple合并使用的是 +
        new_size = x.size()[:-1] + (self.num_heads, self.per_head_size)  # B* N* H*S
        Q = self.Wq(x).view(*new_size).permute(0, 2, 1, 3)
        K = self.Wk(x).view(*new_size).permute(0, 2, 1, 3)
        V = self.Wv(x).view(*new_size).permute(0, 2, 1, 3)
        attention_score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.dim)
        ## attention mask here
        mask_score = attention_score + self.calc_mask_score(attention_mask)

        attention_score = nn.Softmax(dim=-1)(attention_score)
        attention_score = self.att_drop(attention_score)
        O = torch.matmul(attention_score, V)
        O = self.W(
            O.permute(0, 2, 1, 3).contiguous().view(x.size(0), x.size(1), -1)
        )  #  #O#:B* H* N* S==>B* N* H* S==> bxnxd
        O = self.state_drop(O)
        O = self.lm(x + O)
        return O

    def FFN(self, x):
        t1 = self.fnn1(x)
        t1 = self.act(t1)
        t2 = self.fnn2(t1)
        t2 = self.state_drop(t2)
        output = self.lm_ffn(t2 + x)
        return output

    def forward(self,x,attention_mask):
        # input x n*d
        # output n*d
        x = self.SelfAttention(x,attention_mask)
        x = self.FFN(x)
        return x


In [93]:
sample1 =torch.rand(size=(2, 256, 768)) # batchsize* N* Embedding_size
# output1= NaiveTransformerLayer()(sample1)
mask = torch.ones(2, 256)
net =Mask_TransformerLayer()
output1= net(sample1,mask)
sample1.shape,output1.shape


(torch.Size([2, 256, 768]), torch.Size([2, 256, 768]))

In [94]:
output1

tensor([[[ 0.8729,  1.0437,  0.9460,  ...,  0.7100,  0.8173, -1.1307],
         [-2.5397, -0.3485, -1.0852,  ..., -1.0142,  0.6679,  1.7684],
         [-0.3519, -0.9706, -0.4831,  ..., -0.8022,  1.4280,  0.7184],
         ...,
         [ 1.4107,  0.8165,  0.2495,  ..., -0.0963, -0.3215, -0.2316],
         [-0.5850,  1.0678, -0.1695,  ...,  0.1663, -0.3970, -0.0920],
         [-0.9799, -0.1537, -0.6260,  ..., -1.2784,  0.4568, -0.5208]],

        [[ 0.2158,  0.4168, -0.0114,  ..., -0.9031,  0.5463, -0.3733],
         [-2.5492, -1.2238, -0.8947,  ...,  0.4318,  0.6822, -0.7082],
         [-0.8914, -0.8931, -0.3195,  ..., -1.1099, -0.1089, -1.0964],
         ...,
         [-0.5354, -1.4312,  0.6768,  ..., -0.3599, -0.5681, -0.5770],
         [-0.7852, -1.1449,  0.2709,  ...,  0.6154, -0.1276, -0.9554],
         [-0.1394, -1.2817, -1.6374,  ...,  0.9248,  1.3273, -0.0487]]],
       grad_fn=<NativeLayerNormBackward0>)