In [1]:
import torch
from torchtext import data
from torch.utils.data import Dataset

In [2]:
import random
import numpy as np
import unicodedata
import re
import jieba
import tqdm.auto
from tqdm.auto import trange

In [3]:
from utils_batch import *

In [4]:
from models.decoder_batch import DecoderRNN
from models.encoder_batch import EncoderRNN

In [5]:
def mask_create(batch_size, maxlen, output_size, pad_token, device, train=True):
    if train:
        mask = torch.ones((batch_size, maxlen, output_size), device=device)
        mask[:, :, pad_token] = 0
    else:
        mask = torch.ones((batch_size, 1, output_size), device=device)
        mask[:, :, pad_token] = 0
    return mask

In [164]:
def train(input_tensor, target_tensor, mask_train, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, device):
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
        
    batch_size = input_tensor.size(0)

    loss = 0 
    encoder_outputs, encoder_hidden = encoder(input_tensor)
    
    decoder_inputs = torch.zeros_like(target_tensor, dtype=torch.long, device=device)
    decoder_inputs[:, 0] = SOS_token
    decoder_inputs[:, 1:] = target_tensor[:, :-1]
    decoder_input = decoder_inputs[:, 0].view(-1, 1)
    decoder_hidden = encoder_hidden
    
    # always use_teacher_forcing:
    if True:
        decoder_outputs, decoder_hidden = decoder(decoder_inputs, decoder_hidden)
        decoder_outputs.masked_fill(mask_train == 0, -1e10)
        decoder_outputs = torch.nn.LogSoftmax(dim=-1)(decoder_outputs)
        #print(decoder_outputs.shape)
        loss = criterion(decoder_outputs.permute(0, 2, 1), target_tensor)
    else:
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()
            
            loss += criterion(decoder_output.view(-1, output_size), target_tensor[:, di])
            #if decoder_input.item() == EOS_token:
            #    break
    
    loss.backward()
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return loss.item()

In [165]:
def train_epoch(train_loader, epoch, device):
    print(f"\nTrain Epoch: {epoch}.")
    loss_train, steps = 0.0, 0
    for input_sample, target_sample in train_loader:
        input_sample, target_sample = input_sample.to(device), target_sample.to(device)
        loss = train(input_sample, target_sample, mask_train, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, device)
        loss_train += loss
        steps += 1
    loss_train_avg = loss_train / steps
    template_print = f"Epoch(train): {epoch} Loss_train: {loss_train_avg:.3f}."
    print(template_print)
    return loss_train_avg

In [120]:
output_lang.index2word[0]

'SOS'

