## Transfomer 从零开始
基于Encoder-Decoder框架，面向Seq2Seq类别任务，采用Self- attention和Multi-head attention机制，同时使用Token embedding+Position embedding、残差连接、layer norm、Mask等tricks，实现了惊为天人的Transformer。

潜在直觉：在ED框架下，最不可忍受的推理的不可并行，仔细分析背后的过程，可以看到是因为隐藏状态需要一个一个输入。那么是否有一种方式可以直接得到所有的隐藏状态，并且用一种线性变化就可以得到预测后的输出，我觉得这个是transformer的解决的问题。
1. Attention解决了并行性，扩大了感受野的广度
2. Position embedding、mask完善了时序信息
3. 残差连接、layer norm是基本的网络结构

## 0x01 Embedding

In [2]:
import torch
import torch.nn as nn
import math
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size,embedding_size) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size,embedding_size)
        self.embedding_size = embedding_size
    
    def forward(self,x):
        return self.embedding(x.long())*math.sqrt(self.embedding_size)

class PositionalEncoding(nn.Module):
    def __init__(self,dimen,dropout=0.1,max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        PE = torch.zeros(max_len,dimen)
        position = torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,dimen,2).float() * (-math.log(10000.0) / dimen))
        PE[:,0::2] = torch.sin(position * div_term)
        PE[:,1::2] = torch.cos(position * div_term)
        PE = PE.unsqueeze(0).transpose(0,1)
        self.register_buffer('PE',PE)
    
    def forward(self,x):
        x = x + self.PE[:x.size(0),:]
        return self.dropout(x)

x = torch.tensor([[1,2,3,4,5],[5,6,7,8,9]],dtype=torch.float)
# size of the x is [seq_len,batch_size],不用one-hot表示
x = x.reshape(5,2) 
print('The input size is ',x.shape)
vocab_size = 10
embedding_size = 512

# size of the x is [seq_len,batch_size,embedding_size]
token_embedding = TokenEmbedding(vocab_size,embedding_size)
x = token_embedding(x)
print('TokenEmbedding size is ',x)
print('token embedding size is ',x.shape)
pos_embedding = PositionalEncoding(embedding_size)
x = pos_embedding(x)
print('Position embedding size is',x.shape)
print('The size must be [seq_len,batchsize,embedding_size]')

The input size is  torch.Size([5, 2])
TokenEmbedding size is  tensor([[[ 23.3701, -43.0523, -41.0195,  ..., -19.0137,   7.7349, -10.3376],
         [  4.5575, -17.0709, -20.0782,  ...,  13.5287,  17.3372,  16.6027]],

        [[-12.3629,  -4.9283,   0.2701,  ...,  16.3192,   9.2782,  -9.4326],
         [-18.9739, -20.1240, -50.7473,  ...,   3.4778,  20.8694,  20.9355]],

        [[ 26.6863,  36.4356,  22.5699,  ..., -10.3694, -23.3673, -11.5537],
         [ 26.6863,  36.4356,  22.5699,  ..., -10.3694, -23.3673, -11.5537]],

        [[ -8.3441,   7.6558,  23.0942,  ...,   7.3344,  -0.1192, -10.6960],
         [ -3.8388, -37.2422,  13.3843,  ...,  28.1898, -30.3586,   8.6297]],

        [[  8.4035, -13.6004, -29.5773,  ...,   3.1634,  36.3206,  34.4238],
         [-26.4645,  31.2842,  25.9679,  ..., -21.6577,  34.6043,  -3.4274]]],
       grad_fn=<MulBackward0>)
