In [1]:
import os
import numpy as np
import pandas as pd
import pickle
import nltk
from nltk.tokenize import word_tokenize
from gensim.models import Word2Vec
import jieba

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

os.environ['CUDA_VISIBLE_DEVICES']='1'
tf.config.set_soft_device_placement(True)

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print("GPU is", "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")

Version:  2.5.0
Eager mode:  True
Hub version:  0.12.0
GPU is available


# Constants

In [2]:
folder_name = '20211116_wmt19_en_zh'

# Load data
Source: [data.statmt.org](http://data.statmt.org/news-commentary/v14/)

In [3]:
corpus = pd.read_csv(
    'news-commentary-v14.en-zh.tsv', 
    sep='\t', 
    error_bad_lines=False,
    skipfooter= 1,
    header = None,
)

  
Skipping line 5803: '	' expected after '"'. Error could possibly be due to parsing errors in the skipped footer rows (the skipfooter keyword is only applied after Python's csv library has parsed all rows).
Skipping line 5804: '	' expected after '"'. Error could possibly be due to parsing errors in the skipped footer rows (the skipfooter keyword is only applied after Python's csv library has parsed all rows).
Skipping line 12524: '	' expected after '"'. Error could possibly be due to parsing errors in the skipped footer rows (the skipfooter keyword is only applied after Python's csv library has parsed all rows).
Skipping line 12525: '	' expected after '"'. Error could possibly be due to parsing errors in the skipped footer rows (the skipfooter keyword is only applied after Python's csv library has parsed all rows).
Skipping line 12526: '	' expected after '"'. Error could possibly be due to parsing errors in the skipped footer rows (the skipfooter keyword is only applied after Python'

Skipping line 211623: '	' expected after '"'. Error could possibly be due to parsing errors in the skipped footer rows (the skipfooter keyword is only applied after Python's csv library has parsed all rows).
Skipping line 211645: '	' expected after '"'. Error could possibly be due to parsing errors in the skipped footer rows (the skipfooter keyword is only applied after Python's csv library has parsed all rows).
Skipping line 214161: '	' expected after '"'. Error could possibly be due to parsing errors in the skipped footer rows (the skipfooter keyword is only applied after Python's csv library has parsed all rows).
Skipping line 214165: '	' expected after '"'. Error could possibly be due to parsing errors in the skipped footer rows (the skipfooter keyword is only applied after Python's csv library has parsed all rows).
Skipping line 214184: '	' expected after '"'. Error could possibly be due to parsing errors in the skipped footer rows (the skipfooter keyword is only applied after Pyt

In [4]:
en = corpus[0].to_numpy(dtype = str)
zh = corpus[1].to_numpy(dtype = str)

In [5]:
print(en[:5])

['1929 or 1989?'
 'PARIS – As the economic crisis deepens and widens, the world has been searching for historical analogies to help us understand what has been happening.'
 'At the start of the crisis, many people likened it to 1982 or 1973, which was reassuring, because both dates refer to classical cyclical downturns.'
 'Today, the mood is much grimmer, with references to 1929 and 1931 beginning to abound, even if some governments continue to behave as if the crisis was more classical than exceptional.'
 'The tendency is either excessive restraint (Europe) or a diffusion of the effort (the United States).']


In [6]:
print(zh[:5])

['1929年还是1989年?' '巴黎-随着经济危机不断加深和蔓延，整个世界一直在寻找历史上的类似事件希望有助于我们了解目前正在发生的情况。'
 '一开始，很多人把这次危机比作1982年或1973年所发生的情况，这样得类比是令人宽心的，因为这两段时期意味着典型的周期性衰退。'
 '如今人们的心情却是沉重多了，许多人开始把这次危机与1929年和1931年相比，即使一些国家政府的表现仍然似乎把视目前的情况为是典型的而看见的衰退。'
 '目前的趋势是，要么是过度的克制（欧洲 ） ， 要么是努力的扩展（美国 ） 。']


# Tokenize

## English

### Convert the text into sequences

In [7]:
def seq_reduction(post):
    # Initialization
    token_stream = [] #control sequence length
    de_input  = []
    
    for p in post:
        # Initialize
        de_in =  []
        # Put BOS, EOS, and PAD
        de_in.extend(p)
        token_stream.extend(de_in)
        de_input.append(de_in)
    return token_stream, de_input

In [8]:
en_token_stream, en_input = seq_reduction(en)

In [9]:
print(en_input[:2])

[['1', '9', '2', '9', ' ', 'o', 'r', ' ', '1', '9', '8', '9', '?'], ['P', 'A', 'R', 'I', 'S', ' ', '–', ' ', 'A', 's', ' ', 't', 'h', 'e', ' ', 'e', 'c', 'o', 'n', 'o', 'm', 'i', 'c', ' ', 'c', 'r', 'i', 's', 'i', 's', ' ', 'd', 'e', 'e', 'p', 'e', 'n', 's', ' ', 'a', 'n', 'd', ' ', 'w', 'i', 'd', 'e', 'n', 's', ',', ' ', 't', 'h', 'e', ' ', 'w', 'o', 'r', 'l', 'd', ' ', 'h', 'a', 's', ' ', 'b', 'e', 'e', 'n', ' ', 's', 'e', 'a', 'r', 'c', 'h', 'i', 'n', 'g', ' ', 'f', 'o', 'r', ' ', 'h', 'i', 's', 't', 'o', 'r', 'i', 'c', 'a', 'l', ' ', 'a', 'n', 'a', 'l', 'o', 'g', 'i', 'e', 's', ' ', 't', 'o', ' ', 'h', 'e', 'l', 'p', ' ', 'u', 's', ' ', 'u', 'n', 'd', 'e', 'r', 's', 't', 'a', 'n', 'd', ' ', 'w', 'h', 'a', 't', ' ', 'h', 'a', 's', ' ', 'b', 'e', 'e', 'n', ' ', 'h', 'a', 'p', 'p', 'e', 'n', 'i', 'n', 'g', '.']]


In [10]:
print('num_of_pairs',len(en_input))
words=list(set(en_token_stream))
words.remove(' ')
words.append('<bos>')
words.append('<eos>')

en_word2idx={}
en_word2idx[' '] = 0
for i, word in enumerate(words):
    en_word2idx[word]=i+1
en_num_words = len(en_word2idx)
print(f"num_words:{en_num_words}")
en_idx2word = {v:k for k, v in en_word2idx.items()}

num_of_pairs 314487
num_words:199


In [11]:
en_seq=[[en_word2idx[t] for t in seq] for seq in en_input]

### Count the average sequence length.

In [12]:
en_seq_len = [len(s) for s in en_seq]
print(np.mean(en_seq_len))
print(np.std(en_seq_len))

134.79737159246645
71.46306078397407


In [13]:
en_seq_length = round(np.mean(en_seq_len)+np.std(en_seq_len))
print(f'Length of en: {en_seq_length}')

Length of en: 206


## Chinese

In [14]:
zh_token_stream, zh_input = seq_reduction(zh)

In [15]:
print(zh_input[:2])

[['1', '9', '2', '9', '年', '还', '是', '1', '9', '8', '9', '年', '?'], ['巴', '黎', '-', '随', '着', '经', '济', '危', '机', '不', '断', '加', '深', '和', '蔓', '延', '，', '整', '个', '世', '界', '一', '直', '在', '寻', '找', '历', '史', '上', '的', '类', '似', '事', '件', '希', '望', '有', '助', '于', '我', '们', '了', '解', '目', '前', '正', '在', '发', '生', '的', '情', '况', '。']]


In [16]:
print('num_of_pairs',len(zh_input))
words=list(set(zh_token_stream))
words.remove(' ')

zh_word2idx={}
zh_word2idx[' ']=0
for i, word in enumerate(words):
    zh_word2idx[word]=i+1
zh_num_words = len(zh_word2idx)
print(f"num_words:{zh_num_words}")
zh_idx2word = {v:k for k, v in zh_word2idx.items()}

num_of_pairs 314487
num_words:4716


In [17]:
zh_seq=[[zh_word2idx[t] for t in seq] for seq in zh_input]

### Count the average sequence length.

In [18]:
zh_seq_len = [len(s) for s in zh_seq]
print(np.mean(zh_seq_len))
print(np.std(zh_seq_len))

41.59374155370492
23.273498491636335


In [19]:
zh_seq_length = round(np.mean(zh_seq_len)+np.std(zh_seq_len))
print(f'Length of en: {zh_seq_length}')

Length of en: 65


# Remove long sentences

In [20]:
keep = [[zh, en] for zh, en in zip(zh_seq, en_seq) 
            if len(zh)<=zh_seq_length and len(zh) > 5 and len(en)<=en_seq_length
       ]
print(len(keep))

245599


# Add BOS and EOS into decoder

In [21]:
zh_reduce_seq = [pair[0] for pair in keep]
en_reduce_seq = [pair[1] for pair in keep]
bosIdx = en_word2idx['<bos>']
eosIdx = en_word2idx['<eos>']
for i, seq in enumerate(en_reduce_seq):
    en_reduce_seq[i] = [bosIdx]
    en_reduce_seq[i].extend(seq)
    en_reduce_seq[i].append(eosIdx)

# Padding

In [22]:
en_pad_seq = tf.keras.preprocessing.sequence.pad_sequences(
    en_reduce_seq,
    dtype='int32',
    padding='post',
)

In [23]:
print(en_pad_seq.shape)

(245599, 208)


In [24]:
zh_pad_seq = tf.keras.preprocessing.sequence.pad_sequences(
    zh_reduce_seq,
    dtype='int32',
    padding='post',
)

In [25]:
print(zh_pad_seq.shape)

(245599, 65)


# Distribute

In [26]:
num_pair = len(en_pad_seq)
encoder_train = zh_pad_seq[int(num_pair*0.1):]
decoder_train = en_pad_seq[int(num_pair*0.1):, :-1]
teacher_train = en_pad_seq[int(num_pair*0.1):, 1:]
encoder_vali  = zh_pad_seq[:int(num_pair*0.1)]
decoder_vali  = en_pad_seq[:int(num_pair*0.1), :-1]
teacher_vali  = en_pad_seq[:int(num_pair*0.1), 1:]

In [27]:
print(encoder_train.shape)
print(decoder_train.shape)
print(teacher_train.shape)

(221040, 65)
(221040, 207)
(221040, 207)


# Word2Vec

## 32 dim. English embedding 

In [30]:
bosIdx = en_word2idx['<bos>']
eosIdx = en_word2idx['<eos>']
tmp = []
for i, seq in enumerate(en_seq):
    tmp.append([])
    tmp[i] = [bosIdx]
    tmp[i].extend(seq)
    tmp[i].append(eosIdx)

model = Word2Vec(
    sentences=tmp, 
    vector_size=32, 
    window=5, 
    min_count=1, 
    workers=16,
    sg = 1,
    negative = 10,
)
model.build_vocab(tmp)
model.train(tmp, total_examples=model.corpus_count, epochs=10)

(92130499, 430209950)

In [31]:
model.wv[198]

array([-0.45359987,  1.9887191 ,  0.59196424,  0.04850034, -0.11136319,
        1.045833  ,  0.0325858 ,  0.921646  , -0.1453513 , -0.29940316,
       -0.13449323,  0.08392815, -0.19487256, -0.3319154 , -0.05449354,
        0.5141277 ,  0.86144334, -0.25267354,  0.58859575, -0.41372308,
        0.19209178, -0.13207915,  0.2859081 ,  0.71193624, -0.19737837,
       -0.5219248 , -0.44954288, -0.54669255,  0.6119101 ,  0.06045213,
        0.3184272 ,  0.14369203], dtype=float32)

In [32]:
en_emb32 = np.array([ model.wv[i] for i in range(en_num_words)])

## 32 dim. Chinese embedding

In [33]:
tmp = [list(e) for e in zh_seq]
model = Word2Vec(
    sentences=tmp, 
    vector_size=32, 
    window=5, 
    min_count=1, 
    workers=16,
    sg = 1,
    negative = 10,
)
model.build_vocab(tmp)
model.train(tmp, total_examples=model.corpus_count, epochs=10)

(109351746, 130806910)

In [34]:
zh_emb32 = np.array([ model.wv[i] for i in range(zh_num_words)])

In [35]:
print(zh_idx2word[4692])

垢


# Save preprocessed data

In [36]:
pickle.dump(encoder_train, open(f'{folder_name}/encoder_train.pkl','wb'))
pickle.dump(decoder_train, open(f'{folder_name}/decoder_train.pkl','wb'))
pickle.dump(teacher_train, open(f'{folder_name}/teacher_train.pkl','wb'))
pickle.dump(encoder_vali,  open(f'{folder_name}/encoder_vali.pkl','wb'))
pickle.dump(decoder_vali,  open(f'{folder_name}/decoder_vali.pkl','wb'))
pickle.dump(teacher_vali,  open(f'{folder_name}/teacher_vali.pkl','wb'))

pickle.dump(en_idx2word, open(f'{folder_name}/en_idx2word.pkl','wb'))
pickle.dump(en_word2idx, open(f'{folder_name}/en_word2idx.pkl','wb'))
pickle.dump(zh_idx2word, open(f'{folder_name}/zh_idx2word.pkl','wb'))
pickle.dump(zh_word2idx, open(f'{folder_name}/zh_word2idx.pkl','wb'))

pickle.dump(en_emb32, open(f'{folder_name}/en_emb32.pkl','wb'))
pickle.dump(zh_emb32, open(f'{folder_name}/zh_emb32.pkl','wb'))

# Load data

In [37]:
encoder_train = pickle.load(open(f'{folder_name}/encoder_train.pkl', 'rb'))
decoder_train = pickle.load(open(f'{folder_name}/decoder_train.pkl', 'rb'))
teacher_train = pickle.load(open(f'{folder_name}/teacher_train.pkl', 'rb'))
encoder_vali  = pickle.load(open(f'{folder_name}/encoder_vali.pkl', 'rb'))
decoder_vali  = pickle.load(open(f'{folder_name}/decoder_vali.pkl', 'rb'))
teacher_vali  = pickle.load(open(f'{folder_name}/teacher_vali.pkl', 'rb'))

en_idx2word   = pickle.load(open(f'{folder_name}/en_idx2word.pkl','rb'))
en_word2idx   = pickle.load(open(f'{folder_name}/en_word2idx.pkl','rb'))
zh_idx2word   = pickle.load(open(f'{folder_name}/zh_idx2word.pkl','rb'))
zh_word2idx   = pickle.load(open(f'{folder_name}/zh_word2idx.pkl','rb'))

en_emb32  = pickle.load(open(f'{folder_name}/en_emb32.pkl', 'rb'))
zh_emb32 = pickle.load(open(f'{folder_name}/zh_emb32.pkl', 'rb'))

In [38]:
print(zh_idx2word[3742])

拘


In [39]:
def seq2word(seq_tensor, idx2word):
    return np.array([[idx2word[i] for i in seq] for seq in seq_tensor])

In [40]:
seq2word(decoder_train[5:6], en_idx2word)

array([['<bos>', 'E', 'v', 'e', 'n', ' ', 'm', 'o', 'r', 'e', ' ', 't',
        'o', ' ', 't', 'h', 'e', ' ', 'p', 'o', 'i', 'n', 't', ',', ' ',
        'p', 'o', 'l', 'i', 'c', 'y', 'm', 'a', 'k', 'e', 'r', 's', ' ',
        'n', 'e', 'e', 'd', ' ', 't', 'o', ' ', 'a', 'd', 'd', 'r', 'e',
        's', 's', ' ', 't', 'h', 'e', ' ', 'f', 'i', 'n', 'a', 'n', 'c',
        'i', 'a', 'l', 'i', 'z', 'a', 't', 'i', 'o', 'n', ' ', 'o', 'f',
        ' ', 't', 'h', 'e', ' ', 'p', 'h', 'a', 'r', 'm', 'a', 'c', 'e',
        'u', 't', 'i', 'c', 'a', 'l', ' ', 'i', 'n', 'd', 'u', 's', 't',
        'r', 'y', ',', ' ', 'w', 'h', 'i', 'c', 'h', ' ', 'i', 's', ' ',
        'f', 'o', 'c', 'u', 's', 'e', 'd', ' ', 's', 'o', 'l', 'e', 'l',
        'y', ' ', 'o', 'n', ' ', 's', 'h', 'a', 'r', 'e', 'h', 'o', 'l',
        'd', 'e', 'r', ' ', 'v', 'a', 'l', 'u', 'e', ',', ' ', 'r', 'a',
        't', 'h', 'e', 'r', ' ', 't', 'h', 'a', 'n', ' ', 'o', 'n', ' ',
        'a', 'l', 'l', ' ', 's', 't', 'a', 'k', 'e',