## Seq2seq Model with Attention for Chinese-English Machine Translation

Some references on seq2seq:
* Pytorch, *seq2seq translation tutorial*, <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>
* Practical Pytorch, *Batched seq2seq*, <https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb>

![Seq2seq Model](https://pytorch.org/tutorials/_images/seq2seq.png)

Some tricky things:
* Three types of dashes in English:
    * The Hypen (-)
    * The En-dash (–)
    * The Em-dash (—)
    * Please refer to [Wikipedia]() or [English Language Help Desk](http://site.uit.no/english/punctuation/hyphen/) for more details

In [None]:
import re
import os
import sys
import time
import random
import logging, pseudologger
import pickle
import jieba
from collections import Counter
from argparse import Namespace

flags = Namespace(
    checkpoint_path='checkpoint',
    log_flag=True,
    log_path="log",
    data_path="data",
    seq_size=32,
    batch_size=32,
    embedding_size=512, # embedding dimension
    lstm_size=512, # hidden dimension
    gradients_norm=5, # gradient clipping
    top_k=5,
    num_epochs=40,
    learning_rate=0.01
)

for path in [flags.checkpoint_path,flags.log_path,flags.data_path]:
    if not os.path.exists(path):
        os.mkdir(path)

if flags.log_flag:
    logger = logging.getLogger(__name__)
    logger.setLevel(level = logging.INFO)
    handler = logging.FileHandler("{}/seq2seq-01.log".format(flags.log_path))
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)
else:
    logger = pseudologger.PseudoLogger()

logger.info(str(flags))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

In [None]:
class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.index2word = {}
        self.tmp_word_lst = []
        self.n_sentences = 0

    @staticmethod
    def normalizeString(s,lang):
        if lang == "zh":
            s = re.sub(r"&#[0-9]+;",r"",s) # so dirty!
            s = re.sub(r"�",r"",s)
            # Test if is Chinese
            # https://cloud.tencent.com/developer/article/1499958
            punc_pair = [("。","."),("！","!"),("？","?"),("，",",")]
            for zh_punc,en_punc in punc_pair:
                s = s.replace(zh_punc,en_punc)
            s = re.sub(u"[^a-zA-Z0-9\u4e00-\u9fa5,.!?]",u" ",s)
            s = re.sub(r"\s+", r" ", s)
            s = s.lower().strip()
        else: # lang == "en"
            s = re.sub(r"&#[0-9]+;",r"",s)
            s = re.sub(r"([,.!?])",r" \1",s) # add a space between these punctuations
            s = re.sub(r"[^a-zA-Z0-9,.!?]+",r" ",s) # remove most of the punctuations
            s = re.sub(r"\s+", r" ", s)
            s = s.lower().strip()
        return s

    def addSentence(self,sentence):
        self.n_sentences += 1
        if self.name == "zh": # need to use tools to split words
            cut_lst = jieba.lcut(sentence,cut_all=False) # precisely cut
            self.tmp_word_lst += filter(" ".__ne__,cut_lst) # remove all the white spaces
        else: # self.name == "en"
            self.tmp_word_lst += sentence.split()

    def getSentenceIndex(self,sentence,max_len,pad=True):
        """
        Do after processIndex
        """
        if self.name == "zh":
            cut_lst = jieba.lcut(sentence,cut_all=False)
            filter_lst = filter(" ".__ne__,cut_lst)
            res_lst = [self.word2index.get(word,self.word2index["<PAD>"]) for word in filter_lst] + [self.word2index["<EOS>"]]
            return self.padIndex(res_lst,max_len) if pad else res_lst
        else: # self.name == "en"
            res_lst = [self.word2index.get(word,self.word2index["<PAD>"]) for word in sentence.split()] + [self.word2index["<EOS>"]]
            return self.padIndex(res_lst,max_len) if pad else res_lst

    def padIndex(self,lst,max_len):
        """
        Do after processIndex
        """
        if len(lst) > max_len:
            return []
        lst += [self.word2index["<PAD>"] for i in range(max_len - len(lst))]
        return lst

    def getSentenceFromIndex(self,index_lst):
        """
        Call after processIndex
        """
        if self.name == "zh":
            return "".join([self.index2word[index] for index in index_lst])
        else:
            return " ".join([self.index2word[index] for index in index_lst])

    def processIndex(self):
        """
        Do after all the addSentence
        """
        self.word2count = Counter(self.tmp_word_lst) # {word: count}
        del self.tmp_word_lst # delete temporary word list
        self.word2count["<PAD>"] = self.n_sentences * 50 # add padding mark, label as 0
        self.word2count["<BOS>"] = self.n_sentences # add begin of sentence (BOS) mark
        self.word2count["<EOS>"] = self.n_sentences # add end of sentence (EOS) mark
        # sort based on counts, but only remain the word strings
        sorted_vocab = sorted(self.word2count, key=self.word2count.get, reverse=True)

        # make embedding based on the occurance frequency of the words
        self.index2word = {k: w for k, w in enumerate(sorted_vocab)}
        self.word2index = {w: k for k, w in self.index2word.items()}
        self.n_words = len(self.index2word)
        print('Vocabulary size of {}'.format(self.name), self.n_words)
        print(list(self.index2word.items())[:10])

In [None]:
def preprocess(mode="train",size=10000):
    """
    Source file in Chinese, target file in English

    Eg:
    巴黎-随着经济危机不断加深和蔓延，整个世界一直在寻找历史上的类似事件希望有助于我们了解目前正在发生的情况。
    PARIS – As the economic crisis deepens and widens, the world has been searching for historical analogies to help us understand what has been happening.
    """
    data_path = flags.data_path
    zh_lang_file = "{}/zh-lang-{}-{}.pkl".format(data_path,mode,size)
    en_lang_file = "{}/en-lang-{}-{}.pkl".format(data_path,mode,size)
    pairs_file = "{}/pairs-{}-{}.pkl".format(data_path,mode,size)
    if mode == "train" and os.path.isfile(zh_lang_file) and os.path.isfile(en_lang_file) and os.path.isfile(pairs_file):
        src_lang = pickle.load(open(zh_lang_file,"rb"))
        dst_lang = pickle.load(open(en_lang_file,"rb"))
        pairs = pickle.load(open(pairs_file,"rb"))
        print('Vocabulary size of {}'.format(src_lang.name), src_lang.n_words)
        print('Vocabulary size of {}'.format(dst_lang.name), dst_lang.n_words)
        return src_lang, dst_lang, pairs
    else:
        src_lang = Lang("zh")
        dst_lang = Lang("en")
        pairs = []
    path = "dataset_{}".format(size)
    set_size = 8000 if mode == "train" else 1000
    set_size = set_size * 10 if size == 100000 else set_size
    src_file = open("{}/{}_source_{}.txt".format(path,mode,set_size),"r",encoding="utf-8")
    dst_file = open("{}/{}_target_{}.txt".format(path,mode,set_size),"r",encoding="utf-8")

    print("Reading data...")
    for i,(src_line,dst_line) in enumerate(zip(src_file,dst_file),1):
        src = src_line.splitlines()[0]
        dst = dst_line.splitlines()[0]
        norm_src = Lang.normalizeString(src,"zh")
        norm_dst = Lang.normalizeString(dst,"en")
        if mode == "train":
            src_lang.addSentence(norm_src)
            dst_lang.addSentence(norm_dst)
        if i % 1000 == 0:
            print("Done {}/{}".format(i,set_size))
        pairs.append([norm_src,norm_dst])

    if mode != "train":
        return src_lang, dst_lang, pairs

    src_lang.processIndex()
    dst_lang.processIndex()

    pickle.dump(src_lang,open(zh_lang_file,"wb"))
    pickle.dump(dst_lang,open(en_lang_file,"wb"))
    pickle.dump(pairs,open(pairs_file,"wb"))
    print("Dumped to file!")
    return src_lang, dst_lang, pairs

In [None]:
from torch.utils import data

class TextDataset(data.Dataset):
    """
    My own text dataset
    ref: https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel
    """
    def __init__(self,mode="train",dataset_size=10000,max_seq_len=32,batch_size=32):
        self.src_lang, self.dst_lang, self.pairs = preprocess(mode,dataset_size)
        print("Read {} sentence pairs".format(len(self.pairs)))
        # filter long sentence pairs
        # self.pairs = self.filter_pairs()
        # print("Filtered to {} pairs".format(len(self.pairs)))
        # need to pad the sentences for easy generating Dataloader
        index_pairs_file_name = "{}/index_pairs-{}-{}.pkl".format(flags.data_path,mode,dataset_size)
        if False:
        # if os.path.isfile(index_pairs_file_name):
            self.index_pairs = pickle.load(open(index_pairs_file_name,"rb"))
        else:
            self.index_pairs = []
            for src, dst in self.pairs:
                src_index = self.src_lang.getSentenceIndex(src,max_seq_len)
                dst_index = self.dst_lang.getSentenceIndex(dst,max_seq_len)
                if len(src_index) == 0 or len(dst_index) == 0:
                    continue
                self.index_pairs.append([src_index,dst_index])
            print("Further trimmed to {} pairs".format(len(self.index_pairs)))
            pickle.dump(self.index_pairs,open(index_pairs_file_name,"wb"))
            print("Dumped index pairs!")
        self.in_text = np.array(self.index_pairs)[:,0].reshape(-1,max_seq_len)
        self.out_text = np.array(self.index_pairs)[:,1].reshape(-1,max_seq_len)
        num_pairs = len(self.in_text) // batch_size * batch_size
        self.in_text = self.in_text[:num_pairs]
        self.out_text = self.out_text[:num_pairs]
        print("Use {} pairs to {}".format(num_pairs,mode))
        print("In_text shape: {}\t Out_text shape: {}".format(self.in_text.shape,self.out_text.shape))
        print("Done generating {}_{} dataset!".format(mode,dataset_size))

    def filter_pairs(self):
        self.MIN_LENGTH = {"zh":1,"en":3}
        self.MAX_LENGTH = {"zh":60,"en":150}
        filter_pairs = []
        for pair in self.pairs:
            if self.MIN_LENGTH["zh"] <= len(pair[0]) <= self.MAX_LENGTH["zh"] \
                and self.MIN_LENGTH["en"] <= len(pair[1]) <= self.MAX_LENGTH["en"]:
                    filter_pairs.append(pair)
        return filter_pairs

    def __len__(self):
        """
        Return the total number of samples
        """
        return len(self.in_text)

    def __getitem__(self, idx):
        """
        Generate one sample of the data
        """
        x = self.in_text[idx]
        y = self.out_text[idx]
        return x, y

## RNN (LSTM / GRU)
* Reference
    * Animated RNN (LSTM & GRU), <https://towardsdatascience.com/animated-rnn-lstm-and-gru-ef124d06cf45>
    * Pytorch LSTM, <https://pytorch.org/docs/stable/nn.html#lstm>

![RNN](https://miro.medium.com/max/1516/1*yBXV9o5q7L_CvY7quJt3WQ.png)

## Encoder

![Encoder network](https://pytorch.org/tutorials/_images/encoder-network.png)

In [None]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # embed = nn.Embedding(vocab_size, vector_size)
        # "vocab_size" is the number of words in your train, val and test set
        # "vector_size" is the dimension of the word vectors you are using
        # you can view it as a linear transformation
        # the tensor is initialized randomly
        # Input: (*), LongTensor of arbitrary shape containing the indices to extract (i.e. batch size)
        # Output: (*, H), where * is the input shape and H = embedding_dim
        self.embedding = nn.Embedding(input_size, hidden_size)
        # make the embedding size equal to the hidden dimension (lstm size)
        # batch_first makes it to (batch_size, seq_len, features)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)

    def forward(self, x, prev_state):
        """
        x: (batch_size, seq_len)
            seq_len can be viewed as the time step (many small chunks)
        embedding: (batch_size, seq_len, embedding_size)
            since batch_first flag is set to True, the first dimension is batch_size
        output: (batch_size, seq_len, embedding_size)
        h_t: (1, batch_size, hidden_size) # Actually, 1 = num_layers*num_directions
        c_t: (1, batch_size, hidden_size)

        Pytorch's pack_padded_sequence can be used to
        tackle the problem of variable length sequences
        Reference:
        * https://discuss.pytorch.org/t/understanding-lstm-input/31110/3
        * https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.PackedSequence
        * https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorch
        * https://pytorch.org/docs/stable/notes/faq.html#pack-rnn-unpack-with-data-parallelism
        * https://gist.github.com/HarshTrivedi/f4e7293e941b17d19058f6fb90ab0fec
        """
        embedding = self.embedding(x)
        output, state = self.lstm(embedding, prev_state)
        return output, state

    def initHidden(self,batch_size):
        return (torch.zeros(1, batch_size, self.hidden_size, device=device), # h_t
                torch.zeros(1, batch_size, self.hidden_size, device=device)) # c_t

## Decoder
![Decoder network](https://pytorch.org/tutorials/_images/decoder-network.png)

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, output_size)
#         self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x, prev_state):
        """
        x: (batch_size, seq_len)
        embedding: (batch_size, seq_len, embedding_size) # embedding_size = hidden_size
        output: (batch_size, seq_len, embedding_size)
        h_t: (1, batch_size, hidden_size)
        c_t: (1, batch_size, hidden_size)
        """
        embedding = self.embedding(x)
        embedding = F.relu(embedding)
        output, state = self.lstm(embedding, prev_state)
        output = self.linear(output)
#         output = self.softmax(output)
        return output, state

    def initHidden(self,batch_size):
        return (torch.zeros(1, batch_size, self.hidden_size, device=device), # h_t
                torch.zeros(1, batch_size, self.hidden_size, device=device)) # c_t

## Decoder with Attention

![Decoder with Attention](https://i.imgur.com/1152PYf.png)
![pytorch decoder with attention](https://pytorch.org/tutorials/_images/attention-decoder-network.png)

In [None]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=flags.seq_size):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [None]:
def train(encoder,decoder,train_loader):
    """
    Core training function
    """

    criterion = nn.CrossEntropyLoss(ignore_index=0) # ignore padding
    encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=flags.learning_rate)
    decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=flags.learning_rate)

    iteration = 0
    losses = []

    start_time = time.time()
    for e in range(flags.num_epochs):
        encoder_ht, encoder_ct = encoder.initHidden(flags.batch_size)
        decoder_ht, decoder_ct = decoder.initHidden(flags.batch_size)

        for step, (x, y) in enumerate(train_loader):
            iteration += 1
            encoder.train()
            decoder.train()
#             encoder_ht, encoder_ct = encoder.initHidden(flags.batch_size)
#             decoder_ht, decoder_ct = decoder.initHidden(flags.batch_size)

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            x = torch.tensor(x).to(torch.int64).to(device) # (batch_size, seq_size)
            y = torch.tensor(y).to(torch.int64).to(device) # (batch_size, seq_size)

            _, (encoder_ht, encoder_ct) = encoder(x, (encoder_ht, encoder_ct))

            decoder_input = torch.tensor([0] * flags.batch_size).reshape(flags.batch_size,1).to(device) # <BOS> token
            decoder_ht, decoder_ct = encoder_ht, encoder_ct # use last hidden state from encoder
            # print(decoder_input.shape,decoder_ht.shape,decoder_ct.shape)

            # run through decoder one time step at a time
            max_dst_len = y.shape[1]
            all_decoder_outputs = torch.zeros((max_dst_len,flags.batch_size,decoder.output_size))
            for t in range(max_dst_len):
                decoder_output, (decoder_ht, decoder_ct) = decoder(decoder_input, (decoder_ht, decoder_ct))
                all_decoder_outputs[t] = decoder_output.transpose(1,0)
                # teaching forcing: next input is the current target
                decoder_input = y[:,t].reshape(flags.batch_size,1) # remember to reshape

            # loss calculation
            loss = criterion(all_decoder_outputs.transpose(1,0).transpose(1,2).to(device), y)

            loss_value = loss.item()

            loss.backward()

            # avoid delivering loss from h_t and c_t
            # thus need to remove them from the computation graph
            encoder_ht, encoder_ct = encoder_ht.detach(), encoder_ct.detach()
            decoder_ht, decoder_ct = decoder_ht.detach(), decoder_ct.detach()

            # avoid gradient explosion
            _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), flags.gradients_norm)
            _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), flags.gradients_norm)

            # update parameters with optimizers
            encoder_optimizer.step()
            decoder_optimizer.step()

            losses.append(loss_value)

            if iteration % 50 == 0:
                percent = iteration / (flags.num_epochs * len(train_loader))
                time_since = time.time() - start_time
                time_remaining = time_since / percent - time_since
                print('Epoch: {}/{}'.format(e+1, flags.num_epochs),
                      'Iteration: {}'.format(iteration),
                      'Time: {:.2f}m (- {:.2f}m)'.format(time_since/60, time_remaining/60),
                      'Loss: {}'.format(loss_value))
                logger.info('Epoch: {}/{} Iteration: {} Loss: {}'.format(e+1, flags.num_epochs, iteration, loss_value))

            if iteration % 100 == 0:
                torch.save(encoder,
                           '{}/encoder-{}.pth'.format(flags.checkpoint_path,iteration))
                torch.save(decoder,
                           '{}/decoder-{}.pth'.format(flags.checkpoint_path,iteration))

    print("Time:{}s".format(time.time()-start_time))
    torch.save(encoder,'{}/encoder-final.pth'.format(flags.checkpoint_path))
    torch.save(decoder,'{}/decoder-final.pth'.format(flags.checkpoint_path))
    return losses

