# Seq2Seq模型实现文本翻译

## 概述

In [1]:
urls = {
    'train': 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',
    'valid': 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',
    'test': 'http://www.quest.dcs.shef.ac.uk/wmt17_files_mmt/mmt_task1_test2016.tar.gz'
}

## 数据准备

### 词典

In [2]:
class Vocab:
    def __init__(self, word_count_dict, min_freq=1, special_tokens=['<unk>', '<pad>', '<bos>', '<eos>']):
        self.word2idx = {}
        for idx, tok in enumerate(special_tokens):
            self.word2idx[tok] = idx

        filted_dict = {w: c for w, c in word_count_dict.items() if c >= min_freq}
        for w, _ in filted_dict.items():
            self.word2idx[w] = len(self.word2idx)
        
        self.idx2word = {idx: word for word, idx in self.word2idx.items()}

        self.bos_idx = self.word2idx['<bos>']
        self.eos_idx = self.word2idx['<eos>']
        self.pad_idx = self.word2idx['<pad>']
        self.unk_idx = self.word2idx['<unk>']

    def _word2idx(self, word):
        if word not in self.word2idx:
            return self.unk_idx
        return self.word2idx[word]
    
    def _idx2word(self, idx):
        if idx not in self.idx2word:
            raise ValueError('input index is not in vocabulary.')
        return self.idx2word[idx]
    
    def encode(self, word_or_list):
        if isinstance(word_or_list, list):
            return [self._word2idx(i) for i in word_or_list]
        return self._word2idx(word_or_list)
    
    def decode(self, idx_or_list):
        if isinstance(idx_or_list, list):
            return [self._idx2word(i) for i in idx_or_list]
        return self._idx2word(idx_or_list)
    
    def __len__(self):
        return len(self.word2idx)

In [3]:
word_count = {'a':20, 'b':10, 'c':1, 'd':2}

vocab = Vocab(word_count, min_freq=2)
len(vocab)

7

### Multi30K数据集

> pip install spacy

> python -m spacy download de_core_news_sm

> python -m spacy download en_core_web_sm

In [4]:
import re
import six
import string
import tarfile
import spacy
from functools import partial

class Multi30K():
    """Multi30K数据集加载器
    
    加载IMDB数据集并处理为一个Python迭代对象。
    
    """
    def __init__(self, path):
        self.data = self._load(path)
        
    def _load(self, path):
        def tokenize(text, spacy_lang):
            text = text.rstrip()
            return [tok.text.lower() for tok in spacy_lang.tokenizer(text)]
        
        tokenize_de = partial(tokenize, spacy_lang=spacy.load('de_core_news_sm'))
        tokenize_en = partial(tokenize, spacy_lang=spacy.load('en_core_web_sm'))
        
        tarf = tarfile.open(path)
        members = {i.name.split('.')[-1]: i for i in tarf.getmembers()}
        de = tarf.extractfile(members['de']).readlines()[:-1]
        de = [tokenize_de(i.decode()) for i in de]
        en = tarf.extractfile(members['en']).readlines()[:-1]
        en = [tokenize_en(i.decode()) for i in en]

        return list(zip(de, en))
        
    def __getitem__(self, idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)

### 数据迭代器

In [5]:
import mindspore

