In [21]:
import torch
import torch.nn as nn
import math


class NaiveEncoder(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.dim = hidden_dim
        self.Wq = nn.Linear(self.dim, self.dim, bias=False)
        self.Wk = nn.Linear(self.dim, self.dim, bias=False)
        self.Wv = nn.Linear(self.dim, self.dim, bias=False)
        self.layerNorm_SA = nn.LayerNorm(self.dim)

        self.ffn1 = nn.Linear(self.dim,self.dim*4)
        self.ffn2 = nn.Linear(self.dim*4,self.dim)
        self.act = nn.GELU()
        self.layerNorm_ffn = nn.LayerNorm(self.dim)

    def SelfAttention(self, x):
        '''

        :param x: (N,L,D)
        :return: (N,L,D)
        '''
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)

        attention_score = torch.matmul(Q,K.transpose(1,2))/math.sqrt(self.dim)
        attention_score = nn.Softmax(-1)(attention_score)
        O = torch.matmul(attention_score,V)
        O = self.layerNorm_SA(x + O)
        return O

    def FFN(self,x):
        tmp1 = self.act(self.ffn1(x))
        tmp2 = self.ffn2(tmp1)
        output = self.layerNorm_ffn(x+tmp2)
        return output

    def forward(self, x):
        '''

        :param x: shape (N,L,D) N is batch size, L is the length of the sequnce, D is the dimension of word embeddings
        :return: shape (N,L,D)
        '''
        x = self.SelfAttention(x)
        x = self.FFN(x)

        return x

# 缺点
1. 没有dropout
2. 没有multi-head attention
3. 没有attention mask

In [22]:
import numpy as np
X = np.random.randn(10,50,200)
X = torch.Tensor(X)
X

tensor([[[-0.3159, -0.5830,  1.7410,  ...,  0.3811,  0.7921,  1.0346],
         [-0.5891, -0.5257, -0.2432,  ..., -1.1997,  0.5752, -1.3614],
         [ 0.2590,  2.6640, -1.0004,  ...,  0.5402,  1.5058,  0.6942],
         ...,
         [ 1.0217,  1.7838, -0.6132,  ...,  0.8607, -0.1479, -0.7465],
         [ 0.6748,  0.5542,  1.2293,  ...,  0.7653, -0.4950, -1.9648],
         [ 0.5786,  0.3714,  1.4466,  ..., -0.5294, -0.4966, -1.0818]],

        [[-2.9085,  0.5468, -0.6978,  ...,  0.1342, -0.9079,  2.0059],
         [-0.9772, -0.7716,  1.6000,  ...,  2.1822, -0.7635,  1.2609],
         [ 0.0752,  1.0175, -0.7128,  ...,  0.6540,  0.5588,  1.5919],
         ...,
         [-2.3571, -0.8133, -1.5915,  ...,  2.4165, -1.3593,  1.6513],
         [-0.4299, -0.4517, -0.5844,  ..., -0.7948,  0.1466,  0.2284],
         [-0.3546,  1.4316,  1.3139,  ..., -0.3075, -0.7003,  1.6643]],

        [[-1.5730, -0.5587,  1.5377,  ...,  0.0732, -0.7021,  0.5502],
         [ 0.7607, -0.1438,  1.5209,  ..., -0

In [23]:
naive_encoder = NaiveEncoder(200)

In [26]:
output = naive_encoder(X)

output.shape,output

(torch.Size([10, 50, 200]),
 tensor([[[-3.1560e-01, -3.4546e-01,  1.3904e+00,  ...,  2.3509e-01,
            5.5878e-01,  1.0912e+00],
          [-1.9650e-01, -6.3841e-01, -2.0755e-01,  ..., -9.8208e-01,
            6.4689e-01, -1.7955e+00],
          [ 1.9568e-01,  2.4087e+00, -1.5424e+00,  ...,  5.4551e-01,
            1.1766e+00,  5.3722e-01],
          ...,
          [ 9.9189e-01,  1.5846e+00, -3.9716e-01,  ...,  1.0545e+00,
           -4.6478e-01, -8.6285e-01],
          [ 7.6432e-01,  2.7416e-01,  1.5233e+00,  ...,  7.8932e-01,
           -8.1846e-01, -1.9955e+00],
          [ 6.3419e-01,  3.0797e-01,  1.5341e+00,  ..., -3.8921e-01,
           -1.7943e-01, -1.3472e+00]],
 
         [[-2.9542e+00,  7.0059e-01, -9.7643e-01,  ...,  8.5056e-02,
           -9.4083e-01,  2.3823e+00],
          [-9.8874e-01, -1.0092e+00,  1.0984e+00,  ...,  2.3523e+00,
           -1.1913e+00,  1.0914e+00],
          [-7.3326e-02,  7.9250e-01, -9.4479e-01,  ...,  7.0494e-01,
            7.7909e-01,  1.82