In [3]:
import json
import re
import os
import math
import jieba
from tqdm import tqdm
from collections import Counter
from torchtext.data import get_tokenizer
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset,DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt
import IPython.display as display

In [4]:
SOS_ID = 0
EOS_ID = 1
UNK_ID = 2
PAD_ID = 3
en_vocab_size, ch_vocab_size = 7184, 16251
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 1e-6
d_model = 256
d_ff = 1024
num_heads = 8
num_layers = 6
num_epochs = 600
batch_size = 64

# 数据预处理

## 处理数据集

In [5]:
def is_contain_chinese(check_str):
    for ch in check_str:
        if u'\u4e00' <= ch <= u'\u9fff':
            return True
    return False

def is_contain_english(check_str):
    return bool(re.search(f'[a-zA-Z]+',check_str))

# with open('cmn-eng/cmn.txt','r') as f:
#     i=0
#     english = []
#     chinese = []
#     for line in f.readlines():
#         en_sen = ""
#         ch_sen = ""
#         is_eng =True
#         for word in line.split():
#             if is_eng:
#                 if is_contain_chinese(word):
#                     is_eng=False
#                     ch_sen += word + " "
#                 else:
#                     en_sen += word + " "
#             else:
#                 if is_contain_english(word):
#                     break
#                 else:
#                     ch_sen += word + " "

#         english.append(en_sen)
#         chinese.append(ch_sen)
    
#     tokenizer = get_tokenizer('basic_english')
#     with open('cmn-eng/english.txt', 'w') as ef:
#         for sen in english:
#             sen = tokenizer(sen)
#             for word in sen:
#                 ef.write(word + " ")
#             ef.write("\n")
#         ef.close()

#     with open('cmn-eng/chinese.txt', 'w') as cf:
#         for sen in chinese:
#             sen = list(jieba.cut(sen))
#             for word in sen:
#                 cf.write(word + " ")
#             cf.write("\n")
#         cf.close()
#     f.close()

In [6]:
def create_vocab(sentences, max_element=None):
    """Note that max_element includes special characters"""
    default_list = ['<sos>', '<eos>', '<unk>', '<pad>']
    char_set = Counter()
    for sentence in sentences:
        temp_set = Counter(sentence)
        char_set.update(temp_set)

    if max_element is None:
        return default_list + list(char_set.keys())
    else:
        max_element -= 4
        words_freq = char_set.most_common(max_element) # 出现频率最大的前n个 返回元组类型
        words, freq = zip(*words_freq)
        return default_list + list(words)
    
def save_vocab(vocab, name):
    with open(name, 'w') as f:
        for a in vocab:
            f.write(a + " ")
            
def load_vocab(name):
    with open(name, 'r') as f:
        a = f.read()
    return a.split()

def sentence_to_tensor(sentences, vocab):
    indexs = []
    for sentence in sentences:
        index = []
        for char in sentence:
            if char in vocab:
                index.append(vocab.index(char))
            else:
                index.append(UNK_ID)
        indexs.append(np.array(index))
    return indexs

def tensor_to_sentence(indexs, vocab):
    sentence = ""
    for index in indexs:
        sentence += vocab[index] + " "
    return sentence[:-1]

In [7]:
## 构造英文词典
# max_element = en_vocab_size 
# default_list = ['<sos>', '<eos>', '<unk>', '<pad>']
# char_set = Counter()
# with open('cmn-eng/english.txt','r') as ff:
#     for line in ff.readlines():
#         sentence = line.split()
#         temp_set = Counter(sentence)
#         char_set.update(temp_set)
# print(len(char_set))
# max_element -= 4
# words_freq = char_set.most_common(max_element) # 出现频率最大的前n个 返回元组类型
# words, freq = zip(*words_freq)
# en_vocab = default_list + list(words)
# save_vocab(en_vocab, 'cmn-eng/en_vocab.txt')

## 构造中文词典
# max_element = ch_vocab_size
# default_list = ['<sos>', '<eos>', '<unk>', '<pad>']
# char_set = Counter()
# with open('cmn-eng/chinese.txt','r') as ff:
#     for line in ff.readlines():
#         sentence = line.split()
#         temp_set = Counter(sentence)
#         char_set.update(temp_set)
# print(len(char_set))
# max_element -= 4
# words_freq = char_set.most_common(max_element) # 出现频率最大的前n个 返回元组类型
# words, freq = zip(*words_freq)
# zh_vocab = default_list + list(words)
# save_vocab(zh_vocab, 'cmn-eng/ch_vocab.txt')


