In [None]:
from __future__ import unicode_literals, print_function, division
from io import open
import io
import unicodedata
import string
import re
import random
import os
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch.utils.data import Dataset
from torch.optim import lr_scheduler
import itertools
import glob
plt.switch_backend('agg')
import matplotlib.ticker as ticker
from sacrebleu import corpus_bleu
import sacrebleu
import pdb
import pickle
from torch.optim import lr_scheduler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import pdb
import pickle as pkl

In [None]:
batch_size = 32
words_to_load = 100000
SOS_token = 0
EOS_token = 1
PAD_token = 2
UNK_token = 3
LR_RATE = 0.0008
MAX_LENGTH = 35
hidden_size = 300
teacher_forcing_ratio = 0.9
EPOCH_NUM = 30
PRINT_FREQ = 500
dropout = 0.2
n_layers = 1
infrequent_count = 3
add = '/scratch/wz1218'

__Preprocess Data__

In [None]:
class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {"<SOS>": 0, "<EOS>": 1, "<pad>": 2, "<unk>": 3}
        self.word2count = {"<SOS>": 0, "<EOS>": 0, "<pad>": 0, "<unk>": 0}
        self.index2word = {0: "<SOS>", 1: "<EOS>", 2: "<pad>", 3: "<unk>"}
        self.n_words = 4  # Count SOS, EOS, pad and unk

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2count:
#             self.word2index[word] = self.n_words
            self.word2count[word] = 1
#             self.index2word[self.n_words] = word
#             self.n_words += 1
        else:
            self.word2count[word] += 1
            
    def buildVocab(self, count=infrequent_count, train = False, in_out = False):
        if train & in_out:
            del_list = []
            for k,v in self.word2count.items():
                if v <= count:
                    del_list.append(k)
            for k in del_list:
                self.word2count.pop(k)

        for k,v in self.word2count.items():
            self.word2index[k] = self.n_words
            self.index2word[self.n_words] = k
            self.n_words += 1

In [None]:
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

def normalizeZh(s):
    s = s.strip()
    s = re.sub("\s+", " ", s)
    return s

In [None]:
def filterPair(p):
    filtered = []
    for i in p:
        filtered.append(' '.join(i.split(' ')[:MAX_LENGTH-1]))
    return filtered

def filterPairs(pairs):
    return [filterPair(pair) for pair in pairs]

In [None]:
def readLangs(dataset, lang1, lang2):
    chinese = add+'/iwslt-zh-en/{}.tok.{}'.format(dataset, lang1)
    english = add+'/iwslt-zh-en/{}.tok.{}'.format(dataset, lang2)

    chinese_lines = open(chinese, encoding='utf-8').read().strip().split('\n')
    english_lines = open(english, encoding='utf-8').read().strip().split('\n')
    length = len(chinese_lines)

    pairs = [[normalizeZh(chinese_lines[i]), normalizeString(english_lines[i])] for i in range(length)]
    pairs = filterPairs(pairs)
    input_lang = Lang(lang1)
    output_lang = Lang(lang2)

    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    
    if dataset == 'train':
        input_lang.buildVocab(train = True, in_out = True)
        output_lang.buildVocab(train = True, in_out = True)
    else:
        input_lang.buildVocab()
        output_lang.buildVocab()
    return input_lang, output_lang, pairs

In [None]:
train_input_lang, train_output_lang, train_pairs = readLangs('train', 'zh', 'en')
val_input_lang, val_output_lang, val_pairs = readLangs('dev', 'zh', 'en')
test_input_lang, test_output_lang, test_pairs = readLangs('test', 'zh', 'en')

__Embedding__

In [None]:
def load_embedding(fname):
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    n, d = map(int, fin.readline().split())
    data = {}
    for index, line in enumerate(fin):
        tokens = line.rstrip().split(' ')
        data[tokens[0]] = [float(i) for i in tokens[1:]]
    return data

In [None]:
fname_zh = add + '/zh/zh.vec'
fname_eng = add + '/zh/fasttext300d.vec'
embedding_mat_zh = load_embedding(fname_zh)
embedding_mat_en = load_embedding(fname_eng)

