## 机器翻译作业

In [72]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


本次作业使用 seq2seq模型完成英文到中文的机器翻译任务，请补充缺失部分的代码

In [73]:
#导入常用软件包
import torch
import sys
from torch import nn, optim
import random
from torch.nn.functional import softmax
import torch.nn.functional as F
import re
import os

from torchtext.legacy import data
from torchtext.legacy.data import Iterator, BucketIterator
from collections import defaultdict
import string
import dill

from tqdm import tqdm

## 1.数据加载

数据：包含六个文件：
1. 训练集：train.seg.en.txt train.seg.zh.txt （11743条）
2. 验证集：dev.seg.en.txt dev.seg.zh.txt （2936条）
3. 测试集：test.seg.en.txt test.seg.zh.txt （5194条）

使用torchtext完成数据的加载，主要使用以下三个组件：

1. Field :主要包含以下数据预处理的配置信息，比如指定分词方法，是否转成小写，起始字符，结束字符，补全字符以及词典等等

2. Dataset :继承自pytorch的Dataset，用于加载数据，提供了TabularDataset可以指点路径，格式，Field信息就可以方便的完成数据加载。同时torchtext还提供预先构建的常用数据集的Dataset对象，可以直接加载使用，splits方法可以同时加载训练集，验证集和测试集。

3. Iterator : 主要是数据输出的模型的迭代器，可以支持batch定制

具体可以参照：https://pytorch.org/text/stable/index.html
关于torchtext的安装可以参照：https://github.com/pytorch/text

In [74]:
class Dataloader:
    def __init__(self, batch_size, device, eval=False):
        raw_data = self.read_data("/content/drive/MyDrive/machine_translation/data/", test=eval)
        ## 训练模式
        if not eval:
            train_data, dev_data = raw_data
            ##定义数据字段
            self.id_field = data.Field(sequential=False, use_vocab=False)
            self.en_field = data.Field(init_token='<sos>', eos_token='<eos>', lower=True, include_lengths=True)
            self.zh_field = data.Field(init_token='<sos>', eos_token='<eos>', lower=True)
            self.fields = [("id", self.id_field), ("en", self.en_field), ("zh", self.zh_field)]

            ##构建数据集
            train_dataset = data.Dataset([data.Example.fromlist([idx, item[0], item[1]], self.fields) for idx, item in enumerate(train_data)], self.fields)
            dev_dataset =  data.Dataset([data.Example.fromlist([idx, item[0], item[1]], self.fields) for idx, item in enumerate(dev_data)], self.fields)
            
            ##构建数据迭代器
            self.train_iterator= BucketIterator(train_dataset, batch_size=batch_size, device=device, sort_key=lambda x: len(x.en), sort_within_batch=True)
            self.dev_iterator= BucketIterator(dev_dataset, batch_size=batch_size, device=device, sort_key=lambda x: len(x.en), sort_within_batch=True)
            
            ##构建词典
            self.en_field.build_vocab(train_dataset, min_freq=2)
            self.zh_field.build_vocab(train_dataset, min_freq=2)
            
            ##存储字段
            dill.dump(self.en_field, open("/content/drive/MyDrive/machine_translation/model/EN.Field", "wb"))
            dill.dump(self.zh_field, open("/content/drive/MyDrive/machine_translation/model/ZH.Field", "wb"))

            print("en vocab size:", len(self.en_field.vocab.itos),"zh vocab size:", len(self.zh_field.vocab.itos))
        
        ## 测试模式  
        else:
            test_data = raw_data[-1]
            ##加载存储的字段
            self.id_field = data.Field(sequential=False, use_vocab=False)
            self.en_field = dill.load(open("/content/drive/MyDrive/machine_translation/model/EN.Field", "rb"))
            self.zh_field = dill.load(open("/content/drive/MyDrive/machine_translation/model/ZH.Field", "rb"))
            self.fields = [("id", self.id_field), ("en", self.en_field), ("zh", self.zh_field)]
            
            ##构建测试集 & 迭代器
            test_data = data.Dataset([data.Example.fromlist([idx, item[0], item[1]], self.fields) for idx, item in enumerate(test_data)], self.fields)
            self.test_iterator= BucketIterator(test_data, batch_size=batch_size, device=device, train = False, sort_key=lambda x: len(x.en), sort_within_batch = True)   
        
    ##从文件中读取数据
    def read_data(self, path, test=True, lang1='en', lang2 = 'zh'):
        data = []
        types = ['test'] if test else ['train', 'dev']
        # print(types)
        for type in types:
            sub_data = []
            with open(f"{path}/{type}.seg.{lang1}.txt", encoding='utf-8') as f1, open(f"{path}/{type}.seg.{lang2}.txt", encoding='utf-8') as f2:
                for src, trg in zip(f1, f2):
                    if len(src) > MAX_LEN and len(trg) > MAX_LEN:
                        continue
                    sub_data.append((src.strip(), trg.strip()))
            data.append(sub_data)

        return data


