In [1]:
import torch
import torch.nn as nn
import math


class DecoderLayer(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.dim = hidden_dim

        # multi-head SA
        self.num_heads = 12
        self.dim_per_head = self.dim // self.num_heads
        self.Wq = nn.Linear(self.dim, self.num_heads * self.dim_per_head, bias=False)
        self.Wk = nn.Linear(self.dim, self.num_heads * self.dim_per_head, bias=False)
        self.Wv = nn.Linear(self.dim, self.num_heads * self.dim_per_head, bias=False)
        self.layerNorm_SA = nn.LayerNorm(self.dim)
        self.W_reshape_back_SA = nn.Linear(self.num_heads * self.dim_per_head, self.dim)

        # Multi-head CA
        self.Wq2 = nn.Linear(self.dim, self.num_heads * self.dim_per_head, bias=False)
        self.Wk2 = nn.Linear(self.dim, self.num_heads * self.dim_per_head, bias=False)
        self.Wv2 = nn.Linear(self.dim, self.num_heads * self.dim_per_head, bias=False)
        self.layerNorm_CA = nn.LayerNorm(self.dim)
        self.W_reshape_back_CA = nn.Linear(self.num_heads * self.dim_per_head, self.dim)


        # FFN
        self.ffn1 = nn.Linear(self.dim,self.dim*4)
        self.ffn2 = nn.Linear(self.dim*4,self.dim)
        self.act = nn.GELU()
        self.layerNorm_ffn = nn.LayerNorm(self.dim)

        # dropout
        self.att_drop_prob = 0.1
        self.state_drop_prob = 0.5
        self.att_drop = nn.Dropout(self.att_drop_prob)
        self.state_drop = nn.Dropout(self.state_drop_prob)



    def _compute_SA_pad_mask(self, attention_mask):
        '''

        :param attention_mask:  (N,L1)
        :return: (N,#heads, L1,L1)
        '''
        mask = torch.zeros((attention_mask.size(0), self.num_heads, attention_mask.size(1),
                                  attention_mask.size(1)),dtype=torch.int32)
        mask = mask + attention_mask[:, None, None, :]
        return mask

    def _compute_att_subsequence_mask(self, x):
        '''

        :param x: (N,L1,D)
        :return: (N,#heads, L1,L1)
        '''
        mask = torch.zeros((x.size(0),self.num_heads,x.size(1),x.size(1)),dtype=torch.int32)
        ones = torch.tril(torch.ones((x.size(1), x.size(1)),dtype=torch.int32), diagonal=0)
        mask += ones
        return mask

    def _compute_SA_mask_logits(self, x, attention_mask):
        '''

        :param x: (N,L1,D)
        :param attention_mask: (N,L1)
        :return: (N,#heads, L1,L1)
        '''
        att_pad_mask = self._compute_SA_pad_mask(attention_mask)
        att_subseq_mask = self._compute_att_subsequence_mask(x)
        mask = att_pad_mask & att_subseq_mask
        mask_logits = (1.0 - mask) * -10000.0
        return mask_logits

    def _compute_CA_pad_mask(self,x,enc_att_mask):
        '''

        :param x: decoder input: (N,L1,D)
        :param enc_att_mask: (N,L2)
        :return: (N,#heads, L1,L2)
        '''
        mask = torch.zeros((x.size(0), self.num_heads, x.size(1),
                            enc_att_mask.size(1)), dtype=torch.int32)
        mask = mask + enc_att_mask[:, None, None, :]
        return mask

    def _compute_CA_mask_logits(self,x,enc_att_mask):
        '''

        :param x: decoder input: (N,L1,D)
        :param enc_att_mask: (N,L2)
        :return: (N,#heads, L1,L2)
        '''
        mask = self._compute_CA_pad_mask(x,enc_att_mask)
        mask_logits = (1.0 - mask) * -10000.0
        return mask_logits


    def MultiHeadSelfAttention(self, x,attention_mask):
        '''

        :param x: (N,L1,D)
        :return: (N,L1,D)
        '''
        '''
        Q,K,V:

        (N,L,(#heads * dph)) ->(N,#heads,L,dph)
        '''
        new_size = x.size()[:-1] + (self.num_heads, self.dim_per_head)  # (N,L1, #heads, dph)
        Q = self.Wq(x).view(*new_size).permute(0, 2, 1, 3)  # (N,#heads,L1,dph)
        K = self.Wk(x).view(*new_size).permute(0, 2, 1, 3)
        V = self.Wv(x).view(*new_size).permute(0, 2, 1, 3)

        attention_score = torch.matmul(Q,K.transpose(-2,-1))/math.sqrt(self.dim) # (N,#heads,L1, L1)
        attention_score += self._compute_SA_mask_logits(x,attention_mask)
        attention_score = nn.Softmax(-1)(attention_score)
        attention_score = self.att_drop(attention_score)
        O = torch.matmul(attention_score,V)
        O = O.permute(0, 2, 1, 3)  # (N,L1, #heads, dph)
        O = O.contiguous().view(x.size(0), x.size(1), -1)  # (N,L1, #heads*dph)
        O = self.W_reshape_back_SA(O)  # (N,L1,D)
        O = self.state_drop(O)
        O = self.layerNorm_SA(x + O)
        return O


    def MultiHeadCrossAttention(self, x1, x2,enc_att_mask):
        '''

        :param x1: decoder input: (N,L1,D)
        :param x2: encoder output: (N,L2,D)
        :return: (N,L1,D)
        '''
        '''
        Q,K,V:

        (N,L,(#heads * dph)) ->(N,#heads,L,dph)
        '''
        N = x1.size(0)
        Q = self.Wq2(x1).view(N, -1, self.num_heads, self.dim_per_head).transpose(1, 2)  # Q: [N, n_heads, L1, dph]
        K = self.Wk2(x2).view(N, -1, self.num_heads, self.dim_per_head).transpose(1, 2)  # K: [N, n_heads, L2, dph]
        V = self.Wv2(x2).view(N, -1, self.num_heads, self.dim_per_head).transpose(1, 2)  # V: [N, n_heads, L2, dph]

        attention_score = torch.matmul(Q,K.transpose(-2,-1))/math.sqrt(self.dim)    #[N, n_heads, L1, L2]
        attention_score += self._compute_CA_mask_logits(x1,enc_att_mask)
        attention_score = nn.Softmax(-1)(attention_score)
        attention_score = self.att_drop(attention_score)
        O = torch.matmul(attention_score,V) # [N, n_heads, L1, dph]
        O = O.permute(0, 2, 1, 3)  # (N,L1, n_heads, dph)
        O = O.contiguous().view(N, x1.size(1), -1)  # (N,L1, n_heads*dph)
        O = self.W_reshape_back_CA(O)  # (N,L1,D)
        O = self.layerNorm_SA(x1 + O)
        return O


    def FFN(self,x):
        tmp1 = self.act(self.ffn1(x))
        tmp2 = self.ffn2(tmp1)
        tmp2 = self.state_drop(tmp2)
        output = self.layerNorm_ffn(x+tmp2)
        return output

    def forward(self,x1,x2,dec_att_mask,enc_att_mask):
        '''

        :param x1: decoder input: (N,L1,D)
        :param x2: encoder output: (N,L2,D)
        :param dec_att_mask: (N,L1)
        :param enc_att_mask: (N,L2)
        :return:   (N,L1,D)
        '''

        x1 = self.MultiHeadSelfAttention(x1,dec_att_mask)
        tmp = self.MultiHeadCrossAttention(x1, x2,enc_att_mask)
        output = self.FFN(tmp)
        return output



In [3]:
decoder_layer = DecoderLayer(768)
x1 = torch.ones(2, 2, 768)
x2 = torch.ones(2, 3, 768)

enc_att_mask = torch.tensor([[1,1,0],[1,1,1]])
dec_att_mask = torch.tensor([[1,1],[1,0]])

output = decoder_layer(x1,x2,dec_att_mask,enc_att_mask)
print(output, output.shape)

tensor([[[-0.0347,  2.0425, -0.2685,  ..., -1.8946, -0.0547, -0.0576],
         [ 0.7092,  2.0814, -0.6385,  ..., -1.4035,  0.4594, -1.4810]],

        [[ 0.9878,  0.2651, -0.0735,  ..., -2.1452,  0.4137, -0.0737],
         [ 0.4149,  0.0423, -0.7192,  ..., -0.1651,  1.1540,  0.0324]]],
       grad_fn=<NativeLayerNormBackward0>) torch.Size([2, 2, 768])