class Iterator():
    def __init__(self, dataset, de_vocab, en_vocab, batch_size, max_len=32, drop_reminder=False):
        self.dataset = dataset
        self.de_vocab = de_vocab
        self.en_vocab = en_vocab
        
        self.batch_size = batch_size
        self.max_len = max_len
        self.drop_reminder = drop_reminder

        length = len(self.dataset) // batch_size 
        self.len = length if drop_reminder else length + 1
    
    def __call__(self):
        def pad(idx_list, vocab, max_len):
            idx_pad_list, idx_len = [], []
            # max_len = max([len(i) for i in idx_list]) + 2
            for i in idx_list:
                if len(i) > max_len - 2:
                    idx_pad_list.append(
                        [vocab.bos_idx] + i[:max_len-2] + [vocab.eos_idx]
                    )
                    idx_len.append(max_len)
                else:
                    idx_pad_list.append(
                        [vocab.bos_idx] + i + [vocab.eos_idx] + [vocab.pad_idx] * (max_len - len(i) - 2)
                    )
                    idx_len.append(len(i) + 2)
            return idx_pad_list, idx_len

        def sort_by_length(src, trg):
            data = zip(src, trg)
            data = sorted(data, key=lambda t: len(t[0]), reverse=True)
            return zip(*list(data))
            
        def encode_and_pad(batch_data, max_len):
            src_data, trg_data = zip(*batch_data)
            src_idx = [self.de_vocab.encode(i) for i in src_data]
            trg_idx = [self.en_vocab.encode(i) for i in trg_data]
            
            src_idx, trg_idx = sort_by_length(src_idx, trg_idx)
            src_idx_pad, src_len = pad(src_idx, de_vocab, max_len)
            trg_idx_pad, _ = pad(trg_idx, en_vocab, max_len)
            
            return src_idx_pad, src_len, trg_idx_pad
        
        for i in range(self.len):
            if i == self.len - 1 and not self.drop_reminder:
                batch_data = self.dataset[i * self.batch_size:]
            else:
                batch_data = self.dataset[i * self.batch_size: (i+1) * self.batch_size]
            
            src_idx, src_len, trg_idx = encode_and_pad(batch_data, self.max_len)
            yield mindspore.Tensor(src_idx, mindspore.int32), \
                mindspore.Tensor(src_len, mindspore.int32), \
                mindspore.Tensor(trg_idx, mindspore.int32)
    
    def __len__(self):
        return self.len

### 数据下载模块

In [6]:
import os
import logging
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path

# 指定保存路径为 `home_path/.mindspore_examples`
cache_dir = Path.home() / '.mindspore_examples'

def http_get(url: str, temp_file:IO):
    """使用requests库下载数据，并使用tqdm库进行流程可视化"""
    req = requests.get(url, stream=True)
    content_length = req.headers.get('Content-Length')
    total = int(content_length) if content_length is not None else None
    progress = tqdm(unit='B', total=total)
    for chunk in req.iter_content(chunk_size=1024):
        if chunk:
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()

def download(file_name:str, url: str):
    """下载数据并存为指定名称"""
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
    cache_path = os.path.join(cache_dir, file_name)
    cache_exist = os.path.exists(cache_path)
    if not cache_exist:
        with tempfile.NamedTemporaryFile() as temp_file:
            http_get(url, temp_file)
            temp_file.flush()
            temp_file.seek(0)
            logging.info(f"copying {temp_file.name} to cache at {cache_path}")
            with open(cache_path, 'wb') as cache_file:
                shutil.copyfileobj(temp_file, cache_file)
    return cache_path

In [7]:
def download_dataset(urls):
    train_path = download('train.tar.gz', urls['train'])
    valid_path = download('valid.tar.gz', urls['valid'])    
    test_path = download('test.tar.gz', urls['test'])
    
    return Multi30K(train_path), Multi30K(valid_path), Multi30K(test_path)

In [8]:
train_dataset, valid_dataset, test_dataset = download_dataset(urls)

In [9]:
for de, en in test_dataset:
    print(de, en)
    break

['ein', 'mann', 'mit', 'einem', 'orangefarbenen', 'hut', ',', 'der', 'etwas', 'anstarrt', '.'] ['a', 'man', 'in', 'an', 'orange', 'hat', 'starring', 'at', 'something', '.']


### 词典构建

In [10]:
from collections import Counter, OrderedDict

def build_vocab(dataset):
    de_words, en_words = [], []
    for de, en in dataset:
        de_words.extend(de)
        en_words.extend(en)
        
    de_count_dict = OrderedDict(sorted(Counter(de_words).items(), key=lambda t: t[1], reverse=True))
    en_count_dict = OrderedDict(sorted(Counter(en_words).items(), key=lambda t: t[1], reverse=True))
    
    return Vocab(de_count_dict, min_freq=2), Vocab(en_count_dict, min_freq=2)

In [11]:
de_vocab, en_vocab = build_vocab(train_dataset)
len(de_vocab)

7853

## 模型构建

### Encoder

In [12]:
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp

