In [43]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import random
import spacy
import datasets
import torchtext
import evaluate
import tqdm

#### 1. 设置随机种子 以便在随机操作时可以获得可重复结果

In [44]:
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True

#### 2. 准备数据

In [48]:
dataset = datasets.load_dataset("bentrevett/multi30k")
train_data, valid_data, test_data = (
    dataset['train'],
    dataset['validation'],
    dataset['test'],
)   

In [49]:
train_data[0]

{'en': 'Two young, White males are outside near many bushes.',
 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}

#### 3. tokenize

In [50]:
en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

In [118]:
str = 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'
[token.text for token in de_nlp.tokenizer(str)]

['Zwei',
 'junge',
 'weiße',
 'Männer',
 'sind',
 'im',
 'Freien',
 'in',
 'der',
 'Nähe',
 'vieler',
 'Büsche',
 '.']

In [51]:
def tokenize_example(example, en_nlp, de_nlp, max_length, lower, sos_token, eos_token):
    en_tokens = [token.text for token in en_nlp.tokenizer(example['en'])][:max_length]
    de_tokens = [token.text for token in de_nlp.tokenizer(example['de'])][:max_length]
    if lower:
        en_tokens = [token.lower() for token in en_tokens]
        de_tokens = [token.lower() for token in de_tokens]
    en_tokens = [sos_token] + en_tokens + [eos_token]
    de_tokens = [sos_token] + de_tokens + [eos_token]
    return {'en_tokens' : en_tokens, 'de_tokens' : de_tokens}

In [52]:
max_length = 100
lower = True
sos_token = '<sos>'
eos_token = '<eos>'

fn_kwargs = {
    'en_nlp' : en_nlp,
    'de_nlp' : de_nlp,
    'max_length' : max_length,
    'lower' : lower,
    'sos_token' : sos_token,
    'eos_token' : eos_token
}
train_data = train_data.map(tokenize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(tokenize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_example, fn_kwargs=fn_kwargs)

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1014 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [61]:
train_data[0]

{'en': 'Two young, White males are outside near many bushes.',
 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
 'en_tokens': ['<sos>',
  'two',
  'young',
  ',',
  'white',
  'males',
  'are',
  'outside',
  'near',
  'many',
  'bushes',
  '.',
  '<eos>'],
 'de_tokens': ['<sos>',
  'zwei',
  'junge',
  'weiße',
  'männer',
  'sind',
  'im',
  'freien',
  'in',
  'der',
  'nähe',
  'vieler',
  'büsche',
  '.',
  '<eos>'],
 'en_ids': [2, 0, 0, 30, 0, 0, 0, 0, 0, 0, 0, 22, 3],
 'de_ids': [2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 3]}

#### 4. 建立vocabulary

In [62]:
min_freq = 2
unk_token = '<unk>'
pad_token = '<pad>'
special_tokens = [
    unk_token,
    pad_token,
    sos_token,
    eos_token,
]
en_vocab = torchtext.vocab.build_vocab_from_iterator(
    train_data['en_tokens'],
    min_freq=min_freq,
    specials=special_tokens,
)
de_vocab = torchtext.vocab.build_vocab_from_iterator(
    train_data['de_tokens'],
    min_freq=min_freq,
    specials=special_tokens,
)

In [64]:
assert en_vocab.get_stoi()[pad_token] == de_vocab[pad_token]
assert en_vocab[unk_token] == de_vocab[unk_token]
unk_index = en_vocab[unk_token]
pad_index = en_vocab[pad_token]

In [65]:
en_vocab.set_default_index(unk_index)
de_vocab.set_default_index(unk_index)

#### 5. tokens--->indices

In [66]:
def numericalize_example(example, en_vocab, de_vocab):
    en_ids = en_vocab.lookup_indices(example['en_tokens'])
    de_ids = de_vocab.lookup_indices(example['de_tokens'])
    return {'en_ids' : en_ids, 'de_ids' : de_ids}

In [67]:
fn_kwargs = {"en_vocab" : en_vocab, "de_vocab" : de_vocab}
train_data = train_data.map(numericalize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(numericalize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(numericalize_example, fn_kwargs=fn_kwargs)

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1014 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [68]:
train_data[0]

{'en': 'Two young, White males are outside near many bushes.',
 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
 'en_tokens': ['<sos>',
  'two',
  'young',
  ',',
  'white',
  'males',
  'are',
  'outside',
  'near',
  'many',
  'bushes',
  '.',
  '<eos>'],
 'de_tokens': ['<sos>',
  'zwei',
  'junge',
  'weiße',
  'männer',
  'sind',
  'im',
  'freien',
  'in',
  'der',
  'nähe',
  'vieler',
  'büsche',
  '.',
  '<eos>'],
 'en_ids': [2, 16, 24, 15, 25, 778, 17, 57, 80, 202, 1312, 5, 3],
 'de_ids': [2, 18, 26, 253, 30, 84, 20, 88, 7, 15, 110, 7647, 3171, 4, 3]}

In [71]:
data_type = 'torch'
format_columns = ['en_ids', 'de_ids']
train_data = train_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)
valid_data = valid_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)
test_data = test_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)