In [None]:
def create_weight(index2word, embedding):
    emb_dim = 300
    words_found = 0
    wnf = []
    matrix_len = len(index2word.keys())
    weight_matrix= np.zeros((matrix_len, emb_dim))
    for k,v in index2word.items():
        if k == PAD_token:
            pass
        else:
            if v in embedding:
                weight_matrix[k] = embedding[v]
                words_found += 1
            else:
                weight_matrix[k] = np.random.normal(size=(emb_dim, ))
                wnf.append(k)
    return weight_matrix, wnf, words_found

In [None]:
chinese_wm, chin_wnf, chin_wf = create_weight(train_input_lang.index2word, embedding_mat_zh)
english_wm, eng_wnf, eng_wf = create_weight(train_output_lang.index2word, embedding_mat_en)

__Data Loader__

In [None]:
class NMTDataset(Dataset):
    """
    Class that represents a train/validation/test dataset that's readable for PyTorch
    Note that this class inherits torch.utils.data.Dataset
    """

    def __init__(self, input_lang, output_lang, pairs):
        """
        @param data_list_1: list of sentence 1 tokens 
        @param data_list_2: list of sentence 2 tokens
        @param target_list: list of review targets 

        """
        self.input_w2i = input_lang
        self.output_w2i = output_lang
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, key):
        """
        Triggered when you call dataset[i]
        """
        input_sentence = self.pairs[key][0]
        input_indexes = [self.input_w2i[word] if word in self.input_w2i else UNK_token for word in input_sentence.split(' ')]
#         input_indexes.insert(0,SOS_token)
        input_indexes.append(EOS_token)
        input_length = len(input_indexes)

        output_sentence = self.pairs[key][1]
        output_indexes = [self.output_w2i[word] if word in self.output_w2i else UNK_token for word in output_sentence.split(' ')]
#         output_indexes.insert(0,SOS_token)
        output_indexes.append(EOS_token)
        output_length = len(output_indexes)
        return [input_indexes, input_length, output_indexes, output_length]

    
def NMTDataset_collate_func(batch):
    """
    Customized function for DataLoader that dynamically pads the batch so that all 
    data have the same length
    """
    input_ls = []
    output_ls = []
    input_length_ls = []
    output_length_ls = []
    
    for datum in batch:
        input_length_ls.append(datum[1])
        output_length_ls.append(datum[3])
    
    #find max length in each batch
    max_input = sorted(input_length_ls)[-1]
    max_output = sorted(output_length_ls)[-1]
    
    # padding
    for datum in batch:
        padded_vec_input = np.pad(np.array(datum[0]), 
                                  pad_width=((0,MAX_LENGTH-datum[1])), 
                                  mode="constant", constant_values=2).tolist()
        padded_vec_output = np.pad(np.array(datum[2]), 
                                   pad_width=((0,MAX_LENGTH-datum[3])), 
                                   mode="constant", constant_values=2).tolist()
        input_ls.append(padded_vec_input)
        output_ls.append(padded_vec_output)
    return [torch.tensor(torch.from_numpy(np.array(input_ls)), device=device), 
            torch.tensor(input_length_ls, device=device), 
            torch.tensor(torch.from_numpy(np.array(output_ls)), device=device), 
            torch.tensor(output_length_ls, device=device)]

In [None]:
# create pytorch dataloader
train_dataset = NMTDataset(train_input_lang.word2index, train_output_lang.word2index, train_pairs)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size,
                                           collate_fn=NMTDataset_collate_func,
                                           shuffle=True,
                                           drop_last=True)

val_dataset = NMTDataset(train_input_lang.word2index, train_output_lang.word2index, val_pairs)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                         batch_size=batch_size,
                                         collate_fn=NMTDataset_collate_func,
                                         shuffle=True,
                                         drop_last=True)

test_dataset = NMTDataset(train_input_lang.word2index, train_output_lang.word2index, test_pairs)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                         batch_size=batch_size,
                                         collate_fn=NMTDataset_collate_func,
                                         shuffle=True,
                                         drop_last=True)