# # 将分词后的句子 转换为词表索引序列保存
# en_vocab, ch_vocab = load_vocab('cmn-eng/en_vocab.txt'), load_vocab('cmn-eng/ch_vocab.txt')
# with open('cmn-eng/english.txt','r') as ff:
#     if os.path.exists('cmn-eng/en_tensor.txt'):
#         os.remove('cmn-eng/en_tensor.txt')
#     for line in ff.readlines():
#         if len(line) <= 1:
#             continue
#         tensor = sentence_to_tensor([line.split()], en_vocab)[0]
#         string = ''
#         for num in tensor:
#             string += str(num) + ' '
#         with open('cmn-eng/en_tensor.txt','a') as f:
#             f.write(string + '\n')
#             f.close()
# with open('cmn-eng/chinese.txt','r') as ff:
#     if os.path.exists('cmn-eng/ch_tensor.txt'):
#         os.remove('cmn-eng/ch_tensor.txt')
#     for line in tqdm(ff.readlines(), leave=False):
#         if len(line) <= 1:
#             continue
#         tensor = sentence_to_tensor([line.split()], ch_vocab)[0]
#         string = ''
#         for num in tensor:  
#             string += str(num) + ' '
#         with open('cmn-eng/ch_tensor.txt','a') as f:
#             f.write(string + '\n')
#             f.close()

## 数据集构造

In [8]:
class MyDataset(Dataset):
    def __init__(self, en_path, zh_path):
        super().__init__()
        self.en_tensor = []
        self.zh_tensor = []
        
        with open(en_path, 'r') as f:
            for line in f.readlines():
                self.en_tensor.append([ int(num) for num in line.split()])
            f.close()
        with open(zh_path, 'r') as f:
            for line in f.readlines():
                self.zh_tensor.append([ int(num) for num in line.split()])
            f.close()
            
    def __len__(self):
        return len(self.en_tensor)

    def __getitem__(self, index):
        x = np.concatenate(([SOS_ID], self.en_tensor[index], [EOS_ID]))
        x = torch.from_numpy(x)
        y = np.concatenate(([SOS_ID], self.zh_tensor[index], [EOS_ID]))
        y = torch.from_numpy(y)
        return x, y

def collate_fn(batch):
    x, y = zip(*batch)
    x_pad = pad_sequence(x, batch_first=True, padding_value=PAD_ID)
    y_pad = pad_sequence(y, batch_first=True, padding_value=PAD_ID)
    return x_pad, y_pad

# dataset = MyDataset("cmn-eng/en_tensor.txt", "cmn-eng/ch_tensor.txt")
# loader = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=False, collate_fn=collate_fn)
# en_vocab ,ch_vocab = load_vocab('cmn-eng/en_vocab.txt'), load_vocab('cmn-eng/ch_vocab.txt')
# for x, y in loader:
#     print(tensor_to_sentence(x[0],en_vocab))
#     print(tensor_to_sentence(y[0],ch_vocab))
#     break

# transformer

In [9]:
# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0, max_len=1000):
        super().__init__()
        assert d_model % 2 == 0
        self.dropout = nn.Dropout(dropout)
        P = torch.zeros((1, max_len, d_model)) # batch设为1，广播
        # (max_len, 1) / (d_model/2) => (max_len, d_model/2)
        X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, d_model, 2, dtype=torch.float32) / d_model) 
        P[:, :, 0::2] = torch.sin(X)
        P[:, :, 1::2] = torch.cos(X)
        self.register_buffer('P', P, False) # 模型训练时不会更新 不将变量加入 state_dict
 
    def forward(self, x):
#         x *= x.shape[2] ** 0.5 嵌入层 的操作
        x = x + self.P[:, :x.shape[1], :].to(x.device)
        return self.dropout(x)

In [10]:
# 前馈网络
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.layer1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.layer2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = self.layer1(x)
        x = self.dropout(F.relu(x))
        x = self.layer2(x)
        return x
    