#### 6. 创建dataloader

In [72]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_en_ids = [example['en_ids'] for example in batch]
        batch_de_ids = [example['de_ids'] for example in batch]
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, batch_first=True, padding_value=pad_index)
        batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, batch_first=True,padding_value=pad_index)
        return {"en_ids" : batch_en_ids, "de_ids" : batch_de_ids}
    return collate_fn

In [73]:
def get_data_loader(dataset, batch_size, pad_index, shuffle=False):
    collate_fn = get_collate_fn(pad_index)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
    )
    return data_loader

#### 7. 建立模型

In [74]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_layers, 
                 n_heads, pf_dim, dropout, device, max_length=100):
        
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(input_dim, hidden_dim)
        # 用的是max_length，前面的dataloader在创建batch数据时是否需要统一每个batch中seq长度？
        self.pos_embedding = nn.Embedding(max_length, hidden_dim)
        self.layers = nn.ModuleList([
            EncoderLayer(hidden_dim, n_heads, pf_dim, dropout, device) 
            for _ in range(n_layers)
        ])
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hidden_dim])).to(device)

    def forward(self, src, src_mask):
        # 关于src_mask的维度和作用？
        # src [batch_size x seq_length]
        batch_size = src.shape[0]
        src_len = src.shape[1]
        # 初始化pos，作为pos_embedding的输入，repeat重复
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        src = self.dropout(self.tok_embedding(src)) + self.pos_embedding(pos)
        for layer in self.layers:
            src = layer(src, src_mask)
        return src

In [75]:
class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.self_attention = MultiHeadAttentionLayer(hidden_dim, n_heads, dropout, device)
        self.self_attn_layer_norm = nn.LayerNorm(hidden_dim)
        self.ff_layer_norm = nn.LayerNorm(hidden_dim)
        self.positionwise_feedforward = PositionwiseFeedForwardLayer(hidden_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_mask):
        # src [batch_size x src_length x hidden_dim]
        _src, _ = self.self_attention(src, src, src, src_mask)
        src = self.self_attn_layer_norm(src + self.dropout(_src))
        _src = self.positionwise_feedforward(src)
        src = self.ff_layer_norm(self.dropout(_src) + src)
        return src

In [76]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        assert hid_dim % n_heads == 0
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]

        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        # Q [batch_size, query_len, hid_dim]
        # 划分多头
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        # Q [batch_size, query_len, n_heads, head_dim] --> [batch_size, n_heads, query_len, head_dim]
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        # energy [batch_size, n_heads, query_len, key_len]
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        # 掩码为0的对应位置替换为-inf
        if mask is not None:
            energy = energy.masked_fill(mask==0, -1e10)
        attention = torch.nn.functional.softmax(energy, dim=-1)
        x = torch.matmul(self.dropout(attention), V).permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.hid_dim)
        x = self.fc_o(x)
        return x, attention