__Encoder__

In [None]:
class EncoderRNN(nn.Module):
    def __init__(self, vi_wm, input_size, hidden_size, direction, layer, dropout_p):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.direction = direction
        self.layer = layer
        embed_mat = torch.from_numpy(vi_wm).float()
        n, embed_dim = embed_mat.shape
        self.embedding = nn.Embedding.from_pretrained(embed_mat, freeze = False)
        self.lstm = nn.LSTM(embed_dim, hidden_size, batch_first=True, dropout = dropout_p, bidirectional=True)
        self.dropout = nn.Dropout(dropout_p)
    def forward(self, input, input_len, hidden):
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)
    
        output, hidden = self.lstm(embedded, hidden)
        
        return output, hidden

    def initHidden(self, batch_size):
        return (torch.zeros(2, batch_size, self.hidden_size, device=device),
                torch.zeros(2, batch_size, self.hidden_size, device=device))

__Decoder With Attention__

In [None]:
class Attn(nn.Module):
    def __init__(self, method, hidden_size, max_length=MAX_LENGTH):
        super(Attn, self).__init__()
        
        self.method = method
        self.hidden_size = hidden_size
        
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)

        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.other = nn.Parameter(torch.FloatTensor(batch_size, 1, hidden_size)).to(device)

    def forward(self, hidden, encoder_outputs):
        seq_len = encoder_outputs.size()[1]
        hidden = hidden.transpose(0, 1)
        encoder_outputs = encoder_outputs.transpose(1, 2)
        attn_energies = torch.bmm(hidden, encoder_outputs)
        result = F.softmax(attn_energies, dim = 2).unsqueeze(0).unsqueeze(0)
        return result
    

In [None]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, english_wm, attn_model, hidden_size, output_size, n_layers=1, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = nn.Dropout(dropout_p)
        
        embed_mat = torch.from_numpy(english_wm).float()
        n, embed_dim = embed_mat.shape
        self.embedding = nn.Embedding.from_pretrained(embed_mat, freeze = False)
        
        self.lstm = nn.LSTM(hidden_size * 2 + embed_dim, hidden_size * 2, n_layers, bidirectional=False, dropout=dropout_p)
        self.out = nn.Linear(hidden_size * 4, output_size)
        
        if attn_model != 'none':
            self.attn = Attn(attn_model, hidden_size)
    
    def forward(self, word_input, last_context, last_hidden, encoder_outputs):
        word_embedded = self.embedding(word_input).view(n_layers, batch_size, -1) # S=1 x B x N
        word_embedded = self.dropout(word_embedded)

        rnn_input = torch.cat((word_embedded, last_context), 2)
        rnn_output, hidden = self.lstm(rnn_input, last_hidden)
        attn_weights = self.attn(rnn_output, encoder_outputs).squeeze(0).squeeze(0)
        context = attn_weights.bmm(encoder_outputs) # B x 1 x N
        output = F.log_softmax(self.out(torch.cat((rnn_output.transpose(0, 1), context), 2)), dim = 2).squeeze(1)

        return output, context, hidden, attn_weights

__Training__