# 残差连接 和 层规范化
class AddNorm(nn.Module):
    def __init__(self, normalized_shape, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)
        
    def forward(self, x, y):
        return self.ln(x + self.dropout(y))

In [11]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, d_ff, num_heads, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.addnorm1 = AddNorm(d_model, dropout)
        self.addnorm2 = AddNorm(d_model, dropout)
    
    def forward(self, x, padding_mask):
        y = self.addnorm1(x, self.attention(x, x, x, key_padding_mask=padding_mask)[0])
        return self.addnorm2(y, self.ffn(y))
    
    
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, d_ff, num_heads, num_layers, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pe = PositionalEncoding(d_model)
        self.blocks = nn.Sequential()
        for i in range(num_layers):
            self.blocks.add_module(f"block{i}", EncoderBlock(d_model, d_ff, num_heads, dropout))
    
    def forward(self, x, padding_mask=None):
        x = self.pe(self.embedding(x) * (self.d_model ** 0.5))
        for block in self.blocks:
            x = block(x, padding_mask)
        return x

In [12]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, d_ff, num_heads, dropout=0.1):
        super().__init__()
        self.attention1= nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
        self.attention2= nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.addnorm1 = AddNorm(d_model, dropout)
        self.addnorm2 = AddNorm(d_model, dropout)        
        self.addnorm3 = AddNorm(d_model, dropout)

    def forward(self, x, encoder_kv, attn_mask, padding_mask):
        y = self.addnorm1(x, self.attention1(x, x, x, attn_mask=attn_mask)[0])
        z = self.addnorm2(y, self.attention2(y, encoder_kv, encoder_kv, key_padding_mask=padding_mask)[0])
        return self.addnorm3(z, self.ffn(z))
    
    
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, d_ff, num_heads, num_layers, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pe = PositionalEncoding(d_model)
        self.blocks = nn.Sequential()
        for i in range(num_layers):
            self.blocks.add_module(f"block{i}", DecoderBlock(d_model, d_ff, num_heads, dropout))
    
    def forward(self, x, encoder_kv, attn_mask=None, padding_mask=None):
        x = self.pe(self.embedding(x) * (self.d_model ** 0.5))
        for block in self.blocks:
            x = block(x, encoder_kv, attn_mask, padding_mask)
        return x

In [13]:
# 手搓的Transformer
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, d_ff, num_heads, num_layers, dropout=0.1):
        super().__init__()
        self.encoder = TransformerEncoder(src_vocab_size, d_model, d_ff, num_heads, num_layers, dropout)
        self.decoder = TransformerDecoder(tgt_vocab_size, d_model, d_ff, num_heads, num_layers, dropout)
        self.dense = nn.Sequential(nn.Linear(d_model, d_model*4), nn.ReLU(), nn.Linear(d_model*4, tgt_vocab_size))
        self.dense = nn.Linear(d_model, tgt_vocab_size)


    def get_padding_mask(self, src):
        return src == PAD_ID
    
    def get_attn_mask(self, src, tgt):
        return (1 - torch.tril(torch.ones((tgt.shape[1], src.shape[1]))).to(src.device)) == 1
    
    def forward(self, x, y):
        src_padding_mask = self.get_padding_mask(x)
        encoder_kv = self.encoder(x, src_padding_mask)
        
        tgt_attn_mask = self.get_attn_mask(y, y)
        res = self.decoder(y, encoder_kv, tgt_attn_mask, src_padding_mask)
        return F.log_softmax(self.dense(res), dim=-1)
        # return torch.matmul(res, self.decoder.embedding.weight.transpose(0, 1))

In [14]:
# 调库的Transformer
class TF(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_src = nn.Embedding(en_vocab_size, d_model).to(device)
        self.embedding_tgt = nn.Embedding(ch_vocab_size, d_model).to(device)
        self.pe = PositionalEncoding(d_model)
        self.model = nn.Transformer(d_model=d_model, nhead=num_heads, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dim_feedforward=d_ff, batch_first=True)
        self.dense = nn.Linear(d_model, ch_vocab_size)
    
    def forward(self, src, tgt, tgt_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask):
        output = self.model(src=self.pe(self.embedding_src(src)), tgt=self.pe(self.embedding_tgt(tgt)), tgt_mask=tgt_mask,
                       src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        return F.log_softmax(self.dense(output), dim=2)

# train

In [15]:
class LabelSmoothing(nn.Module):
    """标签平滑处理"""
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction='sum')
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False))
    
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))

