# "Attention is all you need" 复现
基于pytorch库的内容，先实现transformer的基本训练、评估和测试的全流程

之后再尝试模仿他人，学习如何不完全借用pytorch的transformer从而实现语言翻译

## 1. 数据处理代码优化 PreprocessData
数据处理过程：
- 1. load_corpus_generator, 原始语料库加载函数
    - 支持语句长度限制
- 2. TokenizerTrain class, 分词器训练类，用于训练本地语料库
    - 逻辑稍后整理
- 3. TokenizerLoader，加载训练好的分词器
    - 逻辑...
- 4. TranslationDataset class，数据集处理类
- 5. collate_fn 函数，对数据进行批处理，包括填充、堆叠、排序等操作

In [None]:
import torch 
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from torch.nn.utils.rnn import pad_sequence

config = {
    'source-file':"/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/europarl-v7.de-en.en",
    'target-file':"/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/europarl-v7.de-en.de",
    'source-tokenizer-file':"/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/en_tokenizer.json",
    'target-tokenizer-file':"/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/de_tokenizer.json",
    'special-tokne':["[PAD]", "[UNK]", "[BOS]", "[EOS]"],
    'vocab-size':30000,
    'min-length':5,
    'max-length':128,
    'batch-size':64,
    'sample-ratio':0.1,
    'num-workers':4
}

# 2. 分词训练器
class TokenizerTrain:
    def __init__(self, vocab_size, special_tokens):
        self.vocab_size = vocab_size
        self.special_tokens = sepcial_tokens

    def train_and_save(self, corpus, output_path, language_name):
        tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = BPE.Trainer(special_tokens=self.special_tokens, vocab_size=self.vocab_size)
        tokenizer.train_from_iterator(corpus, trainer=trainer)
        tokenizer.save(f"{output_path}/{language_name}_tokenizer.json")

# 4. 数据集类
class TranslationDataset(Dataset):
    def __init__(self, src_lines_path, tgt_lines_path, src_transformer_file, tgt_transformer_file, max_length):
        self.src_generator = self.load_corpus_generator(src_lines_path, config['min-length'], config['max-length'])
        self.tgt_generator = self.load_corpus_generator(tgt_lines_path, config['min-length'], config['max-length'])
        self.src_lines = list(self.src_generator)
        self.tgt_lines = list(self.tgt_generator)
        self.src_transformer_tokenizer = self.TokenizerLoader(src_transformer_file)
        self.tgt_transformer_tokenizer = self.TokenizerLoader(tgt_transformer_file)
        self.max_length = max_length
    
    def __len__(self):
        return len(self.src_lines)
    
    def __getitem__(self, idx):
        src_line = self.src_lines[idx]
        tgt_line = self.tgt_lines[idx]

        src_encoding = self.src_transformer_tokenizer(
            src_line,
            padding = 'max_length',
            truncation = True,
            max_length = self.max_length,
            return_tensors = "pt"
        )

        tgt_encoding = self.tgt_transformer_tokenizer(
            tgt_line,
            padding = 'max_length',
            truncation = True,
            max_length = self.max_length,
            return_tensors = "pt"
        )

        return {
            "input_ids": src_encoding['input_ids'].squeeze(0),
            "attention_mask": src_encoding['attention_mask'].squeeze(0),
            "labels":tgt_encoding['input_ids'].squeeze(0)
        }

    def load_corpus_generator(self, file_path, min_length=5, max_length=128):
        #output = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line and min_length <= len(line.split()) <= max_length:
                    #output.append(line)
                    yield line

    def TokenizerLoader(self, tokenizer_path):
        tokenizer = Tokenizer.from_file(tokenizer_path)
        tokenizer = PreTrainedTokenizerFast(
            tokenizer_object = tokenizer,
            bos_token = "[BOS]",
            eos_token = "[EOS]",
            pad_token = "[PAD]",
            unk_token = "[UNK]"
        )
        return tokenizer

# 自定义 collate_fn
def collate_fn(batch):
    input_ids = pad_sequence([item['input_ids'] for item in batch], batch_first=True, padding_value=0)
    attention_mask = pad_sequence([item['attention_mask'] for item in batch], batch_first=True, padding_value=0)
    labels = pad_sequence([item['labels'] for item in batch], batch_first=True, padding_value=-100)  # -100 是常用的忽略索引值
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }




# 测试用例
translation_dataset = TranslationDataset(config['source-file'], config['target-file'], config['source-tokenizer-file'], config['target-tokenizer-file'], config['max-length'])


indices = np.random.choice(len(translation_dataset), int(len(translation_dataset) * config["sample-ratio"]), replace=False)
sampler = SubsetRandomSampler(indices)

sampled_loader = DataLoader(
    translation_dataset,
    batch_size=config['batch-size'],
    sampler=sampler,
    num_workers=config['num-workers'],
    collate_fn=collate_fn
)

## 经过通义优化后的代码

In [26]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer, models, trainers, pre_tokenizers
from torch.nn.utils.rnn import pad_sequence


class TokenizerTrainer:
    """
    训练并保存分词器的类。
    """
    def __init__(self, vocab_size, special_tokens):
        self.vocab_size = vocab_size
        self.special_tokens = special_tokens

    def train_and_save(self, corpus_generator, output_path, language_name):
        """
        训练分词器并保存为 JSON 文件。
        :param corpus_generator: 语料库生成器，逐行生成句子。
        :param output_path: 输出路径。
        :param language_name: 语言名称，用于文件命名。
        """
        tokenizer = Tokenizer(models.BPE(unk_token=self.special_tokens[1]))  # 使用 BPE 模型
        tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()  # 添加预分词器
        trainer = trainers.BpeTrainer(special_tokens=self.special_tokens, vocab_size=self.vocab_size)
        tokenizer.train_from_iterator(corpus_generator, trainer=trainer)
        tokenizer.save(f"{output_path}/{language_name}_tokenizer.json")

    def create_pair_file(self, source_file, target_file, pair_file):
        """
        创建源语言和目标语言的配对文件。
        :param source_file: 源语言文件路径。
        :param target_file: 目标语言文件路径。
        :param pair_file: 输出的配对文件路径。
        """
        with open(source_file, 'r', encoding='utf-8') as src_f, \
            open(target_file, 'r', encoding='utf-8') as tgt_f, \
            open(pair_file, 'w', encoding='utf-8') as pair_f:

            for src_line, tgt_line in zip(src_f, tgt_f):
                src_line = src_line.strip()
                tgt_line = tgt_line.strip()
                if src_line and tgt_line:  # 确保句子非空
                    pair_f.write(f"{src_line}\t{tgt_line}\n")