class Encoder(nn.Cell):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True).to_float(compute_dtype)
        self.fc = nn.Dense(enc_hid_dim * 2, dec_hid_dim).to_float(compute_dtype)

        self.dropout = nn.Dropout(1-dropout)
        
    def construct(self, src, src_len):
        #src = [src len, batch size]
        #src_len = [batch size]
        embedded = self.dropout(self.embedding(src))
        #embedded = [src len, batch size, emb dim]
                        
        outputs, hidden = self.rnn(embedded, seq_length=src_len)
                                 
        #outputs = [src len, batch size, hid dim * num directions]
        #hidden = [n layers * num directions, batch size, hid dim]
        
        #hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        #outputs are always from the last layer
        
        #hidden [-2, :, : ] is the last of the forwards RNN 
        #hidden [-1, :, : ] is the last of the backwards RNN
        
        #initial decoder hidden is final hidden state of the forwards and backwards 
        #  encoder RNNs fed through a linear layer
        hidden = ops.tanh(self.fc(mnp.concatenate((hidden[-2,:,:], hidden[-1,:,:]), axis = 1)))
        
        #outputs = [src len, batch size, enc hid dim * 2]
        #hidden = [batch size, dec hid dim]
        # print(hidden)
        return outputs, hidden

## Attention

In [13]:
class Attention(nn.Cell):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        self.attn = nn.Dense((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim).to_float(compute_dtype)
        self.v = nn.Dense(dec_hid_dim, 1, has_bias = False).to_float(compute_dtype)
        
    def construct(self, hidden, encoder_outputs, mask):
        
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        #repeat decoder hidden state src_len times
        hidden = mnp.tile(hidden.expand_dims(1), (1, src_len, 1))
  
        encoder_outputs = encoder_outputs.transpose(1, 0, 2)
        
        #hidden = [batch size, src len, dec hid dim]
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        
        energy = ops.tanh(self.attn(mnp.concatenate((hidden, encoder_outputs), axis = 2))) 
        
        #energy = [batch size, src len, dec hid dim]

        attention = self.v(energy).squeeze(2)
        
        #attention = [batch size, src len]
        
        attention = attention.masked_fill(mask == 0, -1e10)
        
        return ops.Softmax(1)(attention)

### Decoder

In [14]:
class Decoder(nn.Cell):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim).to_float(compute_dtype)
        self.fc_out = nn.Dense((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim).to_float(compute_dtype)
        self.dropout = nn.Dropout(1-dropout)
        
    def construct(self, inputs, hidden, encoder_outputs, mask):
             
        #input = [batch size]
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        #mask = [batch size, src len]
        
        inputs = inputs.expand_dims(0)
        #input = [1, batch size]
        
        embedded = self.dropout(self.embedding(inputs))
        
        #embedded = [1, batch size, emb dim]
        a = self.attention(hidden, encoder_outputs, mask)
                
        #a = [batch size, src len]
        
        a = a.expand_dims(1)
        
        #a = [batch size, 1, src len]
        
        encoder_outputs = encoder_outputs.transpose(1, 0, 2)
        
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        
        weighted = ops.BatchMatMul()(a, encoder_outputs)
        
        #weighted = [batch size, 1, enc hid dim * 2]
        
        weighted = weighted.transpose(1, 0, 2)
        
        #weighted = [1, batch size, enc hid dim * 2]
        
        rnn_input = mnp.concatenate((embedded, weighted), axis = 2)
        
        #rnn_input = [1, batch size, (enc hid dim * 2) + emb dim]
            
        output, hidden = self.rnn(rnn_input, hidden.expand_dims(0))
        
        #output = [seq len, batch size, dec hid dim * n directions]
        #hidden = [n layers * n directions, batch size, dec hid dim]
        
        #seq len, n layers and n directions will always be 1 in this decoder, therefore:
        #output = [1, batch size, dec hid dim]
        #hidden = [1, batch size, dec hid dim]
        #this also means that output == hidden
        
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
        prediction = self.fc_out(mnp.concatenate((output, weighted, embedded), axis = 1))
        
        #prediction = [batch size, output dim]
        
        return prediction, hidden.squeeze(0), a.squeeze(1)

### Seq2Seq

In [15]:
import random

class Seq2Seq(nn.Cell):
    def __init__(self, encoder, decoder, src_pad_idx, teacher_forcing_ration):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.teacher_forcing_ratio = teacher_forcing_ration
        self.random = ops.UniformReal()
        
    def create_mask(self, src):
        mask = (src != self.src_pad_idx).astype(mindspore.int32).swapaxes(1, 0)
        return mask
        
    def construct(self, src, src_len, trg, trg_len=None):
        #src = [src len, batch size]
        #src_len = [batch size]
        #trg = [trg len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        #e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time
        if trg_len is None:
            trg_len = trg.shape[0]
        
        #tensor to store decoder outputs
        outputs = []
        
        #encoder_outputs is all hidden states of the input sequence, back and forwards
        #hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src, src_len)
                
        #first input to the decoder is the <sos> tokens
        inputs = trg[0]
        
        mask = self.create_mask(src)

        #mask = [batch size, src len]
                
        for t in range(1, trg_len):
            
            #insert input token embedding, previous hidden state, all encoder hidden states 
            #  and mask
            #receive output tensor (predictions) and new hidden state
            output, hidden, _ = self.decoder(inputs, hidden, encoder_outputs, mask)
            # print(output)
            #place predictions in a tensor holding predictions for each token
            outputs.append(output)
            
            #get the highest predicted token from our predictions
            top1 = output.argmax(1) 

            if self.training:
                #decide if we are going to use teacher forcing or not
                teacher_force = self.random((1,)) < self.teacher_forcing_ratio
                # teacher_force = random.random() < self.teacher_forcing_ratio
                #if teacher forcing, use actual next token as next input
                #if not, use predicted token
                inputs = trg[t] if teacher_force else top1
            else:
                inputs = top1
        
        outputs = mnp.stack(outputs, axis=0)
            
        return outputs.astype(dtype)

## CrossEntropy损失函数

In [16]:
class CrossEntropy(nn.Cell):
    reduction_list = ['sum', 'mean', 'none']
    def __init__(self, weight=None, ignore_index:int=-100, reduction:str='mean', label_smoothing:float=0.0):        
        super().__init__()
        if label_smoothing > 1.0 or label_smoothing < 0.0:
            raise ValueError(f'label_smoothing value must in range [0.0, 1.0], '
                             f'but get {label_smoothing}')
        
        if reduction not in self.reduction_list:
            raise ValueError(f'Unsupported reduction {reduction}')
        
        self.weight = weight
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.label_smoothing = label_smoothing

    def construct(self, input, target):
        return self._nll_loss(ops.LogSoftmax(1)(input), target, -1, self.weight, self.ignore_index, self.reduction, self.label_smoothing)

    def _nll_loss(self, input, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0):
        if target.ndim == input.ndim - 1:
            target = target.expand_dims(target_dim)
        nll_loss = -ops.gather_d(input, target_dim, target)
        smooth_loss = -input.sum(axis=target_dim, keepdims=True)
        if weight is not None:
            loss_weights = ops.gather(weight, target, 0)
            nll_loss = nll_loss * loss_weights
        else:
            loss_weights = ops.ones_like(nll_loss)
        if ignore_index is not None:
            non_pad_mask = ops.equal(target, ignore_index)
            nll_loss = nll_loss.masked_fill(non_pad_mask, 0.)
            loss_weights = loss_weights.masked_fill(non_pad_mask, 0.)
            smooth_loss = smooth_loss.masked_fill(non_pad_mask, 0.)
        else:
            nll_loss = nll_loss.squeeze(target_dim)
            smooth_loss = smooth_loss.squeeze(target_dim)

        if reduction == 'sum':
            nll_loss = nll_loss.sum()
            smooth_loss = smooth_loss.sum()
        if reduction == 'mean':
            nll_loss = nll_loss.sum() / loss_weights.sum()
            smooth_loss = smooth_loss.mean()

        eps_i = label_smoothing / input.shape[target_dim]
        loss = (1. - label_smoothing) * nll_loss + eps_i * smooth_loss

        return loss

## 整图训练封装

### 自定义Loss求解

In [17]:
class Seq2SeqWithLoss(nn.Cell):
    def __init__(self, network, loss):
        super().__init__()
        self.network = network
        self.loss = loss
        
    def construct(self, src, src_len, trg):
        output = self.network(src, src_len, trg)
        output_dim = output.shape[-1]
        output = output.view(-1, output_dim)
        trg = trg[1:].view(-1)
        loss = self.loss(output, trg)
        return loss

### 自定义梯度裁剪

In [18]:
from mindspore import Tensor

def clip_by_norm(clip, grad):
    return nn.ClipByNorm()(grad, clip)

class TrainOneStepCell(nn.TrainOneStepCell):
    def __init__(self, network, optimizer, sens=1.0, clip=1.0):
        super(TrainOneStepCell, self).__init__(network, optimizer, sens)
        self.hyper_map = ops.HyperMap()
        self.clip = Tensor(clip, mindspore.float32)
        self.sens = Tensor(sens, mindspore.float32)

    def construct(self, *inputs):
        weights = self.weights
        loss = self.network(*inputs)
        grads = self.grad(self.network, weights)(*inputs, self.sens)
        # 进行梯度截断
        grads = self.hyper_map(ops.partial(clip_by_norm, self.clip), grads)
        grads = self.grad_reducer(grads)
        self.optimizer(grads)
        return loss

## 模型训练

In [19]:
input_dim = len(de_vocab)
output_dim = len(en_vocab)
enc_emb_dim = 256
dec_emb_dim = 256
enc_hid_dim = 512
dec_hid_dim = 512
enc_dropout = 0.5
dec_dropout = 0.5
src_pad_idx = de_vocab.pad_idx
trg_pad_idx = en_vocab.pad_idx

compute_dtype = mindspore.float32
dtype = mindspore.float32

attn = Attention(enc_hid_dim, dec_hid_dim)
encoder = Encoder(input_dim, enc_emb_dim, enc_hid_dim, dec_hid_dim, enc_dropout)
decoder = Decoder(output_dim, dec_emb_dim, enc_hid_dim, dec_hid_dim, dec_dropout, attn)

model = Seq2Seq(encoder, decoder, src_pad_idx, 0.5)

In [20]:
# from mindspore.common.initializer import Normal, Zero, initializer

# def init_weights(m):
#     for param in m.trainable_params():
#         if 'weight' in param.name:
#             param.set_data(initializer(Normal(), param.shape))
#         else:
#             param.set_data(initializer(Zero(), param.shape))

# init_weights(model)

In [21]:
loss = CrossEntropy(ignore_index=trg_pad_idx)
model_with_loss = Seq2SeqWithLoss(model, loss)
optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)
trainer = TrainOneStepCell(model_with_loss, optimizer)