In [23]:
model = Transformer(en_vocab_size, ch_vocab_size, d_model, d_ff, num_heads, num_layers)
# model = TF()
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform_(m.weight)
if os.path.exists('transformer.pth'):
    print('load model')
    model.load_state_dict(torch.load('transformer.pth',map_location=device))
model = model.to(device)

load model


In [None]:
# opt = optim.Adam(model.parameters(), lr=lr)
# Loss = nn.CrossEntropyLoss(ignore_index=PAD_ID)
opt= NoamOpt(d_model, 2, (int)(num_epochs * (29476 / batch_size) * 0.1), torch.optim.Adam(model.parameters(), lr=0, betas=(0.9,0.98), eps=1e-9))
opt._step = (int)(num_epochs * (29476 / batch_size))
Loss = LabelSmoothing(ch_vocab_size, padding_idx = PAD_ID, smoothing= 0.0)
min_loss =  1 
plt.ion()
for epoch in range(num_epochs):
    torch.cuda.empty_cache()
    model = model.to(device)
    model.train()
    loop = tqdm(loader, leave=False)
    loss_sum = 0
    token_sum = 0
    correct = 0
    correct_sum = 0
    for idx, (x, y) in enumerate(loop):
        src, tgt = x.to(device), y[:, :-1].to(device)
        
        # src_key_padding_mask = memory_key_padding_mask = (src == PAD_ID).to(device)
        # tgt_key_padding_mask = (tgt == PAD_ID).to(device)
        # tgt_mask = (model.model.generate_square_subsequent_mask(tgt.shape[1]) == -math.inf).to(device)

        # y_hat = model(src=src, tgt=tgt, tgt_mask=tgt_mask,
        #             src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)

        y_hat = model(src, tgt)
        
        tokens = (y[:, 1:] != PAD_ID).sum()
        loss = Loss(y_hat.reshape(-1, y_hat.size(-1)), y[:, 1:].long().reshape(-1).to(device)) / tokens
        if math.isnan(loss.item()):
            raise Exception
        
        # opt.zero_grad()
        opt.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        opt.step()
        
        loss_sum += loss.item() * tokens
        token_sum +=  tokens
        correct += ((y[:, 1:].to(device) == y_hat.argmax(-1)) * (y[:, 1:].to(device) != PAD_ID)).sum()
        correct_sum += (y[:, 1:].to(device) != PAD_ID).sum()
    
    plt.subplot(3, 1 ,1)
    plt.plot(epoch+1, (loss_sum/token_sum).cpu(), '.', color='red')
    plt.subplot(3, 1,2)
    plt.plot(epoch+1, (correct/correct_sum).cpu(), '.', color='blue')
    plt.subplot(3, 1,3)
    plt.plot(opt._step, opt._rate, '.', color='green')
    plt.tight_layout()
    display.clear_output(wait=True)
    display.display(plt.gcf())
    
    print(f"epoch {epoch + 1} loss {loss_sum / token_sum}  acc:{correct / correct_sum}")
    if loss_sum / token_sum < min_loss:
        min_loss = loss_sum / token_sum
        # torch.save(model.state_dict(), 'transformer.pth')
        print("save model")

# evacuate

In [18]:
valid_dataset = MyDataset("cmn-eng/en_tensor_valid.txt", "cmn-eng/ch_tensor_valid.txt")
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=True, drop_last=False, collate_fn=collate_fn)
en_vocab, ch_vocab = load_vocab('cmn-eng/en_vocab.txt'), load_vocab('cmn-eng/ch_vocab.txt')
model = Transformer(en_vocab_size, ch_vocab_size, d_model, d_ff, num_heads, num_layers)
if os.path.exists('transformer.pth'):
    print('load model')
    model.load_state_dict(torch.load('transformer.pth',map_location=device))
model = model.to(device)

load model


In [19]:
import collections

def bleu(pred_seq, label_seq, k): #@save
    """计算BLEU"""
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))
    for n in range(1, k + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i + n])] += 1
        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i + n])] -= 1
        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
    return score

