In [2]:
import os
import sys
from collections import Counter
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import nltk

In [3]:
def load_data(file):
    cn = []
    en = []
    with open(file, 'r', encoding = 'utf-8') as f:
        for line in f:
            line = line.strip().split('\t')
            en.append(['BOS'] + nltk.word_tokenize(line[0].lower()) + ['EOS'])
            cn.append(['BOS'] + [w for w in line[1]] + ['EOS'])
    return en, cn 

train_file = 'nmt/train.txt'
dev_file = 'nmt/dev.txt'
test_file = 'nmt/test.txt'
train_en, train_cn = load_data(train_file)
dev_en, dev_cn = load_data(dev_file)
test_en, test_cn = load_data(test_file)

In [4]:
print(train_en[:5]) 
print('------------------------------------------------------------------')
print(train_cn[:5])

[['BOS', 'anyone', 'can', 'do', 'that', '.', 'EOS'], ['BOS', 'how', 'about', 'another', 'piece', 'of', 'cake', '?', 'EOS'], ['BOS', 'she', 'married', 'him', '.', 'EOS'], ['BOS', 'i', 'do', "n't", 'like', 'learning', 'irregular', 'verbs', '.', 'EOS'], ['BOS', 'it', "'s", 'a', 'whole', 'new', 'ball', 'game', 'for', 'me', '.', 'EOS']]
------------------------------------------------------------------
[['BOS', '任', '何', '人', '都', '可', '以', '做', '到', '。', 'EOS'], ['BOS', '要', '不', '要', '再', '來', '一', '塊', '蛋', '糕', '？', 'EOS'], ['BOS', '她', '嫁', '给', '了', '他', '。', 'EOS'], ['BOS', '我', '不', '喜', '欢', '学', '习', '不', '规', '则', '动', '词', '。', 'EOS'], ['BOS', '這', '對', '我', '來', '說', '是', '個', '全', '新', '的', '球', '類', '遊', '戲', '。', 'EOS']]


In [5]:
unk_id = 0
pad_id = 1

def build_dict(sentences, max_words = 50000):
    word_count = Counter()
    for sentence in sentences:
        for w in sentence:
            word_count[w] += 1
    ls = word_count.most_common(max_words)
    total_words = len(ls) + 2
    print(total_words)
    
    word_dict = {w[0]: index+2 for index, w in enumerate(ls)}  
    word_dict['UNK'] = unk_id
    word_dict['PAD'] = pad_id
    
    return word_dict, total_words

en_dict_id, en_total_words = build_dict(train_en)
cn_dict_id, cn_total_words = build_dict(train_cn)

id_dict_en = {v: k for k, v in en_dict_id.items()}
id_dict_cn = {v: k for k, v in cn_dict_id.items()}

5493
3195


In [6]:
print(en_total_words)
print(list(en_dict_id.items())[:10]) # 取出前10个
print()
print(list(en_dict_id.items())[-10:]) # 取出后10个，可以看到"unk"和"pad"在最后
print("---"*20)
print(cn_total_words)
print(list(cn_dict_id.items())[:10]) # 查看中文
print()
print(list(cn_dict_id.items())[-10:]) 
print("---"*20)
print(list(id_dict_en.items())[:10]) # 键值对调换
print()
print(list(id_dict_en.items())[-10:]) 
print("---"*20)
print(list(id_dict_cn.items())[:10]) # 键值对调换
print()
print(list(id_dict_cn.items())[-10:]) 

5493
[('BOS', 2), ('EOS', 3), ('.', 4), ('i', 5), ('the', 6), ('to', 7), ('you', 8), ('a', 9), ('is', 10), ('?', 11)]

[('opposition', 5485), ('springs', 5486), ('schoolroom', 5487), ('absence', 5488), ('fonder', 5489), ('field', 5490), ('educational', 5491), ('foster', 5492), ('UNK', 0), ('PAD', 1)]
------------------------------------------------------------
3195
[('BOS', 2), ('EOS', 3), ('。', 4), ('我', 5), ('的', 6), ('了', 7), ('你', 8), ('他', 9), ('是', 10), ('一', 11)]