## 2.模型构建
使用seq2seq模型完成机器翻译模型的搭建,可以参考 pytorch tutorials https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

这里需要完成的是一个seq2seq+attention的基础模型，有余力的可以进行更多的尝试
相关的论文：
1. attention:https://arxiv.org/abs/1409.0473
2. copy:https://arxiv.org/abs/1603.06393
3. coverage:https://arxiv.org/abs/1601.04811

### 2.1 Encoder
这里为了简化代码，embedding_size和hidden_size使用相同大小，均为传入的hid_dim。RNN模块可以采用双向GRU实现。

In [75]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        
        ##构建embedding
        self.embedding = nn.Embedding(input_dim, emb_dim)
        
        ##构建rnn模块
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
        
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        
        ##构建dropout层
        self.dropout = nn.Dropout(dropout)
        
        
    def forward(self, src_info):
        src, src_len = src_info
        
        ## 补充embedding层（+dropout）代码
        embedded = self.dropout(self.embedding(src))
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_len.to('cpu'))
        
        ## 补充rnn层代码
        packed_outputs, hidden = self.rnn(packed_embedded)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs) 
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
        
        return outputs, hidden

### 2.2 Attention模块
Attention 机制的 Encoder-Decoder 模型则是要从序列中学习到每一个元素的重要程度，然后按重要程度将元素合并。因此，注意力机制可以看作是 Encoder 和 Decoder 之间的接口，它向 Decoder 提供来自每个 Encoder 隐藏状态的信息。通过该设置，模型能够选择性地关注输入序列的有用部分，从而学习它们之间的“对齐”。这就表明，在 Encoder 将输入的序列元素进行编码时，得到的不在是一个固定的语义编码 C ，而是存在多个语义编码，且不同的语义编码由不同的序列元素以不同的权重参数组合而成。

在 Attention 机制下，语义编码 C 是各个元素按其重要程度加权求和得到的，即：

$C_i=\sum_{i=1}^{n}{a}_{ij}*{h_i}$

参数 𝑖 表示时刻， 𝑗 表示序列中的第 𝑗 个元素， 𝑇𝑥 表示序列的长度， 𝑓(⋅) 表示对元素 𝑥𝑗 的编码。𝑎𝑖𝑗 可以看作是一个概率，反映了元素 ℎ𝑗 对 𝐶𝑖 的重要性，可以使用 softmax 来表示：

${a}_{ij}= \frac{exp(e_{ij})}{\sum_{k=1}^{T_x}exp(e_{ik})}$

where $e_{ij}=a(s_{i-1}, h_j)$


请实现Attention模块：
（也可以不使用单独的Attention模块，在decoder部分的代码中整合Attention机制）

In [76]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        self.attn = nn.Linear((enc_hid_dim*2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)


    def forward(self, hidden, encoder_outputs, mask):
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        #将decoder的隐状态重复src_len次
#         print("in attention, hidden shape:", hidden.shape)
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        #hidden = [batch size, src len, dec hid dim]
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        
        ##请补充计算attention score的代码
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))
        attention = self.v(energy).squeeze(2)
        attention.masked_fill(mask == 0, -1e10)
        ##mask掉padding部分，计算softmax
        return F.softmax(attention, dim = 1)     

### 2.3 Decoder
请实现带attention机制的decoder：