In [77]:
class PositionwiseFeedForwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc1 = nn.Linear(hid_dim, pf_dim)
        self.fc2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        # x [batch_size, q_len, hid_dim]
        x = self.dropout(torch.nn.functional.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

In [78]:
class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, 
                 n_heads, pf_dim, dropout, device, max_length=100):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device) 
                                    for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.device = device
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        # trg [batch_size, trg_len]
        # enc_src [batch_size, src_len, hid_dim]
        # mask ?
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        # trg [batch_size, trg_len, hid_dim]
        trg = self.dropout(self.tok_embedding(trg)) + self.pos_embedding(pos)
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        # trg [batch_size, trg_len, hid_dim]
        # attention [batch_size, n_heads, trg_len, src_len]
        output = self.fc_out(trg)
        return output, attention

In [79]:
class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedForwardLayer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)
    def forward(self, trg, enc_src, trg_mask, src_mask):
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
        _trg = self.positionwise_feedforward(trg)
        trg = self.ff_layer_norm(trg + self.dropout(_trg))
        return trg, attention

In [80]:
class Seq2seq(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
    def make_src_mask(self, src):
        # src [batch_size, src_len]
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # src_mask [batch_size, 1, 1, src_len]
        return src_mask
    def make_trg_mask(self, trg):
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        # trg_pad_mask [batch_size, 1, 1, trg_len]
        # 创建一个下三角的mask 维度是[trg_len, trg_len]
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).bool()
        trg_mask = trg_pad_mask & trg_sub_mask
        # trg_mask [batch_size, 1, trg_len, trg_len]
        return trg_mask
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        return output, attention

#### 8. 训练

In [81]:
input_dim = len(de_vocab)
output_dim = len(en_vocab)
hid_dim = 256
enc_layers = 3
dec_layers = 3
enc_heads = 8
dec_heads = 8
enc_pf_dim = 512
dec_pf_dim = 512
enc_dropout = 0.1
dec_dropout = 0.1
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
encoder = Encoder(input_dim, hid_dim, enc_layers, enc_heads, enc_pf_dim, enc_dropout, device)
decoder = Decoder(output_dim, hid_dim, dec_layers, dec_heads, dec_pf_dim, dec_dropout, device)

In [82]:
model = Seq2seq(encoder, decoder, pad_index, pad_index, device)

In [83]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"The model has {count_parameters(model)} trainable parameters.")

The model has 9038341 trainable parameters.


In [84]:
def initialize_weights(model):
    for name, param in model.named_parameters():
        if hasattr(param, 'weight') and param.weight.dim() > 1:
            nn.init.xavier_uniform_(param.data)
model.apply(initialize_weights)