[('鷹', 3187), ('鸚', 3188), ('鵡', 3189), ('寵', 3190), ('鳴', 3191), ('缓', 3192), ('黨', 3193), ('釘', 3194), ('UNK', 0), ('PAD', 1)]
------------------------------------------------------------
[(2, 'BOS'), (3, 'EOS'), (4, '.'), (5, 'i'), (6, 'the'), (7, 'to'), (8, 'you'), (9, 'a'), (10, 'is'), (11, '?')]

[(5485, 'opposition'), (5486, 'springs'), (5487, 'schoolroom'), (5488, 'absence'), (5489, 'fonder'), (5490, 'field'), (5491, 'educational'), (5492, 'foster'), (0, 'UNK'), (1, 'PAD')]
-------------------------------------

In [7]:
#将单词全部转变成数字
def encode(en_sentences, cn_sentences, en_dict_id, cn_dict_id, sort_by_len = True):
    length = len(en_sentences)
    out_en_sentences = [[en_dict_id.get(w, 0) for w in sent] for sent in en_sentences]
    out_cn_sentences = [[cn_dict_id.get(w, 0) for w in sent] for sent in cn_sentences]
    
    def len_sort(seq):  # 将句子按长度进行排列，给出从小到大的句子的索引号
        return sorted(range(len(seq)), key = lambda x: len(seq[x]))
    
    if sort_by_len:
        sorted_index = len_sort(out_en_sentences)  # 将句子按长度进行排列，给出从小到大的句子的索引号
        out_en_sentences = [out_en_sentences[i] for i in sorted_index]  #通过索引将句子取出
        out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]
    
    return out_en_sentences, out_cn_sentences

train_en, train_cn = encode(train_en, train_cn, en_dict_id, cn_dict_id)
dev_en, dev_cn = encode(dev_en, dev_cn, en_dict_id, cn_dict_id)

In [8]:
print(train_en[:10])
print('=====================================')
print(train_cn[:10])
print('====================================')
print([id_dict_en[i] for i in train_en[100]])
print([id_dict_cn[i] for i in train_cn[100]])
print(' '.join([id_dict_en[i] for i in train_en[100]]))
print(' '.join([id_dict_cn[i] for i in train_cn[100]]))

[[2, 475, 4, 3], [2, 1318, 126, 3], [2, 1707, 126, 3], [2, 254, 126, 3], [2, 1318, 126, 3], [2, 130, 11, 3], [2, 2045, 126, 3], [2, 693, 126, 3], [2, 2266, 126, 3], [2, 1707, 126, 3]]
[[2, 8, 87, 441, 6, 4, 3], [2, 119, 1368, 221, 3], [2, 982, 2028, 8, 4, 3], [2, 239, 239, 221, 3], [2, 151, 190, 221, 3], [2, 8, 546, 162, 14, 3], [2, 141, 488, 6, 221, 3], [2, 18, 489, 221, 3], [2, 189, 158, 221, 3], [2, 2110, 60, 221, 3]]
['BOS', 'get', 'out', '!', 'EOS']
['BOS', '滾', '出', '去', '！', 'EOS']
BOS get out ! EOS
BOS 滾 出 去 ！ EOS


In [9]:
# n是数据集长度，minibatch_size为一个batc中的个数
def get_minibatches(n, minibatch_size, shuffle = True):
    id_list = np.arange(0, n, minibatch_size)
    if shuffle:
        np.random.shuffle(id_list)
    minibatch = []
    for i in id_list:
        minibatch.append(np.arange(i, min(i+minibatch_size, n)))
    return minibatch

In [10]:
get_minibatches(100, 15)

[array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]),
 array([90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
 array([75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]),
 array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]),
 array([15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]),
 array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]),
 array([45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59])]

In [13]:
def pad_data(seqs):
    lengths = [len(seq) for seq in seqs]
    n_samples = len(seqs)
    max_len = np.max(lengths)
    x = np.zeros((n_samples, max_len)).astype('int32') #[0,0,0,0,0,0,0,0,0,0]
    x_lengths = np.array(lengths).astype('int32')
    for i, seq in enumerate(seqs):
        x[i, :lengths[i]] = seq  #将全是0的进行填充为之前的seq, [2,3,7,1,0,0,0,0,0]
    return x, x_lengths #x：填充好的seq，x_lengths：seqs里每个seq的长度
        
        
        