In [166]:
def evaluate(input_tensor, target_tensor, mask_eval, encoder, decoder, device, num_pairs_to_show):
    with torch.no_grad():        
        batch_size, max_length = input_tensor.shape[0], input_tensor.shape[1]
        _, encoder_hidden = encoder(input_tensor)
        ### first input of the sequence is SOS
        decoder_input = torch.zeros((batch_size, 1), dtype=torch.int)
        decoder_input = decoder_input.to(device)
        decoder_hidden = encoder_hidden
        decoded_index_list = []
        for di in range(max_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            decoder_output.masked_fill(mask_eval == 0, -1e5)
            decoder_output = torch.nn.LogSoftmax(dim=-1)(decoder_output)
            decoder_output = decoder_output.squeeze(1)
            #print(decoder_output.shape)
            topv, topi = decoder_output.data.topk(1)
            decoded_index_list.append(topi.squeeze(-1).cpu().numpy())
            decoder_input = topi.detach()
        decoded_index_array = np.array(decoded_index_list).transpose()
        bleu_accum = 0
        decoded_pairs = []
        for i in range(batch_size):
            input_word_list = []
            output_word_list = []
            target_word_list = []
            decoded_index = decoded_index_array[i].tolist()
            for word_index in input_tensor[i].tolist():
                if word_index == EOS_token:
                    break
                else:
                    input_word_list.append(input_lang.index2word[word_index])
            for word_index in decoded_index:
                if word_index == EOS_token:
                    break
                else:
                    output_word_list.append(output_lang.index2word[word_index])
            for word_index in target_tensor[i].tolist():
                if word_index == EOS_token:
                    break
                else:
                    target_word_list.append(output_lang.index2word[word_index])
            max_n = min(4, len(output_word_list), len(target_word_list))
            if max_n == 0:
                bleu = 0
            else:
                weights = (1.0 / max_n, ) * max_n
                try:
                    bleu = data.metrics.bleu_score([output_word_list], [[target_word_list]], max_n=max_n, weights=weights)
                    decoded_pairs.append([input_word_list, output_word_list, target_word_list])
                except IndexError:
                    bleu = 0
                    print(f'decoded_words: {output_word_list}.')
                    print(f'sentence_list: {target_word_list}.')
                    print(f'max_n: {max_n}.')
            bleu_accum += bleu
        bleu_avg = bleu_accum / batch_size
        if len(decoded_pairs) <= num_pairs_to_show:
            decoded_pairs_sampled = decoded_pairs
        else:
            decoded_pairs_sampled = random.sample(decoded_pairs, num_pairs_to_show)
        return bleu_avg, decoded_pairs_sampled

In [167]:
def evaluate_epoch(test_loader, epoch_id, device):
    print(f"\nTest Epoch: {epoch_id}.")
    bleu_accum, steps = 0.0, 0
    for input_sample, target_sample in test_loader:
        input_sample, target_sample = input_sample.to(device), target_sample.to(device)
        bleu_avg, decoded_pairs = evaluate(input_sample, target_sample, mask_eval, encoder, decoder, device, 10)
        bleu_accum += bleu_avg
        steps += 1
    bleu_avg_epoch = bleu_accum / steps
    template_print = f"Epoch(test): {epoch_id} Bleu_test: {bleu_avg_epoch:.3f}."
    print(template_print)
    for pair in decoded_pairs:
        input_pair, decoded_pair, target_pair = pair
        input_sentence = ' '.join(input_pair)
        decoded_sentence = ''.join(decoded_pair)
        target_sentence = ''.join(target_pair)
        print('----------------------------------')
        print('<', input_sentence)
        print('=', target_sentence)
        print('>', decoded_sentence)
    return bleu_avg_epoch 

In [162]:
bleu_avg, decoded_pairs = evaluate(input_sample, target_sample, mask_eval, encoder, decoder, device, 10)

In [None]:
evaluate_epoch(enzh_loader_test, 1, device)

In [131]:
def evaluateRandomly(encoder, decoder, data_valid, n=100):
    bleu_acm = 0.0
    for i in range(n):
        pair = random.choice(data_valid)
        print('>', pair[0])
        print('=', pair[1])
        output_words, bleu_i = evaluate(encoder, decoder, pair[0], pair[1])
        bleu_acm += bleu_i
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')
    bleu_avg = bleu_acm / n
    print(f"avg bleu is: {bleu_avg:.3f}.")

---------------------------
#### Training and Prediction

In [132]:
SOS_token = 0
EOS_token = 1
PAD_token = 2

In [133]:
file_path = './data/cmn.txt'
input_lang, output_lang, pairs = prepareDataMand(file_path, 'en', 'zh')

Reading Lines...
Read 29458 sentence pairs
Counting words...
Counted words:
en 7368
zh 16262


In [134]:
max_length_input = input_lang.max_length()
max_length_output = output_lang.max_length()

In [135]:
max_length_input

34

In [136]:
max_length_output

28

In [137]:
batch_size = 256
maxlen = 35
device = torch.device('cuda:4')

In [138]:
enzh_loader_train, enzh_loader_test, _ = enzh_loader(file_path, batch_size, maxlen)

Reading Lines...
Read 29458 sentence pairs
Counting words...
Counted words:
en 7368
zh 16262
26513 samples in training dataset.
2945 samples in valid dataset.


In [139]:
input_sample, target_sample = next(iter(enzh_loader_test))

In [140]:
input_sample, target_sample = input_sample.to(device), target_sample.to(device)

In [141]:
evaluate(input_sample, target_sample, mask_eval, encoder, decoder, device, 10);

In [142]:
hidden_size = 256
input_size = input_lang.n_words
output_size = output_lang.n_words

In [143]:
mask_train = mask_create(batch_size, maxlen, output_size, PAD_token, device, train=True)
mask_eval = mask_create(batch_size, maxlen, output_size, PAD_token, device, train=False)

In [144]:
encoder = EncoderRNN(input_size, hidden_size).to(device)
decoder = DecoderRNN(hidden_size, output_size).to(device)

In [171]:
encoder_optimizer = torch.optim.SGD(encoder.parameters(), lr=4e-1)
decoder_optimizer = torch.optim.SGD(decoder.parameters(), lr=4e-1)

In [146]:
criterion = torch.nn.NLLLoss(ignore_index=PAD_token)

In [173]:
num_epochs = 350

In [148]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

In [None]:
loss_train_avg_list, bleu_score_list = [], []
#fig = plt.figure(figsize=(12, 4))
#ax1 = fig.add_subplot(1, 2, 1);
#ax2 = fig.add_subplot(1, 2, 2);
for i in trange(1, num_epochs+1):
    #display(fig)
    #x1 = np.arange(0, i-1);
    #x2 = np.arange(0, i-1);
    
    #ax1.set_xlim(0, i)
    #ax2.set_xlim(0, i)
    #ax1.set_xlabel('Epoch')
    #ax2.set_xlabel('Epoch')
    
    #ax1.cla()
    #ax1.plot(x1, loss_train_avg_list)
    
    #ax2.cla()
    #ax2.plot(x2, bleu_score_list)
    if i % 10 == 0:
        clear_output(wait=True)
    loss_train_avg = train_epoch(enzh_loader_train, i, device)
    bleu_avg_i = evaluate_epoch(enzh_loader_test, i, device)
    loss_train_avg_list.append(loss_train_avg)
    bleu_score_list.append(bleu_avg_i)

    #plt.pause(0.8)


Train Epoch: 260.
Epoch(train): 260 Loss_train: 0.120.

Test Epoch: 260.
Epoch(test): 260 Bleu_test: 0.045.
----------------------------------
< everyone is talking about tom 
= 人人都在說湯姆
> 所有人都非常感謝你
----------------------------------
< ive always trusted your judgment 
= 我一直都相信你的判断
> 我完全不负过你
----------------------------------
< please forgive me for being late 
= 請原諒我遲到
> 请我帮我
----------------------------------
< i dont like science 
= 我不喜欢科学
> 我不喜欢狗
----------------------------------
< tom said that he was curious 
= 汤姆说他很好奇
> 汤姆说他很好奇
----------------------------------
< tom didnt understand what the teacher said 
= 汤姆没明白老师说了什么
> 湯姆說沒有人說謊
----------------------------------
< the monkey fell from the tree 
= 猴子從樹上掉了下來
> 天空从的房子爬上
----------------------------------
< i have to win 
= 我必须赢
> 我需要一名醫生
----------------------------------
< i asked him for a favor 
= 我請他幫忙
> 我向他求助
----------------------------------
< my parents have gone to the airport to see my uncle off 
= 我父母去机场送我叔叔了
> 妈妈我妈