class TranslationDataset(Dataset):
    """
    翻译任务的数据集类。
    """
    def __init__(self, config):
        """
        初始化数据集。
        :param src_file: 源语言文件路径。
        :param tgt_file: 目标语言文件路径。
        :param src_tokenizer_file: 源语言分词器文件路径。
        :param tgt_tokenizer_file: 目标语言分词器文件路径。
        :param config: 配置字典，包含 min_length, max_length, max_length 等参数。
        """
        self.src_lines, self.tgt_lines = self.load_pairs(config['pair-file'], config['min-length'], config['max-length'])
        assert len(self.src_lines) == len(self.tgt_lines), "源语言和目标语言的句子数量不匹配！"

        # 加载分词器
        self.src_tokenizer = self.load_tokenizer(config['source-tokenizer-file'])
        self.tgt_tokenizer = self.load_tokenizer(config['target-tokenizer-file'])

        self.max_length = config['max-length']

    def __len__(self):
        return len(self.src_lines)

    def __getitem__(self, idx):
        src_line = self.src_lines[idx]
        tgt_line = self.tgt_lines[idx]

        # 编码源语言和目标语言
        src_encoding = self.src_tokenizer(
            src_line,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        tgt_encoding = self.tgt_tokenizer(
            tgt_line,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            "input_ids": src_encoding['input_ids'].squeeze(0),
            "attention_mask": src_encoding['attention_mask'].squeeze(0),
            "labels": tgt_encoding['input_ids'].squeeze(0)
        }

    def load_corpus(self, file_path, min_length, max_length):
        """
        加载语料库，过滤掉不符合长度要求的句子。
        :param file_path: 文件路径。
        :param min_length: 最小句子长度。
        :param max_length: 最大句子长度。
        :return: 过滤后的句子生成器。
        """
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line and min_length <= len(line.split()) <= max_length:
                    yield line

    def load_tokenizer(self, tokenizer_path):
        """
        加载分词器并封装为 PreTrainedTokenizerFast 对象。
        :param tokenizer_path: 分词器文件路径。
        :return: 加载好的分词器。
        """
        tokenizer = Tokenizer.from_file(tokenizer_path)
        return PreTrainedTokenizerFast(
            tokenizer_object=tokenizer,
            bos_token="[BOS]",
            eos_token="[EOS]",
            pad_token="[PAD]",
            unk_token="[UNK]"
        )
    
    def load_pairs(self, pair_file, min_length, max_length):
        """
        加载配对文件并过滤掉不符合长度要求的句子对。
        :param pair_file: 配对文件路径。
        :param min_length: 最小句子长度。
        :param max_length: 最大句子长度。
        :return: 过滤后的源语言和目标语言句子列表。
        """
        src_lines, tgt_lines = [], []
        with open(pair_file, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) != 2:
                    continue  # 跳过格式错误的行

                src_line, tgt_line = parts[0], parts[1]
                if (min_length <= len(src_line.split()) <= max_length and
                        min_length <= len(tgt_line.split()) <= max_length):
                    src_lines.append(src_line)
                    tgt_lines.append(tgt_line)

        return src_lines, tgt_lines

def collate_fn(batch):
    input_ids = pad_sequence([item['input_ids'] for item in batch], batch_first=True, padding_value=0)
    attention_mask = pad_sequence([item['attention_mask'] for item in batch], batch_first=True, padding_value=0)
    labels = pad_sequence([item['labels'] for item in batch], batch_first=True, padding_value=-100)  # -100 是常用的忽略索引值
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

# 配置字典
config = {
    'pair-file': "/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/en-de-pair.txt",
    'source-file': "/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/europarl-v7.de-en.en",
    'target-file': "/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/europarl-v7.de-en.de",
    'source-tokenizer-file': "/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/en_tokenizer.json",
    'target-tokenizer-file': "/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/de_tokenizer.json",
    'special-tokens': ["[PAD]", "[UNK]", "[BOS]", "[EOS]"],
    'vocab-size': 30000,
    'min-length': 5,
    'max-length': 128,
    'batch-size': 64,
    'sample-ratio': 0.1,
    'num-workers': 4
}

# 测试用例
if __name__ == "__main__":
    # 创建数据集
    tokenizerTrainer = TokenizerTrainer(config['vocab-size'], config['special-tokens'])
    tokenizerTrainer.create_pair_file(source_file=config['source-file'], target_file=config['target-file'], pair_file='en-de-pair.txt')

    translation_dataset = TranslationDataset(config)

    # 创建采样器
    indices = np.random.choice(len(translation_dataset), int(len(translation_dataset) * config["sample-ratio"]), replace=False)
    sampler = SubsetRandomSampler(indices)

    # 创建数据加载器
    sampled_loader = DataLoader(
        translation_dataset,
        batch_size=config['batch-size'],
        sampler=sampler,
        num_workers=config['num-workers'],
        collate_fn=collate_fn
    )

In [27]:
for batch in sampled_loader:
    print(batch)
    break

{'input_ids': tensor([[  11, 1934,   12,  ...,    0,    0,    0],
        [  44,  465,  842,  ...,    0,    0,    0],
        [ 838,  317,  323,  ...,    0,    0,    0],
        ...,
        [1100,  340,  317,  ...,    0,    0,    0],
        [1754,  422, 2414,  ...,    0,    0,    0],
        [ 584,  317,   67,  ...,    0,    0,    0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[ 6744,   766, 13879,  ...,     0,     0,     0],
        [ 2726,  1757,  1055,  ...,     0,     0,     0],
        [  622,   371,   377,  ...,     0,     0,     0],
        ...,
        [ 1018,   556,  7595,  ...,     0,     0,     0],
        [  557,   416,  3325,  ...,     0,     0,     0],
        [ 1005,   371, 29905,  ...,     0,     0,     0]])}


## 模型基础模块

来源自：https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer

有简单的copy，也有复杂的修改和优化

torch的tensor之间的计算还是不会

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    """ Scaled Dot-product Attention 
    q: (batch_size, num_heads, seq_len_q, d_k)
    k: (batch_size, num_heads, seq_len_k, d_k)
    v: (batch_size, num_heads, seq_len_v, d_v)

    attn: (batch_size, num_heads, seq_len_q, seq_len_k)
    output: (batch_size, num_heads, seq_len_q, d_v)
    """
    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
    
    def forward(self, q, k, v, mask=None):
        """
        计算 attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 时
        k.transpose(2, 3) 改变 k 的大小为 (batch_size, num_heads, d_k, seq_len_k)
        结果 attn 的大小为 (batch_size, num_heads, seq_len_q, seq_len_k)
        """
        attn = torch.matmul(q / self.temperature, k.transpose(2,3))

        if mask is not None:
            attn = attn.masked_fill(mask==0,-1e9)

        attn = self.dropout(F.softmax(attn,dim=-1))
        """
        v: (batch_size, num_heads, seq_len_v, d_v)
       attn(batch_size, num_heads, seq_len_q, seq_len_k)
        """
        output = torch.matmul(attn, v)

        return output, attn

class MultiHeadAttention(nn.Module):
    """多头注意力机制"""
    """定义这个多头注意力机制的类"""
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        """
        n_head  :注意力头的数量
        d_model :输入和输出的维度,通常是嵌入向量的维度
        d_k     :每个注意力头的查询和键的维度, 通常是 d_model 除以 n_head 的值
        d_v     :每个注意力头的值的维度。
        droput  :
        """
        super().__init__()
        self.n_head = n_head
        self.d_k = d_K
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head*d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head*d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head*d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False) # 结合模型理解

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) # 
    
    def forward(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q
        
        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
        if mask is not None:    
            mask = mask.unsqueeze(1) # For head axis broadcasting.
        q, attn = self.attention(q,k,v,mask=mask)
        
        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1,2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual
        q = self.layer_norm(q)

        return q,attn

"""不熟悉"""
class PositionwiseFeedForward(nn.Module):
    """一个标准的FNN模型"""
    """FNN的模型结构是什么"""
    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)
        self.w_2 = nn.Linear(d_hid, d_in)
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        residual = x
        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x+= residual
        x = self.layer_norm(x)
        return x

