## Seq2seq Model with Attention for Chinese-English Machine Translation

Some references on seq2seq:
* Pytorch, <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html>

![Seq2seq Model](https://pytorch.org/tutorials/_images/seq2seq.png)

Some tricky things:
* Three types of dashes in English:
    * The Hypen (-)
    * The En-dash (–)
    * The Em-dash (—)
    * Please refer to [Wikipedia]() or [English Language Help Desk](http://site.uit.no/english/punctuation/hyphen/) for more details

In [24]:
"""
Preprocessing
"""
import re
import os
import sys
import random
import pickle
import jieba
from collections import Counter

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.index2word = {}
        self.tmp_word_lst = []
        self.n_sentences = 0
        self.n_words = 0

    @staticmethod
    def normalizeString(s,lang):
        if lang == "zh":
            s = re.sub(r"&#[0-9]+;",r"",s) # so dirty!
            s = re.sub(r"�",r"",s)
            # Test if is Chinese
            # https://cloud.tencent.com/developer/article/1499958
            punc_pair = [("。","."),("！","!"),("？","?")]
            for zh_punc,en_punc in punc_pair:
                s = s.replace(zh_punc,en_punc)
            s = re.sub(u"[^a-zA-Z0-9\u4e00-\u9fa5.!?]",u" ",s)
            s = s.lower().strip()
        else: # lang == "en"
            s = re.sub(r"&#[0-9]+;",r"",s)
            s = re.sub(r"([.!?])",r" \1",s) # add a space between these punctuations
            s = re.sub(r"[^a-zA-Z0-9.!?]+",r" ",s) # remove most of the punctuations
            s = s.lower().strip()
        return s

    def addSentence(self,sentence):
        self.n_sentences += 1
        if self.name == "zh": # need to use tools to split words
            cut_lst = jieba.lcut(sentence,cut_all=False) # precisely cut
            self.tmp_word_lst += filter(" ".__ne__,cut_lst) # remove all the white spaces
        else: # self.name == "en"
            self.tmp_word_lst += sentence.split()

    def processIndex(self):
        self.word2count = Counter(self.tmp_word_lst) # {word: count}
        self.word2count["SOS"] = self.n_sentences # add begin of sentence (BOS) mark
        self.word2count["EOS"] = self.n_sentences # add end of sentence (EOS) mark
        # sort based on counts, but only remain the word strings
        sorted_vocab = sorted(self.word2count, key=self.word2count.get, reverse=True)

        # make embedding based on the occurance frequency of the words
        self.index2word = {k: w for k, w in enumerate(sorted_vocab)}
        self.word2index = {w: k for k, w in self.index2word.items()}
        self.n_word = len(self.index2word)
        print('Vocabulary size', self.n_word)
        print(list(self.index2word.items())[:10])

def preprocess(mode="train",size=10000):
    """
    Source file in Chinese, target file in English

    Eg:
    巴黎-随着经济危机不断加深和蔓延，整个世界一直在寻找历史上的类似事件希望有助于我们了解目前正在发生的情况。
    PARIS – As the economic crisis deepens and widens, the world has been searching for historical analogies to help us understand what has been happening.
    """
    data_path = "data"
    zh_lang_file = "{}/zh-lang-{}.pkl".format(data_path,size)
    en_lang_file = "{}/en-lang-{}.pkl".format(data_path,size)
    pairs_file = "{}/pairs-{}.pkl".format(data_path,size)
    if os.path.isfile(zh_lang_file) and os.path.isfile(en_lang_file) and os.path.isfile(pairs_file):
        src_lang = pickle.load(open(zh_lang_file,"rb"))
        dst_lang = pickle.load(open(en_lang_file,"rb"))
        pairs = pickle.load(open(pairs_file,"rb"))
        return src_lang, dst_lang, pairs
    else:
        src_lang = Lang("zh")
        dst_lang = Lang("en")
        pairs = []
    path = "dataset_{}".format(size)
    set_size = 8000 if mode == "train" else 1000
    set_size = set_size * 10 if size == 100000 else set_size
    src_file = open("{}/{}_source_{}.txt".format(path,mode,set_size),"r",encoding="utf-8")
    dst_file = open("{}/{}_target_{}.txt".format(path,mode,set_size),"r",encoding="utf-8")

    print("Reading data...")
    for i,(src_line,dst_line) in enumerate(zip(src_file,dst_file),1):
        src = src_line.splitlines()[0]
        dst = dst_line.splitlines()[0]
        norm_src = Lang.normalizeString(src,"zh")
        norm_dst = Lang.normalizeString(dst,"en")
        src_lang.addSentence(norm_src)
        dst_lang.addSentence(norm_dst)
        # print(i,norm_src,norm_dst,end="\t")
        if i % 1000 == 0:
            print("Done {}/8000".format(i))
        pairs.append([norm_src,norm_dst])
        # sys.exit()

    src_lang.processIndex()
    dst_lang.processIndex()

    print("Dumped to file!")
    pickle.dump(src_lang,open(zh_lang_file,"wb"))
    pickle.dump(dst_lang,open(en_lang_file,"wb"))
    pickle.dump(pairs,open(pairs_file,"wb"))
    return src_lang, dst_lang, pairs

src_lang, dst_lang, pairs = preprocess("train",10000)
print(random.choice(pairs))

['1  这意味着就业岗位创造不振. 在此次和前两次复苏中 就业增长的反弹都比gdp弱 而且滞后于gdp.', 'in both this recovery and the previous two the rebound in employment growth has been weaker and later than the rebound in gdp growth .']