In [None]:
import time
import math


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
def train(input, target, input_len, target_len, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH, teach_forcing_ratio=0.5, encoder_cnn = False):
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    max_input_len = max(input_len)
    max_target_len = max(target_len)

    loss = 0
    
    if not encoder_cnn:
        encoder_hidden = encoder.initHidden(batch_size)
        encoder_output, encoder_hidden = encoder(input, input_len, encoder_hidden)
    else:
        encoder_hidden = encoder(input)
    decoder_context = torch.zeros((1, batch_size, decoder.hidden_size * 2), device = device)
    decoder_input = torch.tensor([[SOS_token]]*batch_size, device = device)
    decoder_hidden = torch.cat([encoder_hidden[0][0, :, :], encoder_hidden[0][1, :, :]], dim = 1)
    decoder_cell = torch.cat([encoder_hidden[1][0, :, :], encoder_hidden[1][1, :, :]], dim = 1)
    decoder_hidden = (decoder_hidden.unsqueeze(0), decoder_cell.unsqueeze(0))
    encoder_outputs = torch.cat([encoder_output[:, :, :hidden_size], encoder_output[:, :, hidden_size:]], dim = 2)
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    if use_teacher_forcing:
        for di in range(max_target_len):
            decoder_output, decoder_context, decoder_hidden, attn_weights = decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
            decoder_context = decoder_context.transpose(0, 1)
            loss += criterion(decoder_output, target[:,di])
            decoder_input = target[:,di].unsqueeze(1)  # Teacher forcing (batch_size, 1)

    else:
        for di in range(max_target_len):
            decoder_output, decoder_context, decoder_hidden, attn_weights= decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach().unsqueeze(1)  # detach from history as input
            decoder_context = decoder_context.transpose(0, 1)
            loss += criterion(decoder_output, target[:,di])
            ni = topi[0][0]
            if ni == EOS_token:
                break
                
    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()
    return loss.item() / float(max_target_len)

In [None]:
def trainIters(loader, encoder, decoder, n_iters, encoder_cnn, print_every=1000, learning_rate=LR_RATE):
    start = time.time()
    plot_losses = []
    print_loss_total = 0 
    plot_loss_total = 0  

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    
    
    criterion = nn.NLLLoss()
    
    best_bleu = None
    save_path = '/scratch/wz1218/save_model/bi-lstm-attn-infreq3_both.pt'
    
    train_loss_hist = []
    bleu_hist = []
    
    for iter in range(1, n_iters + 1):
        for i, (input, input_len, target, target_len) in enumerate(train_loader):
            loss = train(input, target, input_len, target_len, encoder, decoder, 
                         encoder_optimizer, decoder_optimizer, criterion, 
                         teach_forcing_ratio=teacher_forcing_ratio, encoder_cnn = encoder_cnn)
            print_loss_total += loss
            plot_loss_total += loss
            
            
            if (i + 1) % print_every == 0:
                current_bleu = test(encoder, decoder, val_loader, encoder_cnn)
                if not best_bleu or current_bleu > best_bleu:
                    torch.save({
                                'epoch': iter,
                                'encoder_state_dict': encoder.state_dict(),
                                'decoder_state_dict': decoder.state_dict(),
                                'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
                                'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
                                'train_loss': loss,
                                'best_BLEU': best_bleu
                                }, save_path)
                    best_bleu = current_bleu
                
                print_loss_avg = print_loss_total / print_every
                print_loss_total = 0
                
                
                train_loss_hist.append(print_loss_avg)
                bleu_hist.append(current_bleu)
                print('%s (Epoch: %d %d%%) | Train Loss: %.4f | Best Bleu: %.4f | Current Blue: %.4f' 
                      % (timeSince(start, iter / n_iters), iter, iter / n_iters * 100, print_loss_avg, best_bleu, current_bleu))
                with open('/scratch/wz1218/save_model/bi-lstm-attn-infreq3_both.txt', 'a') as file:
                    file.write('%s (Epoch: %d %d%%) | Train Loss: %.4f | Best Bleu: %.4f | Current Blue: %.4f\n' 
                      % (timeSince(start, iter / n_iters), iter, iter / n_iters * 100, print_loss_avg, best_bleu, current_bleu))
        print_loss_total = 0
    pkl.dump(train_loss_list, open('bi-lstm-attn-infreq1_loss.p', 'wb'))
    pkl.dump(bleu_hist, open('bi-lstm-attn-infreq1_bleu.p', 'wb'))
    return train_loss_hist, bleu_hist

__Test__

