In [1]:
import torch
import torch.nn as nn
import math


class NaiveDecoderLayer(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.dim = hidden_dim

        # SA
        self.Wq = nn.Linear(self.dim, self.dim, bias=False)
        self.Wk = nn.Linear(self.dim, self.dim, bias=False)
        self.Wv = nn.Linear(self.dim, self.dim, bias=False)
        self.layerNorm_SA = nn.LayerNorm(self.dim)

        # CA
        self.Wq2 = nn.Linear(self.dim,self.dim,bias=False)
        self.Wk2 = nn.Linear(self.dim,self.dim,bias=False)
        self.Wv2 = nn.Linear(self.dim,self.dim,bias=False)
        self.layerNorm_CA = nn.LayerNorm(self.dim)


        # FFN
        self.ffn1 = nn.Linear(self.dim,self.dim*4)
        self.ffn2 = nn.Linear(self.dim*4,self.dim)
        self.act = nn.GELU()
        self.layerNorm_ffn = nn.LayerNorm(self.dim)

    def SelfAttention(self, x):
        '''

        :param x: (N,L,D)
        :return: (N,L,D)
        '''
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)

        attention_score = torch.matmul(Q,K.transpose(1,2))/math.sqrt(self.dim)
        attention_score = nn.Softmax(-1)(attention_score)
        O = torch.matmul(attention_score,V)
        O = self.layerNorm_SA(x + O)
        return O

    def CrossAttention(self,x1,x2):
        '''

        :param x1: decoder input: (N,L,D)
        :param x2: encoder output: (N,L,D)
        :return: (N,L,D)
        '''
        Q = self.Wq2(x1)
        K = self.Wk2(x2)
        V = self.Wv2(x2)

        attention_score = torch.matmul(Q,K.transpose(1,2))/math.sqrt(self.dim)
        attention_score = nn.Softmax(-1)(attention_score)
        O = torch.matmul(attention_score,V)
        O = self.layerNorm_SA(x1 + O)
        return O


    def FFN(self,x):
        tmp1 = self.act(self.ffn1(x))
        tmp2 = self.ffn2(tmp1)
        output = self.layerNorm_ffn(x+tmp2)
        return output

    def forward(self,x1,x2):
        '''

        :param x1: decoder input: (N,L,D)
        :param x2: encoder output: (N,L,D)
        :return:   (N,L,D)
        '''

        x1 = self.SelfAttention(x1)
        tmp = self.CrossAttention(x1,x2)
        output = self.FFN(tmp)
        return output

In [2]:
naive_decoder_layer = NaiveDecoderLayer(200)

In [3]:
import numpy as np
X1 = np.random.randn(10,50,200)
X1 = torch.Tensor(X1)
X1

tensor([[[-0.4192, -0.3046, -1.0121,  ...,  1.1269, -0.0948, -0.6938],
         [ 0.5220,  1.4163,  0.5279,  ..., -1.4432, -0.1654, -0.7548],
         [-0.9789, -1.1642,  1.0576,  ..., -0.2250, -1.0317, -0.0946],
         ...,
         [ 0.9443,  0.9501,  0.3716,  ...,  0.7471,  2.3221,  0.3581],
         [-1.2199,  0.0381,  2.2465,  ...,  1.8185, -0.5360, -0.1274],
         [ 0.3285, -1.8043, -1.8896,  ...,  0.9083,  0.7956,  0.6076]],

        [[-0.2989, -0.6812,  0.8181,  ...,  0.2203, -1.0936,  1.0009],
         [-1.6739, -0.6420, -1.6010,  ...,  1.1678,  0.1683,  0.2897],
         [ 1.8925,  0.0761, -0.2128,  ...,  1.2759,  0.3530, -0.0767],
         ...,
         [ 1.2836,  0.3583, -0.1468,  ..., -0.2630, -1.5336, -1.6907],
         [ 0.0603, -0.1419, -1.1213,  ...,  1.9134, -0.3812,  0.8961],
         [-0.1241,  1.7099, -0.3782,  ..., -0.6878,  0.2951, -0.6351]],

        [[-1.2473,  0.0931,  0.2344,  ...,  0.4220, -0.1690,  1.9427],
         [-0.3613,  2.2734, -0.0558,  ..., -0

In [4]:
X2 = np.random.randn(10,50,200)
X2 = torch.Tensor(X2)
X2

tensor([[[-1.4205,  1.1311, -1.4933,  ..., -0.1737, -1.1983, -1.3437],
         [ 0.4508,  0.1779,  0.0160,  ..., -0.4977, -0.0788, -0.3522],
         [ 0.0479,  0.2017, -2.6437,  ...,  0.1751, -1.4324,  1.1254],
         ...,
         [ 0.0070, -0.2205, -0.2593,  ...,  0.5204,  0.3335, -1.8348],
         [ 1.5172,  1.2976, -0.0281,  ..., -0.4320, -1.3770,  1.5681],
         [-0.0226,  1.0358,  0.1264,  ..., -0.6032,  0.5050, -0.8105]],

        [[ 0.1050, -0.7592, -0.7361,  ..., -0.6486, -0.8526,  0.7178],
         [-0.3437,  0.0080,  1.0824,  ...,  0.5045,  0.2328,  0.5675],
         [-1.9728,  0.7277,  1.2641,  ...,  0.2246,  0.4251,  1.4439],
         ...,
         [-1.0379,  0.6301,  0.5887,  ..., -2.6204,  0.1260, -0.8174],
         [-1.6262,  0.9781, -0.2662,  ...,  0.5593,  0.9278,  0.8802],
         [-0.6525,  0.3195,  2.3799,  ..., -0.1724,  1.8554, -1.3553]],

        [[ 0.5035,  0.7579, -0.6087,  ...,  1.4124,  1.0154,  0.0637],
         [-0.4996,  0.1251,  0.1392,  ..., -0

In [5]:
output = naive_decoder_layer(X1,X2)

output.shape,output

(torch.Size([10, 50, 200]),
 tensor([[[-0.3869, -0.5746, -0.5209,  ...,  1.4008, -0.2333, -1.2964],
          [ 0.6724,  1.2441,  0.7611,  ..., -1.4689, -0.2661, -1.1754],
          [-0.9262, -0.9403,  0.9262,  ..., -0.3938, -1.3850, -0.5305],
          ...,
          [ 1.3099,  0.8876,  0.6687,  ...,  1.1960,  2.1952,  0.0927],
          [-1.3134,  0.1008,  2.5786,  ...,  1.7610, -1.0919, -0.2164],
          [ 0.4003, -2.0306, -1.7047,  ...,  0.8807,  0.7576,  0.1548]],
 
         [[-0.2834, -0.3894,  0.7972,  ...,  0.7084, -1.2533,  0.8241],
          [-1.4472, -0.4665, -1.5695,  ...,  1.3613,  0.0084,  0.3219],
          [ 1.9283,  0.3119,  0.3057,  ...,  2.0656,  0.2988,  0.0633],
          ...,
          [ 1.7092,  0.3258, -0.4554,  ..., -0.0654, -1.6422, -2.1645],
          [ 0.1823, -0.1338, -0.8696,  ...,  2.0757, -0.4466,  0.6769],
          [-0.1520,  1.9974, -0.3151,  ..., -0.6242,  0.2950, -0.7862]],
 
         [[-1.3522, -0.1732,  0.2250,  ...,  0.5865, -0.5035,  1.8003],


# 缺点:
1. 没有dropout
2. 没有multi-head attention
3. 没有attention mask, subsequence mask