In [22]:
def train_one_step(model, iterator, epoch=0):
    model.set_train(True)
    total = len(iterator)
    loss_total = 0
    step_total = 0
    with tqdm(total=total) as t:
        t.set_description('Epoch %i' % epoch)
        for src, src_len, trg in iterator():
            src = src.swapaxes(0, 1)
            trg = trg.swapaxes(0, 1)
            # print(src.shape, trg.shape)
            loss = model(src, src_len, trg)
            loss_total += loss.asnumpy()
            step_total += 1
            t.set_postfix(loss=loss_total/step_total)
            t.update(1)

In [23]:
def evaluate(model, iterator):
    model.set_train(False)
    total = len(iterator)
    loss_total = 0
    step_total = 0
    with tqdm(total=total) as t:
        for src, src_len, trg in iterator():
            src = src.swapaxes(0, 1)
            trg = trg.swapaxes(0, 1)
            loss = model(src, src_len, trg)
            loss_total += loss.asnumpy()
            step_total += 1
            t.set_postfix(loss=loss_total/step_total)
            t.update(1)
    return loss_total / step_total

In [24]:
train_iterator = Iterator(train_dataset, de_vocab, en_vocab, batch_size=128, max_len=32, drop_reminder=True)
valid_iterator = Iterator(valid_dataset, de_vocab, en_vocab, batch_size=128, max_len=32, drop_reminder=False)
test_iterator = Iterator(test_dataset, de_vocab, en_vocab, batch_size=128, max_len=32, drop_reminder=False)