In [77]:
class Decoder(nn.Module):
    def __init__(self, output_dim,emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        ## 继续补充
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, encoder_outputs, mask):
        # input = [batch_size]
        # hidden = [batch_size, dec_hid_dim]
        # encoder_outputs = [src_sent_len, batch_size, enc_hid_dim * 2]
        input = input.unsqueeze(0)
        
        ##embedding层
        embedded = self.dropout(self.embedding(input))
        
        
        ##计算attention score
        a = self.attention(hidden, encoder_outputs,mask).unsqueeze(1)
        
        ##根据attention score计算weighted的context向量
        encoder_outputs = encoder_outputs.permute(1,0,2)
        weighted = torch.bmm(a, encoder_outputs)
        weighted = weighted.permute(1,0,2)
        rnn_input = torch.cat((embedded,weighted), dim=2)
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        assert (output == hidden).all()
        
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
        
        return prediction, hidden.squeeze(0), a.squeeze(1)

## 2.4 Seq2seq模型

In [78]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, pad_idx):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
        self.device = device

    def forward(self, src_info, trg = None, teacher_forcing_ratio = 0.5):
        src, src_len = src_info
        batch_size = src.shape[1]
        max_len = trg.shape[0] if trg is not None else MAX_LEN
        trg_vocab_size = self.decoder.output_dim
        
        ##存储所有decoder输出的结果
        outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)
        attn_scores = []
        
        ## encoder
        encoder_outputs, hidden = self.encoder(src_info)

        ##初始化decoder的输入是<sos>token
        input = trg[0, :] if trg is not None else src[0, :]
        
        #mask = [batch size, src len]
        mask = self.create_mask(src)

        
        ## decode过程，每个step decode出一个token 
        for t in range(1, max_len):
            ## 请补全decoder的代码，得到output, hidden, atten_score
            output, hidden, a = self.decoder(input, hidden, encoder_outputs,mask)
            attn_scores.append(a)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[t] if teacher_force else top1
            
            
        return outputs, torch.cat(attn_scores, dim = 1).to(self.device)

    def create_mask(self, src):
        mask = (src != self.pad_idx).permute(1, 0)
        return mask

## 3.训练模块代码
包括BLEU的计算：BLEU是一种评价机器翻译的指标，NLTK中包含了计算BLEU值的工具。

In [79]:
## bleu计算
import jieba
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction

def calculate_bleu(hypothesis, targets, cut=True, verbose=False):
    bleu_scores = []
    for sent, trg in zip(hypothesis, targets):
        trg = trg.strip().lower().split()
        sent = sent.strip()
        if cut:
            trg = list(jieba.cut(''.join(trg).replace("-", "")))
            sent = list(jieba.cut(''.join(sent).replace("-", "")))
 
        bleu = sentence_bleu([trg], sent, weights=(0.5, 0.5, 0., 0.),smoothing_function = SmoothingFunction().method1)
        if verbose:
            print(f"src:{sent.strip()}\ntrg:{trg}\npredict:{sent}\n{bleu}\n")
        bleu_scores.append(bleu)         
    return sum(bleu_scores) / len(bleu_scores)

请补充train_iter的代码:
（在train函数中将调用train_iter，或者直接将这部分整合到train函数中）

In [80]:
def train_iter(model, iterator, optimizer, criterion, clip, nl_field = None):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(tqdm(iterator)):
        src, src_len = batch.en
        trg = batch.zh
        ##补充训练部分的代码
        optimizer.zero_grad()
        output, attn = model(batch.en, trg)
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

请补充完成evaluate_iter的代码：
该函数用于评测模型在验证集上的效果。
生成时可以采用greedy search，即每次直接选择概率最大的词作为输出，直到出现终结符eos或达到最大句子长度，有余力的同学可以尝试用beam search生成。