In [None]:
def evaluate(encoder, decoder, input, input_len, encoder_cnn, max_length=MAX_LENGTH):
    """
    Function that generate translation.
    First, feed the source sentence into the encoder and obtain the hidden states from encoder.
    Secondly, feed the hidden states into the decoder and unfold the outputs from the decoder.
    Lastly, for each outputs from the decoder, collect the corresponding words in the target language's vocabulary.
    And collect the attention for each output words.
    @param encoder: the encoder network
    @param decoder: the decoder network
    @param input: string, input sentence in source language to be translated
    @param max_length: the max # of words that the decoder can return
    @output decoded_words: a list of words in target language
    @output decoder_attentions: a list of vector, each of which sums up to 1.0
    """    
    # process input sentence
    with torch.no_grad():
        
        max_input_len = max(input_len)
        
        if not encoder_cnn:
            encoder_hidden = encoder.initHidden(batch_size)
            
            encoder_output, encoder_hidden = encoder(input, input_len, encoder_hidden)
        else:
            encoder_hidden = encoder(input)

        decoder_input = torch.tensor([[SOS_token]]*batch_size, device=device)

        decoder_hidden = encoder_hidden # decoder starts from the last encoding sentence
        decoder_context = torch.zeros((1, batch_size, decoder.hidden_size * 2), device = device)
        decoder_hidden = torch.cat([encoder_hidden[0][0, :, :], encoder_hidden[0][1, :, :]], dim = 1)
        decoder_cell = torch.cat([encoder_hidden[1][0, :, :], encoder_hidden[1][1, :, :]], dim = 1)
        decoder_hidden = (decoder_hidden.unsqueeze(0), decoder_cell.unsqueeze(0))
        encoder_outputs = torch.cat([encoder_output[:, :, :hidden_size], encoder_output[:, :, hidden_size:]], dim = 2)
        # output of this function
        decoded_words = []
        for di in range(max_length):
            # for each time step, the decoder network takes two inputs: previous outputs and the previous hidden states
            decoder_output, decoder_context, decoder_hidden, attn_weights= decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
            decoder_context = decoder_context.transpose(0, 1)
            topv, topi = decoder_output.topk(1)
            decoded_words.append(topi.cpu().numpy())
            decoder_input = topi.squeeze().detach().unsqueeze(1)  # detach from history as input
        return np.asarray(decoded_words).T#, decoder_attentions[:di + 1]

In [None]:
def test(encoder, decoder, data_loader, encoder_cnn):
    total_score = 0
    count = 0
    
    candidate_corpus = []
    reference_corpus = []

    for i, (input, input_len, target, target_len) in enumerate(data_loader):
        decoded_words = evaluate(encoder, decoder, input, input_len, encoder_cnn)
        candidate_sentences = []
        for ind in range(decoded_words.shape[1]):
            sent_words = []
            for token in decoded_words[0][ind]:
                if token != PAD_token and token != EOS_token:
#                     pdb.set_trace()
                    sent_words.append(train_output_lang.index2word[token])
                else:
                    break
            sent_words = ' '.join(sent_words)
            if count == 0:
                print('predict: '+sent_words)
                count += 1
            candidate_sentences.append(sent_words)
        candidate_corpus.extend(candidate_sentences)

        reference_sentences = []
        for sent in target:
            sent_words = []
            for token in sent:
                if token.item() != EOS_token:
                    sent_words.append(train_output_lang.index2word[token.item()])
                else:
                    break
            sent_words = ' '.join(sent_words)
            if count == 1:
                print('target: '+sent_words)
                count += 1
            reference_sentences.append(sent_words)
        reference_corpus.extend(reference_sentences)
    
    score = corpus_bleu(candidate_corpus, [reference_corpus], smooth='exp', smooth_floor=0.0, force=False).score
    return score

__Run__

In [None]:
attn_model = 'general'
encoder = EncoderRNN(chinese_wm, train_input_lang.n_words, hidden_size, 2, 1, dropout).to(device)
attn_decoder = AttnDecoderRNN(english_wm, attn_model, hidden_size, train_output_lang.n_words, n_layers, dropout_p=dropout).to(device)
trainIters(train_loader, encoder, attn_decoder, n_iters=EPOCH_NUM, encoder_cnn=False, print_every=PRINT_FREQ, learning_rate=LR_RATE)