# "Attention is all you need" 复现
基于pytorch库的内容，先实现transformer的基本训练、评估和测试的全流程

之后再尝试模仿他人，学习如何不完全借用pytorch的transformer从而实现语言翻译

## 1. 数据处理代码优化 PreprocessData
数据处理过程：
- 1. load_corpus_generator, 原始语料库加载函数
    - 支持语句长度限制
- 2. TokenizerTrain class, 分词器训练类，用于训练本地语料库
    - 逻辑稍后整理
- 3. TokenizerLoader，加载训练好的分词器
    - 逻辑...
- 4. TranslationDataset class，数据集处理类
- 5. collate_fn 函数，对数据进行批处理，包括填充、堆叠、排序等操作

In [23]:
import torch 
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from torch.nn.utils.rnn import pad_sequence

config = {
    'source-file':"/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/europarl-v7.de-en.en",
    'target-file':"/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/europarl-v7.de-en.de",
    'source-tokenizer-file':"/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/en_tokenizer.json",
    'target-tokenizer-file':"/harddisk1/SZC-Project/NLP-learning/Transformer/Transformer-pytorch-from-scratch/de_tokenizer.json",
    'special-tokne':["[PAD]", "[UNK]", "[BOS]", "[EOS]"],
    'vocab-size':30000,
    'min-length':5,
    'max-length':128,
    'batch-size':64,
    'sample-ratio':0.1,
    'num-workers':4
}

# 1. 加载符合要求的语料库
def load_corpus_generator(file_path, min_length=5, max_length=128):
    #output = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line and min_length <= len(line.split()) <= max_length:
                #output.append(line)
                yield line
    #return output

# 2. 分词训练器
class TokenizerTrain:
    def __init__(self, vocab_size, special_tokens):
        self.vocab_size = vocab_size
        self.special_tokens = sepcial_tokens

    def train_and_save(self, corpus, output_path, language_name):
        tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = BPE.Trainer(special_tokens=self.special_tokens, vocab_size=self.vocab_size)
        tokenizer.train_from_iterator(corpus, trainer=trainer)
        tokenizer.save(f"{output_path}/{language_name}_tokenizer.json")

# 3. 加载分词器
def TokenizerLoader(tokenizer_path):
    tokenizer = Tokenizer.from_file(tokenizer_path)
    tokenizer = PreTrainedTokenizerFast(
        tokenizer_object = tokenizer,
        bos_token = "[BOS]",
        eos_token = "[EOS]",
        pad_token = "[PAD]",
        unk_token = "[UNK]"
    )
    return tokenizer

# 4. 数据集类
class TranslationDataset(Dataset):
    def __init__(self, src_lines, tgt_lines, src_transformer_file, tgt_transformer_file, max_length):
        self.src_generator = src_lines
        self.tgt_generator = tgt_lines
        self.src_lines = list(src_lines)
        self.tgt_lines = list(tgt_lines)
        self.src_transformer_tokenizer = TokenizerLoader(src_transformer_file)
        self.tgt_transformer_tokenizer = TokenizerLoader(tgt_transformer_file)
        self.max_length = max_length
    
    def __len__(self):
        return len(self.src_lines)
    
    def __getitem__(self, idx):
        src_line = self.src_lines[idx]
        tgt_line = self.tgt_lines[idx]

        src_encoding = self.src_transformer_tokenizer(
            src_line,
            padding = 'max_length',
            truncation = True,
            max_length = self.max_length,
            return_tensors = "pt"
        )

        tgt_encoding = self.tgt_transformer_tokenizer(
            tgt_line,
            padding = 'max_length',
            truncation = True,
            max_length = self.max_length,
            return_tensors = "pt"
        )

        return {
            "input_ids": src_encoding['input_ids'].squeeze(0),
            "attention_mask": src_encoding['attention_mask'].squeeze(0),
            "labels":tgt_encoding['input_ids'].squeeze(0)
        }

# 自定义 collate_fn
def collate_fn(batch):
    input_ids = pad_sequence([item['input_ids'] for item in batch], batch_first=True, padding_value=0)
    attention_mask = pad_sequence([item['attention_mask'] for item in batch], batch_first=True, padding_value=0)
    labels = pad_sequence([item['labels'] for item in batch], batch_first=True, padding_value=-100)  # -100 是常用的忽略索引值
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

# 测试用例
src_lines = load_corpus_generator(config['source-file'], config['min-length'], config['max-length'])
tgt_lines = load_corpus_generator(config['source-file'], config['min-length'], config['max-length'])
translation_dataset = TranslationDataset(src_lines, tgt_lines, config['source-tokenizer-file'], config['target-tokenizer-file'], config['max-length'])


indices = np.random.choice(len(translation_dataset), int(len(translation_dataset) * config["sample-ratio"]), replace=False)
sampler = SubsetRandomSampler(indices)

sampled_loader = DataLoader(
    translation_dataset,
    batch_size=config['batch-size'],
    sampler=sampler,
    num_workers=config['num-workers'],
    collate_fn=collate_fn
)

In [25]:
len(sampled_loader)

2912

In [24]:
for batch in sampled_loader:
    tmp = batch
    break

In [29]:
tmp['input_ids'].size()

torch.Size([64, 128])

# 2. 基于pytorch的Transformer模型搭建