Seq2seq(
  (encoder): Encoder(
    (tok_embedding): Embedding(7853, 256)
    (pos_embedding): Embedding(100, 256)
    (layers): ModuleList(
      (0-2): 3 x EncoderLayer(
        (self_attention): MultiHeadAttentionLayer(
          (fc_q): Linear(in_features=256, out_features=256, bias=True)
          (fc_k): Linear(in_features=256, out_features=256, bias=True)
          (fc_v): Linear(in_features=256, out_features=256, bias=True)
          (fc_o): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (positionwise_feedforward): PositionwiseFeedForwardLayer(
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
 

In [85]:
learning_rate = 0.0005
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [86]:
criterion = nn.CrossEntropyLoss(ignore_index=pad_index)

In [87]:
def train_fn(model, data_loader, optimizer, criterion, clip, device):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(data_loader):
        src = batch['de_ids'].to(device)
        trg = batch['en_ids'].to(device)
        optimizer.zero_grad()
        # 不自己输入<sos>
        output, _ = model(src, trg[:, :-1])
        output_dim = output.shape[-1]
        # output [batch_size, trg_len-1, output_dim]
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:, 1:].contiguous().view(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

In [88]:
def evaluate_fn(model, data_loader, optimizer, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            src = batch['de_ids'].to(device)
            trg = batch['en_ids'].to(device)
            output, _ = model(src, trg[:, :-1])
            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:, 1:].contiguous().view(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

In [89]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    return int(elapsed_time / 60), int(elapsed_time - elapsed_time // 60 * 60)

In [90]:
batch_size = 128

train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)
valid_data_loader = get_data_loader(valid_data, batch_size, pad_index)
test_data_loader = get_data_loader(test_data, batch_size, pad_index)

In [91]:
import time
import math

In [92]:
n_epochs = 10
clip = 1.0
best_valid_loss = float('inf')
for epoch in tqdm.tqdm(range(n_epochs)):
    start_time = time.time()
    train_loss = train_fn(model, train_data_loader, optimizer, criterion, clip, device)
    valid_loss = evaluate_fn(model, valid_data_loader, optimizer, criterion, device)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut6-model.pt')
    print(f"Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s")
    print(f"\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}")
    print(f"\tValid Loss: {valid_loss:.3f} | Valid PPL: {math.exp(valid_loss):7.3f}")

 10%|█████████▏                                                                                  | 1/10 [07:43<1:09:33, 463.70s/it]

Epoch: 01 | Time: 7m 43s
	Train Loss: 4.089 | Train PPL:  59.679
	Valid Loss: 2.997 | Valid PPL:  20.016


 20%|██████████████████▊                                                                           | 2/10 [14:59<59:39, 447.48s/it]

Epoch: 02 | Time: 7m 15s
	Train Loss: 2.864 | Train PPL:  17.527
	Valid Loss: 2.449 | Valid PPL:  11.578


 30%|████████████████████████████▏                                                                 | 3/10 [22:40<52:55, 453.69s/it]

Epoch: 03 | Time: 7m 40s
	Train Loss: 2.404 | Train PPL:  11.071
	Valid Loss: 2.168 | Valid PPL:   8.743


 40%|█████████████████████████████████████▌                                                        | 4/10 [30:17<45:29, 454.97s/it]

Epoch: 04 | Time: 7m 36s
	Train Loss: 2.094 | Train PPL:   8.113
	Valid Loss: 1.981 | Valid PPL:   7.249


 50%|███████████████████████████████████████████████                                               | 5/10 [37:57<38:04, 456.83s/it]

Epoch: 05 | Time: 7m 40s
	Train Loss: 1.861 | Train PPL:   6.432
	Valid Loss: 1.861 | Valid PPL:   6.429


 60%|████████████████████████████████████████████████████████▍                                     | 6/10 [45:35<30:28, 457.08s/it]

Epoch: 06 | Time: 7m 37s
	Train Loss: 1.677 | Train PPL:   5.349
	Valid Loss: 1.791 | Valid PPL:   5.997


 70%|█████████████████████████████████████████████████████████████████▊                            | 7/10 [53:10<22:49, 456.50s/it]

Epoch: 07 | Time: 7m 35s
	Train Loss: 1.524 | Train PPL:   4.593
	Valid Loss: 1.743 | Valid PPL:   5.716


 80%|█████████████████████████████████████████████████████████████████████████▌                  | 8/10 [1:00:48<15:13, 456.73s/it]

Epoch: 08 | Time: 7m 37s
	Train Loss: 1.392 | Train PPL:   4.025
	Valid Loss: 1.719 | Valid PPL:   5.579


 90%|██████████████████████████████████████████████████████████████████████████████████▊         | 9/10 [1:08:11<07:32, 452.57s/it]

Epoch: 09 | Time: 7m 23s
	Train Loss: 1.279 | Train PPL:   3.593
	Valid Loss: 1.700 | Valid PPL:   5.472


100%|███████████████████████████████████████████████████████████████████████████████████████████| 10/10 [1:15:41<00:00, 454.11s/it]

Epoch: 10 | Time: 7m 29s
	Train Loss: 1.180 | Train PPL:   3.254
	Valid Loss: 1.693 | Valid PPL:   5.434





In [104]:
model.load_state_dict(torch.load('tut6-model.pt'))

test_loss = evaluate_fn(model, test_data_loader, optimizer, criterion, device)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

| Test Loss: 1.765 | Test PPL:   5.843 |


In [131]:
def translate_sentence(
    sentence,
    model,
    en_nlp,
    de_nlp,
    en_vocab,
    de_vocab,
    lower,
    sos_token,
    eos_token,
    device,
    max_length=100,
):
    model.eval()
    with torch.no_grad():
        # 若sentence是字符串，转化为token
        if isinstance(sentence, type(str)):
            tokens = [token.text for token in de_nlp.tokenizer(sentence)]
        else:
            tokens = [token for token in sentence]
        if lower:
            tokens = [token.lower() for token in tokens]
        #print(tokens)
        tokens = [sos_token] + tokens + [eos_token]
        ids = de_vocab.lookup_indices(tokens)
        #print(tokens, ids)
        # 加上一个维度，batch dimension
        src = torch.LongTensor(ids).unsqueeze(0).to(device)
        src_mask = model.make_src_mask(src)
        with torch.no_grad():
            enc_src = model.encoder(src, src_mask)
        trg_ids = [en_vocab.get_stoi()[sos_token]]
        # 循环传入输入，一个接一个地生成单词
        for i in range(max_length):
            trg_tensor = torch.LongTensor(trg_ids).unsqueeze(0).to(device)
            trg_mask = model.make_trg_mask(trg_tensor)
            with torch.no_grad():
                output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
            # argmax(2)在vacab_size这个维度上找出最大值的索引
            # 取出句子中最后一个单词的idx
            pred_token = output.argmax(2)[:,-1].item()
            trg_ids.append(pred_token)
            if pred_token == en_vocab.get_stoi()[eos_token]:
                break
        trg_tokens = en_vocab.lookup_tokens(trg_ids)
        return trg_tokens[1:], attention

In [132]:
sentence = test_data[0]["de"]
expected_translation = test_data[0]["en"]
sentence, expected_translation

('Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.',
 'A man in an orange hat starring at something.')

In [133]:
translation = translate_sentence(
    sentence,
    model,
    en_nlp,
    de_nlp,
    en_vocab, 
    de_vocab,
    lower,
    sos_token,
    eos_token,
    device
)
translation[0]

['a', 'man', 'in', 'an', 'orange', 'hat', 'welding', 'something', '.', '<eos>']

#### 9. 计算bleu

In [134]:
translations = [
    translate_sentence(
        example["de"],
        model,
        en_nlp,
        de_nlp,
        en_vocab,
        de_vocab,
        lower,
        sos_token,
        eos_token,
        device,
    ) for example in tqdm.tqdm(test_data)
]

100%|██████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [07:02<00:00,  2.37it/s]


In [136]:
bleu = evaluate.load('bleu')

In [143]:
predictions = [" ".join(translation[0][1:-1]) for translation in translations]
references = [[example['en']] for example in test_data]

In [144]:
predictions[0], references[0]

('man in an orange hat welding something .',
 ['A man in an orange hat starring at something.'])

In [145]:
def get_tokenizer_fn(nlp, lower):
    def tokenizer_fn(s):
        tokens = [token.text for token in nlp.tokenizer(s)]
        if lower:
            tokens = [token.lower() for token in tokens]
        return tokens
    return tokenizer_fn

In [146]:
tokenizer_fn = get_tokenizer_fn(en_nlp, lower)
tokenizer_fn(predictions[0]), tokenizer_fn(references[0][0])

(['man', 'in', 'an', 'orange', 'hat', 'welding', 'something', '.'],
 ['a', 'man', 'in', 'an', 'orange', 'hat', 'starring', 'at', 'something', '.'])

In [147]:
results = bleu.compute(
    predictions=predictions,
    references=references,
    tokenizer=tokenizer_fn
)

In [148]:
results

{'bleu': 0.3108307280672848,
 'precisions': [0.6392705069494657,
  0.3889228618852101,
  0.26064899014071025,
  0.1752937440457288],
 'brevity_penalty': 0.9520972141629594,
 'length_ratio': 0.9532087609128503,
 'translation_length': 12447,
 'reference_length': 13058}