In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import os
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch.utils.data import Dataset
from torch.optim import lr_scheduler
import itertools
import glob

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
SOS_token = 0
EOS_token = 1
PAD_Itoken = 2
UNK_token = 3

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS", 2: "<pad>", 3: "<unk>"}
        self.n_words = 4  # Count SOS, EOS, pad and unk

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [3]:
# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [4]:
def readLangs(dataset, lang1, lang2):
    chinese = os.getcwd()+'/iwslt-zh-en/{}.tok.{}'.format(dataset, lang1)
    english = os.getcwd()+'/iwslt-zh-en/{}.tok.{}'.format(dataset, lang2)

    chinese_lines = open(chinese, encoding='utf-8').read().strip().split('\n')
    english_lines = open(english, encoding='utf-8').read().strip().split('\n')
    length = len(chinese_lines)

    pairs = [[chinese_lines[i], normalizeString(english_lines[i])] for i in range(length)]
    
    input_lang = Lang(chinese_lines)
    output_lang = Lang(english_lines)

    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])

    return input_lang, output_lang, pairs

In [17]:
train_input_lang, train_output_lang, train_pairs = readLangs('train', 'zh', 'en')
val_input_lang, val_output_lang, val_pairs = readLangs('dev', 'zh', 'en')
test_input_lang, test_output_lang, test_pairs = readLangs('test', 'zh', 'en')

In [36]:
class NMTDataset(Dataset):
    """
    Class that represents a train/validation/test dataset that's readable for PyTorch
    Note that this class inherits torch.utils.data.Dataset
    """

    def __init__(self, input_lang, output_lang, pairs):
        """
        @param data_list_1: list of sentence 1 tokens 
        @param data_list_2: list of sentence 2 tokens
        @param target_list: list of review targets 

        """
        self.input_lang = input_lang
        self.output_lang = output_lang
        self.pairs = pairs
        assert (len(self.input_lang.name) == len(self.pairs))
        assert (len(self.output_lang.name) == len(self.pairs))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, key):
        """
        Triggered when you call dataset[i]
        """
        input_sentence = self.pairs[key][0]
        input_indexes = [self.input_lang.word2index[word] for word in input_sentence.split(' ')]
        input_indexes.append(EOS_token)
        input_length = len(input_indexes)

        output_sentence = self.pairs[key][1]
        output_indexes = [self.output_lang.word2index[word] for word in output_sentence.split(' ')]
        output_indexes.append(EOS_token)
        output_length = len(output_indexes)
        return [input_indexes, input_length, output_indexes, output_length]

    
def NMTDataset_collate_func(batch):
    """
    Customized function for DataLoader that dynamically pads the batch so that all 
    data have the same length
    """
    input_ls = []
    output_ls = []
    input_length_ls = []
    output_length_ls = []
    
    for datum in batch:
        input_length_ls.append(datum[1])
        output_length_ls.append(datum[3])
    
    #find max length in each batch
    max_input = sorted(input_length_ls)[-1]
    max_output = sorted(output_length_ls)[-1]
    
    # padding
    for datum in batch:
        padded_vec_input = np.pad(np.array(datum[0]), 
                                  pad_width=((0,max_input-datum[1])), 
                                  mode="constant", constant_values=0).tolist()
        padded_vec_output = np.pad(np.array(datum[2]), 
                                   pad_width=((0,max_output-datum[3])), 
                                   mode="constant", constant_values=0).tolist()
        input_ls.append(padded_vec_input)
        output_ls.append(padded_vec_output)
    return [torch.tensor(torch.from_numpy(np.array(input_ls)), device=device), 
            torch.tensor(input_length_ls, device=device), 
            torch.tensor(torch.from_numpy(np.array(output_ls)), device=device), 
            torch.tensor(output_length_ls, device=device)]

In [37]:
# create pytorch dataloader
BATCH_SIZE = 32
train_dataset = NMTDataset(train_input_lang, train_output_lang, train_pairs)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=NMTDataset_collate_func,
                                           shuffle=True)

val_dataset = NMTDataset(val_input_lang, val_output_lang, val_pairs)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                         batch_size=BATCH_SIZE,
                                         collate_fn=NMTDataset_collate_func,
                                         shuffle=True)

In [39]:
next(iter(train_loader))

[tensor([[   58,    32, 12752,  ...,     0,     0,     0],
         [   50,   309, 39721,  ...,     0,     0,     0],
         [  111,   284,   175,  ...,     0,     0,     0],
         ...,
         [  329,  1026,   218,  ...,     0,     0,     0],
         [ 6494,  1397,    32,  ...,     0,     0,     0],
         [11659, 64970,     6,  ...,     0,     0,     0]], device='cuda:0'),
 tensor([ 23,  47,  25,   9,  22,  15,  12,  28,  11,  45,  25,  27,   6,   7,
           8,  12,  30,  11,  19, 110,  21,  12,  31,  17,  40,  19,  22,  14,
           5,   7,  30,  32], device='cuda:0'),
 tensor([[   44,    38,   150,  ...,     0,     0,     0],
         [  179,   503,   179,  ...,     0,     0,     0],
         [    5,    76,   212,  ...,     0,     0,     0],
         ...,
         [ 1980,   422,  4117,  ...,     0,     0,     0],
         [  695,  5688,    38,  ...,     0,     0,     0],
         [10738,  3020,   998,  ...,     0,     0,     0]], device='cuda:0'),
 tensor([22, 43, 20,