In [81]:
def evaluate_iter(model, iterator, en_field, zh_field, criterion):
    model.eval()

    hypothesis, targets = [], []
    eval_loss = 0
    with torch.no_grad():
        for _, batch in enumerate(iterator):
            src, src_len = batch.en
            trg = batch.zh
            ## 得到decoder的输出结果并计算loss，同时获取batch_size
            output,attn = model(batch.en, trg, 0)
            batch_size = output.shape[-2]
            output_dim = output.shape[-1]

            eval_output = output[1:].view(-1, output_dim)
            eval_trg = trg[1:].view(-1)

            loss = criterion(eval_output, eval_trg)
            eval_loss += loss.item()
            
            ##解码每个句子，采用greedy search，每一步选择概率最大的词，可以改进为beam search
            for sent in range(batch_size):
                predicts = []  # 模型生成的文本
                grounds = []   # 参考文本（ground truth）
                
                ## 请补充生成部分的代码
                trg_tensor = output[:,sent,:]
                trg_indexes = []
                for i in range(1,output.shape[0]):
                  pred_token = trg_tensor[i,:].argmax().item()
                  if pred_token == zh_field.vocab.stoi[zh_field.eos_token]:
                    break
                  trg_indexes.append(pred_token)
                  
                
                grd_indexes = []
                for i in range(1,output.shape[0]):
                  pred_token = trg[i,sent]
                  if pred_token == zh_field.vocab.stoi[zh_field.eos_token]:
                    break
                  grd_indexes.append(pred_token)
                  
                  

                predicts = [zh_field.vocab.itos[i] for i in trg_indexes]
                grounds = [zh_field.vocab.itos[i] for i in grd_indexes]
                
                hypothesis.append(' '.join(predicts))
                targets.append(' '.join(grounds))
    
    # 根据模型生成的文本和参考文本计算BLEU值
    bleu = calculate_bleu(hypothesis, targets)  

    return bleu, eval_loss / len(iterator)

## 训练模型

In [82]:
def train(dataloader, model, model_output_path):
    print('Start training...')
    
    ## 补充训练部分的代码，optimizer，计算损失函数的criterion等
    
    optimizer = optim.Adam(model.parameters())
    TRG = dataloader.zh_field
    TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
    criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)
    best_loss = 100000
    
    for epoch in range(N_EPOCHS):
        ## 调用train_iter函数
        train_loss = train_iter(model, dataloader.train_iterator, optimizer, criterion, CLIP)
        bleu, loss = evaluate_iter(model, dataloader.dev_iterator, dataloader.en_field, dataloader.zh_field, criterion)
        
        ## 每5个epoch保存一次模型
        if epoch % 5 == 0:
            torch.save(model.state_dict(),'model_'+str(epoch)+'.pt')
        
        ## 计算当前模型在验证集上的效果，如果损失比之前的更小，将当前模型保存为'model_best.pt'
        if loss < best_loss:
            best_bleu = bleu
            best_loss = loss
            torch.save(model.state_dict(), 'model_best.pt')
        
        print(f'Best BLEU: {best_bleu:.3f} | Best Loss:{best_loss:.3f} |  Epoch: {epoch:d} |  BLeu： {bleu:.3f} | Loss:{loss}', flush=True)


In [83]:
## 参数设定(可以自己修改)
MAX_LEN = 128
TRAIN_BATCH_SIZE = 64
INFERENCE_BATCH_SIZE = 64
HID_DIM = 256
DROPOUT=0.2
N_EPOCHS = 100
CLIP = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH='/content/'


In [84]:
##定义各个模块
dataloader = Dataloader(TRAIN_BATCH_SIZE, device)
attn = Attention(HID_DIM, HID_DIM)
INPUT_DIM = len(dataloader.en_field.vocab)
OUTPUT_DIM = len(dataloader.zh_field.vocab)
encoder = Encoder(INPUT_DIM, HID_DIM, HID_DIM, HID_DIM, DROPOUT)
decoder = Decoder(OUTPUT_DIM, HID_DIM, HID_DIM, HID_DIM, DROPOUT, attn)
model = Seq2Seq(encoder, decoder, device, dataloader.zh_field.vocab.stoi['<pad>']).to(device)

## 开始训练
train(dataloader, model, model_output_path= MODEL_PATH)

  0%|          | 0/176 [00:00<?, ?it/s]

en vocab size: 6182 zh vocab size: 4836
Start training...


100%|██████████| 176/176 [00:17<00:00, 10.19it/s]


Best BLEU: 0.043 | Best Loss:4.973 |  Epoch: 0 |  BLeu： 0.043 | Loss:4.973399324850603


100%|██████████| 176/176 [00:17<00:00, 10.23it/s]


Best BLEU: 0.069 | Best Loss:4.332 |  Epoch: 1 |  BLeu： 0.069 | Loss:4.331590787930922


100%|██████████| 176/176 [00:17<00:00, 10.20it/s]


Best BLEU: 0.078 | Best Loss:4.059 |  Epoch: 2 |  BLeu： 0.078 | Loss:4.059345974163576


100%|██████████| 176/176 [00:17<00:00, 10.31it/s]