"""不熟悉"""
class EncoderLayer(nn.Module):
    """MultiHeadAttention+FNN"""
    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, self_attn_mask=None):
        enc_output, enc_self_attn = self.self_attention(enc_input, enc_input, enc_input, mask=self_attn_mask)
        enc_output = self.feed_forward(enc_output)
        return enc_output, enc_self_attn

"""更不熟悉了"""
class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
    
    def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
        dec_output, dec_slf_attn = self.slf_attn(dec_input, dec_input, dec_input, mask=slf_attn_mask)
        dec_output, dec_enc_attn = self.enc_attn(dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) # 因为当前的decode是掩盖的，所以使用enc_output作为value
        dec_output = self.feed_forward(dec_output)
        return dec_output, dec_slf_attn, dec_enc_attn

"""构筑模型"""

def get_pad_mask(seq, pad_idx):
    """非填充值的位置为True，填充值的位置为False。这个掩码用于在注意力机制中忽略填充值"""
    return (seq != pad_idx).unsqueeze(-2)

"""可视化它"""
def get_subsequent_mask(seq):
    """屏蔽序列中当前时间步之后的所有时间步的信息"""
    sz_b, len_s = seq.size()
    subsequent_mask = (1 - torch.triu(torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool()
    return subsequent_mask

class PositionalEncoding(nn.Module):
    def __init__(self, d_in, d_hid, n_position=200):
        super(PositionalEncoding, self).__init__()

        self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
    
    def _get_sinusoid_encoding_table(self, n_position, d_hid):
        position = torch.arange(n_position).unsqueeze(1)  # [n_position, 1]
        div_term = torch.pow(10000, torch.arange(0, d_hid, 2) / d_hid)
        angle_rads = position / div_term # [n_position, d_hid // 2]

        # 偶数维度用 sin，奇数维度用 cos
        sinusoid_table = torch.zeros((n_position, d_hid))  # 初始化表格
        sinusoid_table[:, 0::2] = torch.sin(angle_rads)  # 偶数维度
        sinusoid_table[:, 1::2] = torch.cos(angle_rads)  # 奇数维度
        
        return sinusoid_table.unsqueeze(0)  # [1, n_position, d_hid]

    def forward(self, x):
        return x + self.pos_table[:, :x.size(1)].clone().detach()

class Encoder(nn.Module):
    def __init__(self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v, d_model, d_inner, pad_idx, dropout=0.1, n_position=200, scale_emb=False):
        super().__init__()

        self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx) # ? pad_idx
        self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
        self.dropout = dropout
        self.layer_stack = nn.ModuleList([EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)])
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.scale_emb = scale_emb
        self.d_model = d_model

    def forward(self, src_seq, src_mask, return_attn=False):
        enc_slf_attn_list = []

        enc_output = self.slf_word_emb(src_seq)
        if self.scale_emb:
            enc_output *= self.d_model ** 0.5 # ??
        enc_output = self.dropout(self.position_enc(enc_output))
        enc_output = self.layer_norm(enc_output) # ?? 结构是这样子的吗
        
        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output, self_attn_mask=src_mask)
            enc_self_attn_list += [enc_slf_attn] if return_attn else []
        
        if return_attn:
            return enc_output, enc_self_attn_list
        else:
            return enc_output