In [25]:
from mindspore import save_checkpoint, context
# context.set_context(mode=context.PYNATIVE_MODE)
num_epochs = 10
best_valid_loss = float('inf')
ckpt_file_name = os.path.join(cache_dir, 'seq2seq.ckpt')

for i in range(num_epochs):
    train_one_step(trainer, train_iterator, i)
    valid_loss = evaluate(model_with_loss, valid_iterator)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        save_checkpoint(model, ckpt_file_name)

Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 226/226 [01:04<00:00,  3.50it/s, loss=4.9]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:11<00:00,  1.41s/it, loss=4.72]
Epoch 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 226/226 [00:44<00:00,  5.09it/s, loss=3.87]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 13.82it/s, loss=4]
Epoch 2: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 226/226 [00:44<00:00,  5.07it/s, loss=3.19]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 14.13it/s, loss=3.68]
Epoch 3: 100%|██████████████████████████████████████████████████████████████

## 模型推理

In [26]:
def translate_sentence(sentence, de_vocab, en_vocab, model, max_len=32):
    model.set_train(False)
    if isinstance(sentence, str):
        spacy_lang = spacy.load('de')
        tokens = [token.text.lower() for token in spacy_lang(sentence)]
    else:
        tokens = [token.lower() for token in sentence]
    
    if len(tokens) > max_len - 2:
        src_len = max_len
        tokens = ['<bos>'] + tokens[:max_len - 2] + ['<eos>']
    else:
        src_len = len(tokens) + 2
        tokens = ['<bos>'] + tokens + ['<eos>'] + ['<pad>'] * (max_len - src_len)
        
    src = de_vocab.encode(tokens)
    src = mindspore.Tensor(src, mindspore.int32).expand_dims(1)
    src_len = mindspore.Tensor([src_len], mindspore.int32)
    trg = mindspore.Tensor([en_vocab.bos_idx], mindspore.int32).expand_dims(1)

    outputs = model(src, src_len, trg, max_len)
    trg_indexes = [int(i.argmax(1).asnumpy()) for i in outputs]
    eos_idx = trg_indexes.index(en_vocab.eos_idx) if en_vocab.eos_idx in trg_indexes else -1
    trg_tokens = en_vocab.decode(trg_indexes[:eos_idx])
    
    return trg_tokens