In [None]:
train_set = TextDataset("train",10000,max_seq_len=flags.seq_size,batch_size=flags.batch_size)
train_loader = data.DataLoader(dataset=train_set,batch_size=flags.batch_size,shuffle=True)

In [None]:
encoder = EncoderRNN(train_set.src_lang.n_words, flags.lstm_size).to(device)
decoder = DecoderRNN(flags.lstm_size,train_set.dst_lang.n_words).to(device)
losses = train(encoder,decoder,train_loader)

In [None]:
def evaluation(in_text, out_text, src_lang, dst_lang, encoder, decoder):
    encoder.eval() # set in evaluation mode
    decoder.eval()

    x = torch.tensor(in_text).to(device).reshape(1,-1)
    y = torch.tensor(in_text).to(device)
    # encoder
    encoder_ht, encoder_ct = encoder.initHidden(1)
    _, (encoder_ht, encoder_ct) = encoder(x, (encoder_ht, encoder_ct))

    decoder_input = torch.tensor([0] * 1).reshape(1,1).to(device) # <BOS> token
    decoder_ht, decoder_ct = encoder_ht, encoder_ct # use last hidden state from encoder

    # decoder
    # run through decoder one time step at a time
    decoded_words = []
    for t in range(flags.seq_size):
        decoder_output, (decoder_ht, decoder_ct) = decoder(decoder_input, (decoder_ht, decoder_ct))
        # choose top word from output
        top_value, top_index = decoder_output.data.topk(1)
        print(top_value,top_index)
        ni = top_index[0][0].item()
        word = dst_lang.index2word[ni]
        decoded_words.append(word)
        if word == "<EOS>":
            break
        decoder_input = torch.LongTensor([ni]).reshape(1,1).to(device)

    res_words = " ".join(decoded_words)
    print("< {}".format(src_lang.getSentenceFromIndex(in_text)))
    print("= {}".format(dst_lang.getSentenceFromIndex(out_text)))
    print("> {}".format(res_words))
    return res_words

In [None]:
# test_set = TextDataset("test",10000,max_seq_len=flags.seq_size,batch_size=flags.batch_size)
# test_loader = data.DataLoader(dataset=test_set,batch_size=1,shuffle=False)
src_lang = train_set.src_lang
dst_lang = train_set.dst_lang
_, _, test_set = preprocess("test",10000)
for i,(in_text, out_text) in enumerate(test_set):
    if i == 5:
        break
    in_text = src_lang.getSentenceIndex(in_text,0,False)
    out_text = dst_lang.getSentenceIndex(out_text,0,False)
    print(in_text,out_text)
    evaluation(in_text,out_text,src_lang,dst_lang,encoder,decoder)