class Decoder(nn.Module):
    def __init__(self, n_tgt_vocab, d_word_vec, n_layers, d_k, d_v, d_model, d_inner, pad_idx, n_position=200, dropout=0.1, scale_emb=False):
        super().__init__()
        self.tgt_word_emb = nn.Embedding(n_tgt_vocab, d_word_vec, padding_idx=pad_idx)
        self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
        self.dropout = dropout
        self.layer_stack = nn.ModuleList([DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)])
        self.layer_norm  = nn.LayerNorm(d_model, eps=1e-6)
        self.scale_emb = scale_emb
        self.d_model = d_model 
    
    def forward(self, tgt_seq, tgt_mask, enc_output, src_mask, return_attn=False):
        dec_slf_attn_list = []
        dec_enc_attn_list = []

        dec_output =  self.tgt_word_emb(tgt_seq)
        if self.scale_emd:
            dec_output *= self.d_model ** 0.5
        dec_output = self.dropout(self.position_enc(dec_output))
        dec_output = self.layer_norm(dec_output)

        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(dec_output, enc_output, slf_attn_mask=tgt_mask, dec_enc_attn_mask=src_mask)
            dec_slf_attn_list += [dec_slf_attn] if return_attn else []
            dec_enc_attn_list += [dec_enc_attn] if return_attn else []
        
        if return_attn:
            return dec_output, dec_slf_attn_list, dec_enc_attn_list
        else:
            return dec_output
    
class Transformer(nn.Module):
    def __init__(
        self, 
        n_src_vocab, 
        n_tgt_vocab, 
        src_pad_idx, 
        tgt_pad_idx, 
        d_word_vec=512, 
        d_model=512, 
        d_inner=2048, 
        n_layers=6, 
        n_head=8, 
        d_k=64, 
        d_v=64, 
        dropout=0.1, 
        n_position=200, 
        tgt_emb_prj_weight_sharing=True, 
        emb_src_tgt_weight_sharing=True, 
        scale_emb_or_prj='prj'):
        
        super().__init__()
        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx
        assert scale_emb_or_prj in ['emb', 'prj', 'none']
        scale_emb = (scale_emb_or_prj == 'emb') if tgt_emb_prj_weight_sharing else False
        self.scale_prj = (scale_emb_or_prj == 'prj') if trg_emb_prj_weight_sharing else False
        self.d_model = d_model

        self.encoder = Encoder(
            n_src_vocab=n_src_vocab,
            n_position=n_position,
            d_word_vec=d_word_vec,
            d_model=d_model,
            d_inner=d_inner,
            n_layers=n_layers,
            n_head=n_head,
            d_k=d_k,
            d_v=d_v,
            pad_idx=src_pad_idx,
            dropout=dropout,
            scale_emb=scale_emb
        )

        self.decoder = Decoder(
            n_tgt_vocab=n_tgt_vocab,
            n_position=n_position,
            d_word_vec=d_word_vec,
            d_model=d_model,
            d_inner=d_inner,
            n_layers=n_layers,
            n_head=n_head,
            d_k=d_k,
            d_v=d_v,
            pad_idx=tgt_pad_idx,
            dropout=dropout,
            scale_emb=scale_emb
        )

        # 这是一个线性层，将解码器的输出映射为目标语言词汇表的 logits
        self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)

        # 使用 Xavier 初始化方法对所有权重矩阵进行初始化
        for p in self.parameters():
            if p.dim > 1:
                nn.init.xavier_uniform_(p)

        assert d_model == d_word_vec, \
        'To facilitate the residual connections, \
         the dimensions of all module outputs shall be the same.'

        # 如果启用了 tgt_emb_prj_weight_sharing，则该层的权重会与目标语言嵌入矩阵共享
        if tgt_emb_prj_weight_sharing:
            # Share the weight between target word embedding & last dense layer
            self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight

        # 如果启用 emb_src_tgt_weight_sharing，则源语言和目标语言的嵌入矩阵权重会被共享。
        if emb_src_tgt_weight_sharing:
            self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight
    
    def forward(self, src_seq, tgt_seq):
        src_mask = get_pad_mask(src_seq, self.src_pad_idx)
        tgt_mask = get_pad_mask(tgt_seq, self.tgt_pad_idx) & get_subsequent_mask(tgt_seq)

        enc_output, *_ = self.encoder(src_seq, src_mask)
        dec_output, *_ = self.decoder(tgt_seq, tgt_mask, enc_output, src_mask)

        # 将解码器输出映射为目标语言词汇表的 logits
        seq_logit = self.tgt_word_prj(dec_output)

        # 如果启用了投影缩放，则对 logits 进行缩放
        if self.scale_prj:
            seq_logit *= self.d_model ** -0.5

        # 将 logits 展平为 [batch_size * seq_len, vocab_size] 的形状，便于计算损失函数
        return seq_logit.view(-1, seq_logit.size(2))