In [27]:
from mindspore import load_checkpoint, load_param_into_net

param_dict = load_checkpoint(ckpt_file_name)
load_param_into_net(model, param_dict)

[]

In [28]:
# test_loss = evaluate(model_with_loss, test_iterator)

In [29]:
example_idx = 0

src = test_dataset[example_idx][0]
trg = test_dataset[example_idx][1]

print(f'src = {src}')
print(f'trg = {trg}')

src = ['ein', 'mann', 'mit', 'einem', 'orangefarbenen', 'hut', ',', 'der', 'etwas', 'anstarrt', '.']
trg = ['a', 'man', 'in', 'an', 'orange', 'hat', 'starring', 'at', 'something', '.']


In [30]:
translation = translate_sentence(src, de_vocab, en_vocab, model)

print(f'predicted trg = {translation}')

predicted trg = ['a', 'man', 'in', 'an', 'orange', 'hat', ',', 'is', '<unk>', '.']


### BLEU得分

> pip install nltk

In [31]:
from nltk.translate.bleu_score import corpus_bleu

def calculate_bleu(dataset, de_vocab, en_vocab, model, max_len=50):
    trgs = []
    pred_trgs = []
    
    for data in dataset:
        
        src = data[0]
        trg = data[1]
        
        pred_trg = translate_sentence(src, de_vocab, en_vocab, model, max_len)
                
        pred_trgs.append(pred_trg)
        trgs.append([trg])
        
    return corpus_bleu(trgs, pred_trgs)
    # return bleu_score(pred_trgs, trgs)

In [32]:
bleu_score = calculate_bleu(test_dataset, de_vocab, en_vocab, model)

print(f'BLEU score = {bleu_score*100:.2f}')

BLEU score = 30.40