token embedding size is  torch.Size([5, 2, 512])
Position embedding size is torch.Size([5, 2, 512])
The size must be [seq_len,b

## 0x02 Encoder layer 实现
主要是实现基于attention的encoder和decoder layer，并搭建好encoder和decoder

In [3]:
class TransformerEncoderLayerStractch(nn.Module):
    def __init__(self,dimen,nhead,dim_forward=2048,dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(dimen,nhead,dropout=dropout)
        self.dropoutAttn = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(dimen)
        self.linear1 = nn.Linear(dimen,dim_forward)
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_forward,dimen)
        self.activation = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(dimen)
    
    def forward(self,x,mask=None,src_key_padding_mask=None):
        '''
        x: [seq_len,batch_size,embedding_size]
        '''
        x2,_ = self.self_attn(x,x,x,attn_mask=mask,key_padding_mask=src_key_padding_mask)
        # x2: [seq_len,batch_size,embedding_size*nhead]
        x = x + self.dropoutAttn(x2)
        x = self.norm1(x)

        x2 = self.activation(self.linear1(x))
        x2 = self.linear2(self.dropout1(x2))
        x = x + self.dropout2(x2)

        x = self.norm2(x)
        return x 
print('Test the transformer encoder layer')
nhead = 8
layer = TransformerEncoderLayerStractch(embedding_size,nhead)
x = layer(x)

Test the transformer encoder layer


In [4]:
import copy
def get_copy(module,N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
                         
class TransformerEncoderStractch(nn.Module):
    def __init__(self,encoder_layer,num_layers,norm=None):
        super().__init__()
        self.layers = nn.ModuleList([encoder_layer for _ in range(num_layers)])
        self.norm = norm
        self.num_layers = num_layers
    
    def forward(self,x,mask=None,src_key_padding_mask=None):
        output = x
        for layer in self.layers:
            output = layer(output,mask=mask,src_key_padding_mask=src_key_padding_mask)
        if self.norm is not None:
            output = self.norm(output)
        return output 

encoder = TransformerEncoderStractch(layer,6)
x = encoder(x)
print('Test the transformer encoder')
loss = nn.MSELoss()
print(loss(x,encoder(x)))

Test the transformer encoder
tensor(0.5778, grad_fn=<MseLossBackward0>)


## 0x03 Decoder实现

In [5]:
class TransformerDecoderLayerStractch(nn.Module):
    def __init__(self,dimen,nhead,dim_forward=2048,dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(dimen,nhead,dropout=dropout)
        self.dropoutAttn = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(dimen)
        self.multihead_attn = nn.MultiheadAttention(dimen,nhead,dropout=dropout)
        self.dropoutAttn2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(dimen)
        self.linear1 = nn.Linear(dimen,dim_forward)
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_forward,dimen)
        self.dropout2 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(dimen)
        self.activation = nn.ReLU()
    
    def forward(self,tgt,memory,tgt_mask=None,memory_mask=None,tgt_key_padding_mask=None,memory_key_padding_mask=None):
        ''' 
        tgt = [tgt_len,batch_size,embedding_size]
        memory = [seq_len,batch_size,embedding_size]
        '''
        tgt2,_ = self.self_attn(tgt,tgt,tgt,attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)
        tgt = tgt + self.dropoutAttn(tgt2)
        tgt = self.norm1(tgt)
        # tgt = [tgt_len,batch_size,embedding_size]

        tgt2,_ = self.multihead_attn(tgt,memory,memory,attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)
        tgt = tgt + self.dropoutAttn2(tgt2)
        tgt = self.norm2(tgt)
        # tgt = [tgt_len,batch_size,embedding_size]

        tgt2 = self.activation(self.linear1(tgt))
        tgt2 = self.linear2(self.dropout1(tgt2))
        tgt = tgt + self.dropout2(tgt2)

        tgt = self.norm3(tgt)
        # tgt = [tgt_len,batch_size,embedding_size]
        return tgt
print('Test the transformer decoder layer')
decoder_layer = TransformerDecoderLayerStractch(embedding_size,nhead)
x = decoder_layer(x,layer(x))
print(loss(x,decoder_layer(x,layer(x))))

Test the transformer decoder layer


tensor(0.2137, grad_fn=<MseLossBackward0>)


In [6]:
class TransformerDecoderStractch(nn.Module):
    def __init__(self,decoder_layer,num_layers,norm=None):
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
        self.norm = norm
        self.num_layers = num_layers
    
    def forward(self,tgt,memory,tgt_mask=None,memory_mask=None,tgt_key_padding_mask=None,memory_key_padding_mask=None):
        output = tgt
        for layer in self.layers:
            output = layer(output,memory,tgt_mask=tgt_mask,memory_mask=memory_mask,tgt_key_padding_mask=tgt_key_padding_mask,memory_key_padding_mask=memory_key_padding_mask)
        if self.norm is not None:
            output = self.norm(output)
        return output
encoder = TransformerEncoderStractch(layer,6)
decoder = TransformerDecoderStractch(decoder_layer,6)

In [7]:
# Bulid the transformer
class MyTransformer(nn.Module):
    def __init__(self,dimen=512,nhead=8,num_encoder_layers=6,num_decoder_layers=6,dim_forward=2048,dropout=0.1):
        super().__init__()
        encoder_layer = TransformerEncoderLayerStractch(dimen,nhead,dim_forward,dropout)
        encoder_norm = nn.LayerNorm(dimen)
        self.encoder = TransformerEncoderStractch(encoder_layer,num_encoder_layers,encoder_norm)

        decoder_layer = TransformerDecoderLayerStractch(dimen,nhead,dim_forward,dropout)
        decoder_norm = nn.LayerNorm(dimen)
        self.decoder = TransformerDecoderStractch(decoder_layer,num_decoder_layers,decoder_norm)

        self._reset_parameters()
        self.dimen = dimen
        self.nhead = nhead
    
    def forward(self,src,tgt,src_mask=None,tgt_mask=None,memory_mask=None,src_key_padding_mask=None,tgt_key_padding_mask=None,memory_key_padding_mask=None):
        memory = self.encoder(src,mask=src_mask,src_key_padding_mask=src_key_padding_mask)
        output = self.decoder(tgt,memory,tgt_mask=tgt_mask,memory_mask=memory_mask,tgt_key_padding_mask=tgt_key_padding_mask,memory_key_padding_mask=memory_key_padding_mask)
        return output
    
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def generate_square_subsequent_mask(self,sz):
        mask = (torch.triu(torch.ones(sz,sz)) == 1).transpose(0,1)
        mask = mask.float().masked_fill(mask == 0,float('-inf')).masked_fill(mask == 1,float(0.0))
        return mask


In [8]:
seq_len = 12
batch_size = 2
dimen = 128
tag_len = 10
nhead = 8
input = torch.randn(seq_len,batch_size,dimen)
target = torch.randn(tag_len,batch_size,dimen)
print('The input size is ',input.shape)

model = MyTransformer(dimen=dimen,nhead=nhead,num_decoder_layers=6,num_encoder_layers=6,dim_forward=2048,dropout=0.1)

tgt_mask = model.generate_square_subsequent_mask(tag_len)
out = model(input,target,tgt_mask=tgt_mask)
print('Test the transformer')
print(loss(out,model(input,target,tgt_mask=tgt_mask)))

The input size is  torch.Size([12, 2, 128])
Test the transformer
tensor(0.4967, grad_fn=<MseLossBackward0>)


## 细品Torch的Multi head attention


In [9]:
layer = nn.MultiheadAttention(embed_dim=dimen,
                              num_heads=nhead,
                              dropout=0.1)
input = torch.randn(seq_len,batch_size,dimen)
print('input size is',input.shape)
for out in layer(input,input,input):
    print(out.shape)


input size is torch.Size([12, 2, 128])
torch.Size([12, 2, 128])
torch.Size([2, 12, 12])