### 问题
90%的都理解了，但是对于什么投影共享之类的，很理解

## 优化模块


In [35]:
import numpy as np

class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, lr_mul, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.lr_mul = lr_mul
        self.d_model = d_model
        self.n_warmup_steps = n_warm_steps
        self.n_steps = 0
    
    def step_and_update_lr(self):
        self._update_learning_rate()
        self._optimizer.step()
    
    def zero_grad(self):
        "Zero out the gradients with the inner optimizer"
        self._optimizer.zero_grad()
    
    def _get_lr_scale(self):
        d_model = self.d_model
        n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
        return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_steps += 1
        lr = self.lr_mul * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

## 训练函数模块

In [None]:
import argparse
import math
import time
import dill as pickle
from tqdm import tqdm
import numpy as np
import random
import os

import torch
import torch.nn.functional as F
import torch.optim as optim

def call_loss(pred, gold, tgt_pad_idx, smoothing=False):
    ''' Calculate cross entropy loss, apply label smoothing if needed. '''
    """
    pred: [batch_size * seq_len, vocab_size]
    gold：目标序列的真实标签，形状为 [batch_size, seq_len]
    trg_pad_idx：目标语言中的填充符索引（PAD token），用于忽略填充部分的损失。

    """
    # gold: 真实标签
    gold = gold.contiguous().view(-1)
    if smoothing:
        eps = 0.1
        n_class = pred.size(1)

        one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1,1), 1) #
        one_hot = one_hot*(1-eps)+(1-one_hot)*eps / (n_class - 1)
        log_prb = F.log_softmax(pred, dim=1) # 使用 F.log_softmax 计算预测结果的对数概率

        non_pad_mask = gold.ne(tgt_pad_idx) #
        loss = -(one_hot*log_prb).sum(dim=1)
        loss = loss.masked_select(non_pad_mask).sum()
    else:
        loss = F.cross_entropy(pred, gold, ignore_index=tgt_pad_idx, reduction='sum')
    return loss

def call_performance(pred, gold, tgt_pad_idx, smoothing=False):
    """Apply label smoothing is needed"""

    loss = call_loss(pred, gold, tgt_pad_idx, smoothing=smoothing)

    pred = pred.max(1)[1]
    gold = gold.contiguous().view(-1)
    non_pad_mask = gold.ne(tgt_pad_idx)
    n_correct = pred.eq(gold).masked_select(non_pad_mask).sum().item()
    n_word = non_pad_mask.sum().item()
    return loss, n_correct, n_word

def patch_src(src, pad_idx):
    src = src.transpose(0,1)
    return src

# 不是很理解这一点
def patch_tgt(tgt, pad_idx):
    """
    tgt：用于解码器的输入，不包含最后一个 token，形状为 [batch_size, seq_len - 1]。
    gold：用于计算损失，包含从第二个 token 开始的目标序列，并展平为一维张量。形状为 [batch_size * (seq_len - 1)]
    """
    tgt = tgt.transpose(0,1)
    tgt, gold = tgt[:,:-1], tgt[:,:-1].contiguous().view(-1)
    return tgt, gold