In [20]:
for x, y in valid_loader:
    target = torch.tensor([[SOS_ID]] * y.shape[0])
    model.eval()
    for i in range(y.shape[1]): 
        src, tgt = x[0:1].to(device), target[0:1].to(device)
        # src_key_padding_mask = memory_ke y_padding_mask = (src == PAD_ID).to(device)
        # tgt_key_padding_mask = (tgt == PAD_ID).to(device) 
        # tgt_mask = model.model.generate_square_subsequent_mask(tgt.shape[1]).to(device)     
        
        # y_hat = model(src, tgt, tgt_mask=tgt_mask,
        #               src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        y_hat = model(src, tgt)

        out = torch.argmax(y_hat[:, -1, :], dim=-1)
        if out == EOS_ID:
            break
        target = torch.concat([target[0:1].to(device), out.unsqueeze(1).to(device)], dim=1)
        
    
    origin = tensor_to_sentence(y[0, 1: -1], ch_vocab)
    trans =  tensor_to_sentence(target[0, 1:], ch_vocab)
    print(tensor_to_sentence(x[0, 1: -1], en_vocab))
    print(origin)
    print(trans)    
    print(bleu(trans, origin, k=1))  
    
    break

the man is loading the moving truck on his own .
這個 男人 獨自 把 東西 搬 上 搬家 卡車 上 。
完全 是 自己 的 电话 。
0.17742397566167217


In [21]:
ans = []
def beam_search(model, src, tgt=None, max_len=1, k=2, pr=1):
    model.eval()
    if tgt is None:
        tgt = torch.tensor([SOS_ID]).to(device)
    
    if tgt.shape[0] >= max_len:
        ans.append([tensor_to_sentence(tgt[1:], ch_vocab), (int)(pr)])
        return
    
    y_hat = model(src.unsqueeze(0), tgt.unsqueeze(0))
    vals, idxs = torch.sort(y_hat[0, -1, :], descending=True)
    for i in range(k):
        idx = idxs[i]
        if idx == EOS_ID:
            ans.append([tensor_to_sentence(tgt[1:], ch_vocab), (int)(pr)])
            continue
        beam_search(model, src, torch.cat([tgt, idx.unsqueeze(0)], dim=0), max_len, k, vals[i]*pr)

for x, y in valid_loader:
    x, y = x[0].to(device), y[0].to(device)
    beam_search(model, x, tgt=None, max_len=(y!=PAD_ID).sum()-1, k=2, pr=1)
    
    print(tensor_to_sentence(x, en_vocab))
    print(tensor_to_sentence(y, ch_vocab))   
    
    ans = sorted(ans , key=lambda ans:ans[1], reverse=True)
    for i in range(min(6, len(ans))):
        print(ans[i][1], ":", ans[i][0])
    break


<sos> it ' s hard to predict what the weather will be like tomorrow . <eos>
<sos> 很难说 明天 的 天气 将会 怎样 。 <eos>
73866 : 天氣 并 一起 曾 。 。
42938 : 天氣 并 一起 。 。 。
6295 : 天氣 怎麼樣 。 幾個 天氣 ？
3049 : 天氣 怎麼樣 。 幾個 天气 了
2283 : 天氣 怎麼樣 。 幾個 天氣 。
2256 : 明天 的 天氣 他 喜歡 。


In [22]:
def translate_english2chinese(model, english, en_vocab, ch_vocab):
    assert is_contain_chinese(english) == False , "该句子包含中文"
    tokenizer = get_tokenizer('basic_english')
    english = tokenizer(english)
    x = torch.tensor(sentence_to_tensor([english], en_vocab)).to(device)
    y = torch.tensor([[SOS_ID]]).to(device)
    model = model.to(device)
    model.eval()
    while True:
        src, tgt = x[0:1].to(device), y[0:1].to(device)
        y_hat = model(src, tgt)
        out = torch.argmax(y_hat[:, -1, :], dim=-1)
        y = torch.concat([y[0:1].to(device), out.unsqueeze(1).to(device)], dim=1)
        if out == EOS_ID:
            break

    return tensor_to_sentence(y[0], ch_vocab)

translate_english2chinese(model, "my name is Tom .", en_vocab, ch_vocab)

  x = torch.tensor(sentence_to_tensor([english], en_vocab)).to(device)


'<sos> 湯姆 是 我 的 名字 。 <eos>'