def gen_examples(en_sentences, cn_sentences, batch_size):
    minibatches = get_minibatches(len(en_sentences), batch_size)
    all_ex = []
    for minibatch in minibatches:
        batch_en_sentences = [en_sentences[t] for t in minibatch]
        batch_cn_sentences = [cn_sentences[t] for t in minibatch]
        mb_zn, mb_zn_len = pad_data(batch_en_sentences)
        mb_cn, mb_cn_len = pad_data(batch_cn_sentences)
        
        all_ex.append((mb_zn, mb_zn_len, mb_cn, mb_cn_len))
    return all_ex        

batch_size = 64
train_data = gen_examples(train_en, train_cn, batch_size)
random.shuffle(train_data)
dev_data = gen_examples(dev_en, dev_cn, batch_size)

In [20]:
print(train_data[0][0].shape)
print(train_data[0][1].shape)
print(train_data[0][2].shape)
print(train_data[0][3].shape)
print(train_data[0])

(64, 9)
(64,)
(64, 16)
(64,)
(array([[   2,    5,   56,   73,    8,  150,  311,    4,    3],
       [   2,   32,  261,   10,  213,  368, 1747,    4,    3],
       [   2,   12,   93,   35,  365,    9,  925,    4,    3],
       [   2,   77,   10,  698,   15,  125, 1048,    4,    3],
       [   2,  643,   10,   66,   26,   32,  477,    4,    3],
       [   2,   51, 1400,  515,   62,   21, 4376,    4,    3],
       [   2,   31,    5,   42,  608,   10,  357,    4,    3],
       [   2,  505,   10, 1900,   33,   25,  492,    4,    3],
       [   2,   18,  542,   23, 2553,   21,  261,    4,    3],
       [   2,   29,   84,   33,    6,  200,   44,    4,    3],
       [   2,   29,  109,  124,  812,   18,   57,    4,    3],
       [   2,  709,    8,   67,  111,  154,   57,    4,    3],
       [   2,   19, 1378,    7,   22,    9, 1140,    4,    3],
       [   2,    5,   79,  279,  274,   21,  314,    4,    3],
       [   2,   80,   14,    8,  113,   16,   65,   11,    3],
       [   2,   16,   27,

In [24]:
class PlainEncoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, dropout=0.2):
        super(PlainEncoder, self).__init__()
        self.embed = nn.embedding(vocab_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first = True)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, lengths):
        '''x:batch的句子， lengths为句子的长度'''
        # 因为需要把最后一个hidden state取出来，需要知道长度，因为句子长度不一样
        sorted_len, sorted_id = lengths.sort(0, descending = True)
        x_sorted = x[sorted_id.long()]
        embedded = self.dropout(self.embed(x_sorted))
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first = True)
        packed_out, hid = self.rnn(packed_embedded)
        out, _ = nn.utils.rnn.pad_packed_sentence(packed_out, batch_first=True)
        _, original_id = sorted_id.sort(0, descending=False)
        out = out[original_id.long()].contiguous()
        hid = hid[:, original_id.long()].contigous()
        
        return out, hid[[-1]]

In [25]:
class PlainDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, dropout=0.2):
        super(PlainDecoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, bathc_first=True)
        self.out = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, y, y_lengths, hid):
        sorted_len, sorted_id = y_lengths.sort(0, descending=True)
        y_sorted = y[sorted_id.long()]
        hid = hid[:, sorted_id.long()]
        y_sorted = self.dropout(self.embed(y_sorted))
        
        packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first = True)
        out, hid = self.rnn(packed_seq, hid)
        unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_size=True)
        _, original_id = sorted_id.sort(0, descending=False)
        output_seq = unpacked[original_id.long()].contiguous()
        hid = hid[:, original_id.long()].contiguous()
        output = F.log_softmax(self.out(output_seq), -1)
        return output, hid