def train_epoch(model, training_data, optimizer, opt, device, smoothing):
    model.train()
    total_loss, n_word_total, n_word_correct = 0, 0, 0
    desc = '  - (Training)   '
    for batch in tqdm(training_data, mininterval=2, desc=desc, leave=False):
        # prepare data
        src_seq = patch_src(batch.src, opt.src_pad_idx).to(device)
        tgt_seq, gold = map(lambda x: x.to(device), patch_tgt(batch.tgt, opt.tgt_pad_idx))

        # forward
        optimizer.zero_grad()
        pred = model(src_seq, tgt_seq)

        # backward and update parameters
        loss, n_correct, n_word = call_performance(pred, gold, opt.tgt_pad_idx, smoothing=smoothing)
        loss.backward()
        optimizer.step_and_update_lr()

        n_word_total += n_word
        n_word_correct += n_correct
        total_loss += loss.item()
    
    loss_per_word = total_loss / n_word_total
    accuracy = n_word_correct / n_word_total
    return loss_per_word, accuracy

def eval_epoch(model, validation_data, device, opt):
    model.eval()
    total_loss, n_word_total, n_word_correct = 0,0,0
    desc = '  - (Validation) '
    with torch.no_grad():
        for batch in tqdm(validation_data, mininterval=2, desc=desc, leave=False):
            src_seq = patch_src(batch.src, opt.src_pad_idx).to(device)
            tgt_seq, gold = map(lambda x:x.to(device), patch_tgt(batch.tgt, opt.tgt_pad_idx))

            pred = model(src_seq, trg_seq)
            loss, n_correct, n_word = call_performance(pred, gold, opt.tgt_pad_idx, smoothing=False)

            n_word_total += n_word
            n_word_correct += n_correct
            total_loss += loss.item()
    loss_per_word = total_loss / n_word_total
    accuracy = n_word_correct / n_word_total
    return loss_per_word, accuracy

def train(model, training_data, validation_data, optimizer, device, opt):
    # Use tensorboard to plot curves, e.g. perplexity, accuracy, learning rate
    if opt.use_tb:
        print("[Info] Use Tensorboard")
        from torch.utils.tensorbard import SummaryWriter
        tb_writer = SummaryWriter(log_dir=opt.output_dir)
    log_train_file = opt.train_log_dir
    log_valid_file = opt.valid_log_dir

    print('[Info] Training performance will be written to file: {} and {}'.format(log_train_file, log_valid_file))

    with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf:
        log_tf.write('epoch,loss,ppl,accuracy\n')
        log_vf.write('epoch,loss,ppl,accuracy\n')
    
    def print_performance(header, ppl, accu, start_time, lr):
        print('  - {header:12} ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, lr: {lr:8.5f}, '\
              'elapse: {elapse:3.3f} min'.format(
                  header=f"({header})", ppl=ppl,
                  accu=100*accu, elapse=(time.time()-start_time)/60, lr=lr))
    
    valid_loss = []
    for epoch_i in range(opt.epoch):
        print('[ Epoch', epoch_i, ']')

        start = time.time()
        train_loss, train_accu = train_epoch(
            model,
            training_data,
            optimizer,
            opt,
            device,
            smoothing=opt.label_smoothing
        )
        train_ppl = math.exp(min(train_loss, 100))
        lr = optimizer._optimizer.param_groups[0]['lr']
        print_performances('Training', train_ppl, train_accu, start, lr)

        start = time.time()
        valid_loss, valid_accu = eval_epoch(
            model, 
            validation_data, 
            device, 
            opt)
        valid_ppl = math.exp(min(valid_loss, 100))
        print_performances('Validation', valid_ppl, valid_accu, start, lr)
        valid_losses += [valid_loss]
        checkpoint = {'epoch': epoch_i, 'settings': opt, 'model': model.state_dict()}

        if opt.save_mode == 'all':
            model_name = 'model_accu_{accu:3.3f}.chkpt'.format(accu=100*valid_accu)
            torch.save(checkpoint, model_name)
        elif opt.save_mode == 'best':
            model_name = 'model.chkpt'
            if valid_loss <= min(valid_losses):
                torch.save(checkpoint, os.path.join(opt.output_dir, model_name))
                print('    - [Info] The checkpoint file has been updated.')

        with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf:
            log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
                epoch=epoch_i, loss=train_loss,
                ppl=train_ppl, accu=100*train_accu))
            log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
                epoch=epoch_i, loss=valid_loss,
                ppl=valid_ppl, accu=100*valid_accu))

        if opt.use_tb:
            tb_writer.add_scalars('ppl', {'train': train_ppl, 'val': valid_ppl}, epoch_i)
            tb_writer.add_scalars('accuracy', {'train': train_accu*100, 'val': valid_accu*100}, epoch_i)
            tb_writer.add_scalar('learning_rate', lr, epoch_i)

