# Transformer 
对`attention`机制的简单讲解[link](https://easyai.tech/ai-definition/attention/)（顺带一提这好像是个关于AI的知识库）  
实现上参考了[link](https://towardsdatascience.com/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1)  
其他讲解可以参考[blog](https://ifwind.github.io/)以及[CSDN](https://blog.csdn.net/zhaohongfei_358/article/details/126019181)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import math
import numpy as np

import random

### 单词表达 
- one hot representation ： 每个单词都是一个维度，如果一共有1000个单词， 那么每个单词都是`[0,0,...,1,0,...,0]`  
- distributed representation : 每个单词具有多个维度（分类）=》`word embedding`即为找到一个映射f，从输入空间到分类好的空间。最广泛的词嵌入方法有`word2vec`   

### Positional Encoding
- 传统的RNN：以序列形式逐个处理句子中的词，因此词语的顺序信息不会丢失  
- transformer：词语同时进入网络，因此顺序信息丢失，需要额外处理  
- 处理方式：特殊正余弦函数的映射，其中`max_len`是能处理的最长序列长度

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self,dim_model,dropout_p,max_len):
        super().__init__()
        self.dropout = nn.Dropout(dropout_p) # used for regularization and preventing the co-adaptation of neurons
        
        # Encoding from fomula
        pos_encoding = torch.zeros(max_len,dim_model)
        positions_list = torch.arange(0,max_len,dtype=torch.float).view(-1,1)
        division_term = torch.exp(torch.arange(0,dim_model,2).float() * (-math.log(10000.0)) / dim_model)# 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

### 关于 Mask 机制
参考[link](https://ifwind.github.io/2021/08/17/Transformer%E7%9B%B8%E5%85%B3%E2%80%94%E2%80%94%EF%BC%887%EF%BC%89Mask%E6%9C%BA%E5%88%B6/#self-attention%E4%B8%AD%E7%9A%84padding-mask)  

- sequence mask  
简单来说，由于训练的时候会把`target`整个直接喂进去，但训练的要求是一个一个给，因此将“答案”部分遮盖，防止模型作弊。参考[link](https://blog.csdn.net/zhaohongfei_358/article/details/125858248)  
- padding mask  
对不定长文本补齐

In [3]:
class Transformer(nn.Module):
    def __init__(
        self,
        num_tokens,# size of the dictionary of embeddings
        dim_model, # the number of expected features in the encoder/decoder inputs
        num_heads, # the number of heads in the multiheadattention models
        num_encoder_layers,
        num_decoder_layers,
        dropout_p,
    ):

        super(Transformer,self).__init__()
        self.model_type = "Transformer"
        self.dim_model = dim_model
        
        self.positional_encoder = PositionalEncoding(
            dim_model=dim_model, dropout_p=dropout_p, max_len=5000
        )
        
        self.embedding = nn.Embedding(num_tokens,dim_model)
        
        self.transformer = nn.Transformer(
            d_model=dim_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout_p,
        )
        
        self.out = nn.Linear(dim_model,num_tokens)
        
    def forward(self, src, tgt, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        # Src size must be (batch_size, src sequence length)
        # Tgt size must be (batch_size, tgt sequence length)
        
        src = self.embedding(src) * math.sqrt(self.dim_model) ### why times sqrt??
        tgt = self.embedding(tgt) * math.sqrt(self.dim_model)
        src = self.positional_encoder(src)
        tgt = self.positional_encoder(tgt)
        
        # here my pytorch version(1.8.1) do not have batch_first parameter :(
        src = src.permute(1,0,2)
        tgt = tgt.permute(1,0,2)
        
        out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask)
        
        return out
    
    def get_tgt_mask(self, size) -> torch.tensor:
        # Generates a squeare matrix where the each row allows one word more to be seen
        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
        
        # EX for size=5:
        # [[0., -inf, -inf, -inf, -inf],
        #  [0.,   0., -inf, -inf, -inf],
        #  [0.,   0.,   0., -inf, -inf],
        #  [0.,   0.,   0.,   0., -inf],
        #  [0.,   0.,   0.,   0.,   0.]]
        
        return mask
    
    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is
        # [False, False, False, True, True, True]
        return (matrix == pad_token)

下面的创造数据集的部分没有细看，直接复制。

In [4]:
def generate_random_data(n):
    SOS_token = np.array([2])
    EOS_token = np.array([3])
    length = 8

    data = []

    # 1,1,1,1,1,1 -> 1,1,1,1,1
    for i in range(n // 3):
        X = np.concatenate((SOS_token, np.ones(length), EOS_token))
        y = np.concatenate((SOS_token, np.ones(length), EOS_token))
        data.append([X, y])

    # 0,0,0,0 -> 0,0,0,0
    for i in range(n // 3):
        X = np.concatenate((SOS_token, np.zeros(length), EOS_token))
        y = np.concatenate((SOS_token, np.zeros(length), EOS_token))
        data.append([X, y])

    # 1,0,1,0 -> 1,0,1,0,1
    for i in range(n // 3):
        X = np.zeros(length)
        start = random.randint(0, 1)

        X[start::2] = 1

        y = np.zeros(length)
        if X[-1] == 0:
            y[::2] = 1
        else:
            y[1::2] = 1

        X = np.concatenate((SOS_token, X, EOS_token))
        y = np.concatenate((SOS_token, y, EOS_token))

        data.append([X, y])

    np.random.shuffle(data)

    return data


def batchify_data(data, batch_size=16, padding=False, padding_token=-1):
    batches = []
    for idx in range(0, len(data), batch_size):
        # We make sure we dont get the last bit if its not batch_size size
        if idx + batch_size < len(data):
            # Here you would need to get the max length of the batch,
            # and normalize the length with the PAD token.
            if padding:
                max_batch_length = 0

                # Get longest sentence in batch
                for seq in data[idx : idx + batch_size]:
                    if len(seq) > max_batch_length:
                        max_batch_length = len(seq)

                # Append X padding tokens until it reaches the max length
                for seq_idx in range(batch_size):
                    remaining_length = max_bath_length - len(data[idx + seq_idx])
                    data[idx + seq_idx] += [padding_token] * remaining_length

            batches.append(np.array(data[idx : idx + batch_size]).astype(np.int64))

    print(f"{len(batches)} batches of size {batch_size}")

    return batches


train_data = generate_random_data(9000)
val_data = generate_random_data(3000)

train_dataloader = batchify_data(train_data)
val_dataloader = batchify_data(val_data)

562 batches of size 16
187 batches of size 16


In [5]:
# Training
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(
    num_tokens=4, dim_model=8, num_heads=2, num_encoder_layers=3, num_decoder_layers=3, dropout_p=0.1
).to(device)
opt = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

def train_loop(model,opt,loss_fn,dataloader):
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        X, y = batch[:, 0], batch[:, 1]
        X, y = torch.tensor(X).to(device), torch.tensor(y).to(device)

        # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
        y_input = y[:,:-1]
        y_expected = y[:,1:]
        
        # Get mask to mask out the next words
        sequence_length = y_input.size(1)
        tgt_mask = model.get_tgt_mask(sequence_length).to(device)

        # Standard training except we pass in y_input and tgt_mask
        pred = model(X, y_input, tgt_mask)

        # Permute pred to have batch size first again
        pred = pred.permute(1, 2, 0)      
        loss = loss_fn(pred, y_expected)

        opt.zero_grad()
        loss.backward()
        opt.step()
        
        total_loss += loss.detach().item()
        
    return total_loss / len(dataloader)

# Validation
def validation_loop(model, loss_fn, dataloader):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch[:, 0], batch[:, 1]
            X, y = torch.tensor(X, dtype=torch.long, device=device), torch.tensor(y, dtype=torch.long, device=device)

            # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
            y_input = y[:,:-1]
            y_expected = y[:,1:]
            
            # Get mask to mask out the next words
            sequence_length = y_input.size(1)
            tgt_mask = model.get_tgt_mask(sequence_length).to(device)

            # Standard training except we pass in y_input and src_mask
            pred = model(X, y_input, tgt_mask)
            
            # Permute pred to have batch size first again
            pred = pred.permute(1, 2, 0)      
            loss = loss_fn(pred, y_expected)
            total_loss += loss.detach().item()
        
    return total_loss / len(dataloader)

def fit(model, opt, loss_fn, train_dataloader, val_dataloader, epochs):

    # Used for plotting later on
    train_loss_list, validation_loss_list = [], []
    
    print("Training and validating model")
    for epoch in range(epochs):
        print("-"*25, f"Epoch {epoch + 1}","-"*25)
        
        train_loss = train_loop(model, opt, loss_fn, train_dataloader)
        train_loss_list += [train_loss]
        
        validation_loss = validation_loop(model, loss_fn, val_dataloader)
        validation_loss_list += [validation_loss]
        
        print(f"Training loss: {train_loss:.4f}")
        print(f"Validation loss: {validation_loss:.4f}")
        print()
        
    return train_loss_list, validation_loss_list
    
train_loss_list, validation_loss_list = fit(model, opt, loss_fn, train_dataloader, val_dataloader, 10)


Training and validating model
------------------------- Epoch 1 -------------------------
Training loss: 0.7350
Validation loss: 0.4675

------------------------- Epoch 2 -------------------------
Training loss: 0.4769
Validation loss: 0.4173

------------------------- Epoch 3 -------------------------
Training loss: 0.4340
Validation loss: 0.3943

------------------------- Epoch 4 -------------------------
Training loss: 0.4113
Validation loss: 0.3709

------------------------- Epoch 5 -------------------------
Training loss: 0.3902
Validation loss: 0.3455

------------------------- Epoch 6 -------------------------
Training loss: 0.3703
Validation loss: 0.3232

------------------------- Epoch 7 -------------------------
Training loss: 0.3546
Validation loss: 0.3036

------------------------- Epoch 8 -------------------------
Training loss: 0.3420
Validation loss: 0.2877

------------------------- Epoch 9 -------------------------
Training loss: 0.3284
Validation loss: 0.2811

-------