Best BLEU: 0.079 | Best Loss:3.928 |  Epoch: 3 |  BLeu： 0.079 | Loss:3.9283003509044647


100%|██████████| 176/176 [00:17<00:00, 10.20it/s]


Best BLEU: 0.079 | Best Loss:3.928 |  Epoch: 4 |  BLeu： 0.075 | Loss:3.941565077413212


100%|██████████| 176/176 [00:17<00:00, 10.26it/s]


Best BLEU: 0.080 | Best Loss:3.908 |  Epoch: 5 |  BLeu： 0.080 | Loss:3.907629446549849


100%|██████████| 176/176 [00:17<00:00, 10.21it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 6 |  BLeu： 0.081 | Loss:3.898713697086681


100%|██████████| 176/176 [00:17<00:00, 10.19it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 7 |  BLeu： 0.082 | Loss:3.9417662132870066


100%|██████████| 176/176 [00:17<00:00, 10.06it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 8 |  BLeu： 0.083 | Loss:3.94968980550766


100%|██████████| 176/176 [00:17<00:00, 10.11it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 9 |  BLeu： 0.084 | Loss:3.98931607874957


100%|██████████| 176/176 [00:17<00:00, 10.28it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 10 |  BLeu： 0.085 | Loss:3.999457611279054


100%|██████████| 176/176 [00:17<00:00, 10.22it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 11 |  BLeu： 0.078 | Loss:4.096289496530186


100%|██████████| 176/176 [00:17<00:00, 10.12it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 12 |  BLeu： 0.085 | Loss:4.092717615040866


100%|██████████| 176/176 [00:17<00:00, 10.15it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 13 |  BLeu： 0.082 | Loss:4.179443573409861


100%|██████████| 176/176 [00:17<00:00, 10.15it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 14 |  BLeu： 0.083 | Loss:4.224167130210183


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 15 |  BLeu： 0.083 | Loss:4.2534102472391995


100%|██████████| 176/176 [00:17<00:00, 10.21it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 16 |  BLeu： 0.082 | Loss:4.323083014650778


100%|██████████| 176/176 [00:17<00:00, 10.16it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 17 |  BLeu： 0.083 | Loss:4.324066549539566


100%|██████████| 176/176 [00:17<00:00, 10.20it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 18 |  BLeu： 0.087 | Loss:4.361155427315018


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 19 |  BLeu： 0.087 | Loss:4.411430314183235


100%|██████████| 176/176 [00:17<00:00, 10.12it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 20 |  BLeu： 0.084 | Loss:4.438928472724828


100%|██████████| 176/176 [00:17<00:00, 10.15it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 21 |  BLeu： 0.083 | Loss:4.465086936950684


100%|██████████| 176/176 [00:17<00:00, 10.17it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 22 |  BLeu： 0.088 | Loss:4.493700897151774


100%|██████████| 176/176 [00:17<00:00, 10.10it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 23 |  BLeu： 0.085 | Loss:4.5440684448588975


100%|██████████| 176/176 [00:17<00:00, 10.16it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 24 |  BLeu： 0.086 | Loss:4.6338210729035465


100%|██████████| 176/176 [00:17<00:00, 10.17it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 25 |  BLeu： 0.087 | Loss:4.62757408618927


100%|██████████| 176/176 [00:17<00:00, 10.21it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 26 |  BLeu： 0.088 | Loss:4.6358950869603595


100%|██████████| 176/176 [00:17<00:00, 10.19it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 27 |  BLeu： 0.090 | Loss:4.709001698277214


100%|██████████| 176/176 [00:17<00:00, 10.14it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 28 |  BLeu： 0.086 | Loss:4.726533678444949


100%|██████████| 176/176 [00:17<00:00, 10.35it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 29 |  BLeu： 0.090 | Loss:4.78228677267378


100%|██████████| 176/176 [00:17<00:00, 10.10it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 30 |  BLeu： 0.086 | Loss:4.788808687166735


100%|██████████| 176/176 [00:17<00:00, 10.31it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 31 |  BLeu： 0.088 | Loss:4.7972861501303585


100%|██████████| 176/176 [00:17<00:00, 10.10it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 32 |  BLeu： 0.086 | Loss:4.868466240438548


100%|██████████| 176/176 [00:17<00:00, 10.22it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 33 |  BLeu： 0.086 | Loss:4.9152302958748555


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 34 |  BLeu： 0.087 | Loss:4.914696276187897


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 35 |  BLeu： 0.087 | Loss:4.905583739280701


100%|██████████| 176/176 [00:17<00:00, 10.18it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 36 |  BLeu： 0.087 | Loss:4.95905527472496


100%|██████████| 176/176 [00:17<00:00, 10.23it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 37 |  BLeu： 0.087 | Loss:5.025691178711978


100%|██████████| 176/176 [00:17<00:00, 10.12it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 38 |  BLeu： 0.085 | Loss:4.997301024469462


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 39 |  BLeu： 0.084 | Loss:5.059968821027062


100%|██████████| 176/176 [00:17<00:00, 10.22it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 40 |  BLeu： 0.086 | Loss:5.09977745197036


100%|██████████| 176/176 [00:17<00:00, 10.26it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 41 |  BLeu： 0.086 | Loss:5.124186821959236


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 42 |  BLeu： 0.086 | Loss:5.116248245943677


100%|██████████| 176/176 [00:17<00:00, 10.19it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 43 |  BLeu： 0.085 | Loss:5.138240061023018


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 44 |  BLeu： 0.085 | Loss:5.188825355334715


100%|██████████| 176/176 [00:17<00:00, 10.15it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 45 |  BLeu： 0.085 | Loss:5.22485243732279


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 46 |  BLeu： 0.087 | Loss:5.203251602974805


100%|██████████| 176/176 [00:17<00:00, 10.17it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 47 |  BLeu： 0.085 | Loss:5.260348501530561


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 48 |  BLeu： 0.087 | Loss:5.2714948220686475


100%|██████████| 176/176 [00:17<00:00, 10.21it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 49 |  BLeu： 0.090 | Loss:5.293866092508489


100%|██████████| 176/176 [00:17<00:00, 10.19it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 50 |  BLeu： 0.087 | Loss:5.307398308407176


100%|██████████| 176/176 [00:17<00:00, 10.09it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 51 |  BLeu： 0.084 | Loss:5.385423625057394


100%|██████████| 176/176 [00:17<00:00, 10.20it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 52 |  BLeu： 0.088 | Loss:5.32087154821916


100%|██████████| 176/176 [00:17<00:00, 10.12it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 53 |  BLeu： 0.087 | Loss:5.366894992915067


100%|██████████| 176/176 [00:17<00:00, 10.14it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 54 |  BLeu： 0.085 | Loss:5.388840613040057


100%|██████████| 176/176 [00:17<00:00, 10.08it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 55 |  BLeu： 0.091 | Loss:5.377936489202759


100%|██████████| 176/176 [00:17<00:00, 10.23it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 56 |  BLeu： 0.087 | Loss:5.371506945653395


100%|██████████| 176/176 [00:17<00:00, 10.08it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 57 |  BLeu： 0.089 | Loss:5.413949318907478


100%|██████████| 176/176 [00:17<00:00, 10.17it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 58 |  BLeu： 0.085 | Loss:5.453006134791807


100%|██████████| 176/176 [00:17<00:00, 10.05it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 59 |  BLeu： 0.088 | Loss:5.46539783206853


100%|██████████| 176/176 [00:17<00:00, 10.30it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 60 |  BLeu： 0.085 | Loss:5.460895557295192


100%|██████████| 176/176 [00:17<00:00, 10.31it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 61 |  BLeu： 0.088 | Loss:5.485906663266095


100%|██████████| 176/176 [00:17<00:00, 10.19it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 62 |  BLeu： 0.086 | Loss:5.497216205705296


100%|██████████| 176/176 [00:17<00:00, 10.15it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 63 |  BLeu： 0.089 | Loss:5.477291991764849


100%|██████████| 176/176 [00:17<00:00, 10.12it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 64 |  BLeu： 0.088 | Loss:5.577911154790358


100%|██████████| 176/176 [00:17<00:00, 10.22it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 65 |  BLeu： 0.085 | Loss:5.612963568080556


100%|██████████| 176/176 [00:17<00:00, 10.14it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 66 |  BLeu： 0.087 | Loss:5.565672932700678


100%|██████████| 176/176 [00:17<00:00, 10.29it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 67 |  BLeu： 0.088 | Loss:5.606389406052503


100%|██████████| 176/176 [00:17<00:00, 10.14it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 68 |  BLeu： 0.086 | Loss:5.600881172852083


100%|██████████| 176/176 [00:17<00:00, 10.34it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 69 |  BLeu： 0.087 | Loss:5.601971333677119


100%|██████████| 176/176 [00:17<00:00, 10.02it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 70 |  BLeu： 0.086 | Loss:5.677679568529129


100%|██████████| 176/176 [00:17<00:00, 10.08it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 71 |  BLeu： 0.087 | Loss:5.645767564123327


100%|██████████| 176/176 [00:17<00:00, 10.03it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 72 |  BLeu： 0.089 | Loss:5.690555317835375


100%|██████████| 176/176 [00:17<00:00, 10.16it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 73 |  BLeu： 0.089 | Loss:5.66342706842856


100%|██████████| 176/176 [00:17<00:00, 10.06it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 74 |  BLeu： 0.088 | Loss:5.708429406989705


100%|██████████| 176/176 [00:17<00:00, 10.10it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 75 |  BLeu： 0.085 | Loss:5.741124583916231


100%|██████████| 176/176 [00:17<00:00, 10.12it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 76 |  BLeu： 0.087 | Loss:5.716917875138196


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 77 |  BLeu： 0.088 | Loss:5.729238361120224


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 78 |  BLeu： 0.086 | Loss:5.758396630937403


100%|██████████| 176/176 [00:17<00:00, 10.21it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 79 |  BLeu： 0.088 | Loss:5.74781379645521


100%|██████████| 176/176 [00:17<00:00, 10.14it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 80 |  BLeu： 0.086 | Loss:5.806213439865545


100%|██████████| 176/176 [00:17<00:00, 10.09it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 81 |  BLeu： 0.088 | Loss:5.755197855559262


100%|██████████| 176/176 [00:17<00:00, 10.21it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 82 |  BLeu： 0.085 | Loss:5.763015419244766


100%|██████████| 176/176 [00:17<00:00, 10.18it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 83 |  BLeu： 0.087 | Loss:5.819564524022016


100%|██████████| 176/176 [00:17<00:00, 10.09it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 84 |  BLeu： 0.086 | Loss:5.881595134735107


100%|██████████| 176/176 [00:17<00:00, 10.23it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 85 |  BLeu： 0.088 | Loss:5.893751919269562


100%|██████████| 176/176 [00:17<00:00, 10.17it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 86 |  BLeu： 0.088 | Loss:5.886167496442795


100%|██████████| 176/176 [00:17<00:00, 10.05it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 87 |  BLeu： 0.086 | Loss:5.8675973388281735


100%|██████████| 176/176 [00:17<00:00, 10.13it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 88 |  BLeu： 0.089 | Loss:5.860954444516789


100%|██████████| 176/176 [00:17<00:00, 10.12it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 89 |  BLeu： 0.086 | Loss:5.9447847144170245


100%|██████████| 176/176 [00:17<00:00, 10.17it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 90 |  BLeu： 0.087 | Loss:5.94047831676223


100%|██████████| 176/176 [00:17<00:00, 10.14it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 91 |  BLeu： 0.089 | Loss:5.9397062117403205


100%|██████████| 176/176 [00:17<00:00, 10.22it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 92 |  BLeu： 0.088 | Loss:5.913985306566412


100%|██████████| 176/176 [00:17<00:00, 10.15it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 93 |  BLeu： 0.087 | Loss:5.90402583100579


100%|██████████| 176/176 [00:17<00:00, 10.05it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 94 |  BLeu： 0.086 | Loss:5.973330054770816


100%|██████████| 176/176 [00:17<00:00, 10.31it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 95 |  BLeu： 0.083 | Loss:5.97617980837822


100%|██████████| 176/176 [00:17<00:00, 10.07it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 96 |  BLeu： 0.088 | Loss:5.934599800543352


100%|██████████| 176/176 [00:17<00:00, 10.08it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 97 |  BLeu： 0.087 | Loss:6.0244188511913475


100%|██████████| 176/176 [00:17<00:00, 10.07it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 98 |  BLeu： 0.086 | Loss:6.019261395389384


100%|██████████| 176/176 [00:17<00:00, 10.01it/s]


Best BLEU: 0.081 | Best Loss:3.899 |  Epoch: 99 |  BLeu： 0.085 | Loss:6.032081181352789


## 推理部分
请补充完成推理（inference）部分的代码，用于在测试集上生成翻译后的文本（该部分与evaluate_iter比较类似）。

In [85]:
def inference(model, iterator, en_field, zh_field):
    model.eval()

    with torch.no_grad():
        predict_res = []
        for _, batch in enumerate(iterator):
            src, src_len = batch.en
            id, trg = batch.id, batch.zh
            
            ## 得到decoder的输出结果
            output,attn = model(batch.en, trg, 0)
            batch_size = output.shape[-2]
            output_dim = output.shape[-1]
            
            for sent in range(batch_size):
                if en_field is not None:
                    eos_index = [x.item() for x in src[:, sent]].index(en_field.vocab.stoi['<eos>'])
                    src_str = ' '.join([en_field.vocab.itos[x.item()] for x in src[1: eos_index, sent]])
                    sent_id = id[sent]
                predicts = []
                grounds = []
                
                ## 请补充生成部分的代码
                trg_tensor = output[:,sent,:]
                trg_indexes = []
                for i in range(output.shape[0]):
                  pred_token = trg_tensor[i,:].argmax().item()
                  if pred_token == zh_field.vocab.stoi[zh_field.eos_token]:
                    break
                  trg_indexes.append(pred_token)
                  
                grd_indexes = []
                for i in range(1,output.shape[0]):
                  pred_token = trg[i,sent]
                  if pred_token == zh_field.vocab.stoi[zh_field.eos_token]:
                    break
                  grd_indexes.append(pred_token)
                  

                predicts = [zh_field.vocab.itos[i] for i in trg_indexes]
                grounds = [zh_field.vocab.itos[i] for i in grd_indexes]
                
                predict_res.append((int(sent_id), src_str, ' '.join(predicts), " ".join(grounds)))

    predict_res = [(item[1],item[2], item[3] ) for item in sorted(predict_res, key=lambda x: x[0])]

    bleu = calculate_bleu([i[1] for i in predict_res], [i[2] for i in predict_res])
    return bleu, predict_res

读取效果最好的模型，在测试集上进行生成：

In [86]:
dataloader = Dataloader(INFERENCE_BATCH_SIZE, device, eval=True)
model_init_path = f"{MODEL_PATH}/model_best.pt"
model = Seq2Seq(encoder, decoder, device, dataloader.zh_field.vocab.stoi['<pad>']).to(device)
model.load_state_dict(torch.load(model_init_path))
bleu, predict_output = inference(model, dataloader.test_iterator, dataloader.en_field, dataloader.zh_field)
for item in predict_output:
    src, trg, golden = item
    print(f"src:{src}\ntrg:{trg}\ngolden:{golden}\n")
print(bleu)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
trg:<unk> 如何 找到 的 执行 而 不 执行 ？
golden:如何 在 不 阻塞 的 情况 下 检查 按键 ？

src:output not more than size bytes to str according to the format string format and the extra arguments
trg:<unk> 不 一定 要 size size 个 字节 字符串 的 大小 格式 字符串
golden:根据 格式 字符串 format 和 额外 参数 ， 输出 不 超过 size 字节 到 str

src:author: guido van rossum
trg:<unk> 作者 : guido van "
golden:作者 : guido van rossum

src:in certain <unk> cases, though, modules are built right in their installation directory, so this is <unk> a useful ability
trg:<unk> 在 <unk> 中 ， ， ， ， 模块 是 ， ， 因此 ， 因此 ， 因此 ， 因此 ， 因此 ， 因此 ， 所以 ， 所以 会 很 有用 的
golden:不过 在 某些 特殊 情况 下 ， 模块 是 在 其 安装 目录 中 被 构建 的 ， 因此 这 可能 会 是 个 有用 的 功能

src:it is the <unk> responsibility to ensure that all whitespace and special characters are quoted appropriately to avoid shell <unk> <unk>
trg:<unk> 它 被 设计 的 所有 特殊字符 ， 所有 特殊字符 的 所有 特殊字符 ， 并且 ， 并且 ， 并且 被 视为 的 的 ， 并且 需要 <unk> 避免 干扰 避免
golden:应用程序 要 负责 确保 正确 地 转义 所有 空白 字符 和 特殊字符 以 防止 shell 注入