def main():
    ''' 
    Usage:
    python train.py -data_pkl m30k_deen_shr.pkl -log m30k_deen_shr -embs_share_weight -proj_share_weight -label_smoothing -output_dir output -b 256 -warmup 128000
    '''

    parser = argparse.ArgumentParser()

    parser.add_argument('-data_pkl', default=None)     # all-in-1 data pickle or bpe field

    parser.add_argument('-train_path', default=None)   # bpe encoded data
    parser.add_argument('-val_path', default=None)     # bpe encoded data

    parser.add_argument('-epoch', type=int, default=10)
    parser.add_argument('-b', '--batch_size', type=int, default=2048)

    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-d_inner_hid', type=int, default=2048)
    parser.add_argument('-d_k', type=int, default=64)
    parser.add_argument('-d_v', type=int, default=64)

    parser.add_argument('-n_head', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-warmup','--n_warmup_steps', type=int, default=4000)
    parser.add_argument('-lr_mul', type=float, default=2.0)
    parser.add_argument('-seed', type=int, default=None)

    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-embs_share_weight', action='store_true')
    parser.add_argument('-proj_share_weight', action='store_true')
    parser.add_argument('-scale_emb_or_prj', type=str, default='prj')

    parser.add_argument('-output_dir', type=str, default=None)
    parser.add_argument('-use_tb', action='store_true')
    parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best')

    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-label_smoothing', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.d_word_vec = opt.d_model

    # https://pytorch.org/docs/stable/notes/randomness.html
    # For reproducibility
    if opt.seed is not None:
        torch.manual_seed(opt.seed)
        torch.backends.cudnn.benchmark = False
        # torch.set_deterministic(True)
        np.random.seed(opt.seed)
        random.seed(opt.seed)

    if not opt.output_dir:
        print('No experiment result will be saved.')
        raise

    if not os.path.exists(opt.output_dir):
        os.makedirs(opt.output_dir)

    if opt.batch_size < 2048 and opt.n_warmup_steps <= 4000:
        print('[Warning] The warmup steps may be not enough.\n'\
              '(sz_b, warmup) = (2048, 4000) is the official setting.\n'\
              'Using smaller batch w/o longer warmup may cause '\
              'the warmup stage ends with only little data trained.')

    device = torch.device('cuda' if opt.cuda else 'cpu')

    #========= Loading Dataset =========#

    if all((opt.train_path, opt.val_path)):
        training_data, validation_data = prepare_dataloaders_from_bpe_files(opt, device)
    elif opt.data_pkl:
        training_data, validation_data = prepare_dataloaders(opt, device)
    else:
        raise

    print(opt)

    transformer = Transformer(
        opt.src_vocab_size,
        opt.trg_vocab_size,
        src_pad_idx=opt.src_pad_idx,
        trg_pad_idx=opt.trg_pad_idx,
        trg_emb_prj_weight_sharing=opt.proj_share_weight,
        emb_src_trg_weight_sharing=opt.embs_share_weight,
        d_k=opt.d_k,
        d_v=opt.d_v,
        d_model=opt.d_model,
        d_word_vec=opt.d_word_vec,
        d_inner=opt.d_inner_hid,
        n_layers=opt.n_layers,
        n_head=opt.n_head,
        dropout=opt.dropout,
        scale_emb_or_prj=opt.scale_emb_or_prj).to(device)

    optimizer = ScheduledOptim(
        optim.Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-09),
        opt.lr_mul, opt.d_model, opt.n_warmup_steps)

    train(transformer, training_data, validation_data, optimizer, device, opt)



In [45]:
pred = torch.tensor([[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6]])
gold = torch.tensor([0, 1, 2])
tgt_pad_idx = 2
smoothing = True

call_performance(pred, gold, tgt_pad_idx, smoothing)

(tensor(1.4631), 2, 2)