In [1]:
import math
import os
import random
import sys
from collections import Counter

import jieba
import nltk
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def load_data(file):
    en = []
    cn = []
    with open(file, "r") as f:
        for line in f:
            line = line.strip().split("\t")
            en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
            cn.append(
                ["BOS"] + [c for c in jieba.cut(line[1])] + ["EOS"]
            )  # jieba.cut generator to list
    return en, cn

In [6]:
def build_dict(sentences, max_words=50000):
    UNK_IDX = 0
    PAD_IDX = 1
    word_count = Counter()
    for sentence in sentences:
        for s in sentence:
            word_count[s]+=1
    ls = word_count.most_common(max_words)
    word_dict = {w[0]:index+2 for index, w in enumerate(ls)}
    word_dict["UNK"] = UNK_IDX
    word_dict["PAD"] = PAD_IDX
    total_words = len(ls) + 2
    return word_dict, total_words

In [13]:
def encode(en_sens, cn_sens, en_dict, cn_dict, sort_by_len=True):
    """
    word to number
    """
    out_en_sens = [[en_dict.get(w, 0) for w in en_sen] for en_sen in en_sens]
    out_cn_sens = [[cn_dict.get(w, 0) for w in cn_sen] for cn_sen in cn_sens]
    
    if sort_by_len:
        sorted_index = sorted(range(len(out_en_sens)), key=lambda x: len(out_en_sens[x]))
        out_en_sens = [out_en_sens[i] for i in sorted_index]
        out_cn_sens = [out_cn_sens[i] for i in sorted_index]
    return out_en_sens, out_cn_sens

In [34]:
def get_mini_batches(n, sz, shuffle=True):
    """
    seperate range(n) into batches with size of `sz`
    """
    minibatches=[np.arange(idx, min(idx+sz, n)) for idx in range(0, n, sz)]
    if shuffle:
        np.random.shuffle(minibatches)
    return minibatches

In [68]:
def prepare_data(seqs):
    """
    pading seqs to a matrix
    """
    lengths = [len(seq) for seq in seqs]
    
    x = np.zeros((len(lengths), np.max(lengths))).astype('int32')
    x_lengths = np.array(lengths).astype("int32")
    
    for i,v in enumerate(seqs):
        x[i,:lengths[i]] = v
    return x, x_lengths

In [65]:
def gen_examples(en_sens, cn_sens, minibatch_size):
    minibatches = get_mini_batches(len(en_sens), minibatch_size)
    all_ex=[]
    for minibatch in minibatches:
        mb_en_sents = [en_sens[t] for t in minibatch]
        mb_cn_sents = [cn_sens[t] for t in minibatch]
        mb_x, mb_x_len = prepare_data(mb_en_sents)
        mb_y, mb_y_len = prepare_data(mb_cn_sents)
        all_ex.append((mb_x, mb_x_len,mb_y, mb_y_len))
    return all_ex

In [3]:
train_file = "data/nmt/en-cn/train.txt"
dev_file = "data/nmt/en-cn/dev.txt"
train_en, train_cn = load_data(train_file)
dev_en, dev_cn = load_data(dev_file)

Building prefix dict from the default dictionary ...
Dumping model to file cache /tmp/jieba.cache
Loading model cost 0.556 seconds.
Prefix dict has been built successfully.


In [7]:
en_dict, en_total_words = build_dict(train_en)
cn_dict, cn_total_words = build_dict(train_cn)
inv_en_dict = {v: k for k, v in en_dict.items()}
inv_cn_dict = {v: k for k, v in cn_dict.items()}

In [14]:
train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)

In [16]:
k = 10000
print(" ".join([inv_cn_dict[i] for i in train_cn[k]]))
print(" ".join([inv_en_dict[i] for i in train_en[k]]))

BOS 他来 这里 的 目的 是 什么 ？ EOS
BOS for what purpose did he come here ? EOS


In [69]:
batch_size = 64
train_data = gen_examples(train_en, train_cn, batch_size)
random.shuffle(train_data)
dev_data = gen_examples(dev_en, dev_cn, batch_size)

### without attention

In [84]:
class PlainEncoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, dropout=0.2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, lengths):
        sorted_len, sorted_idx = lengths.sort(dim=0, descending=True)
        x_sorted = s[sorted_idx.long()]
        embedded = self.dropout(self.embed(x_sorted))
        
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(),
                                                           batch_first=True)
        packed_out , hid = self.rnn(packed_embedded)
        out, _ = nn.utils.rnn.pack_padded_sequence(packed_out, batch_first=True)
        _, original_idx = sorted_idx.sort(0, descending=False)
        out = out[original_idx.long()].contiguous()
        hid = hid[:, original_idx.long()].contiguous()
        
        return out, hid[[-1]]

In [None]:
class PlainDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, dropout=0.2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, lengths):
        sorted_len, sorted_idx = lengths.sort(dim=0, descending=True)
        x_sorted = s[sorted_idx.long()]
        embedded = self.dropout(self.embed(x_sorted))
        
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(),
                                                           batch_first=True)
        packed_out , hid = self.rnn(packed_embedded)
        out, _ = nn.utils.rnn.pack_padded_sequence(packed_out, batch_first=True)
        _, original_idx = sorted_idx.sort(0, descending=False)
        out = out[original_idx.long()].contiguous()
        hid = hid[:, original_idx.long()].contiguous()
        
        return out, hid[[-1]]

In [76]:
x = torch.randn(3, 4)

In [81]:
orted, indices = torch.sort(x, dim=0)

In [78]:
x

tensor([[ 0.6728, -1.8939, -0.3373,  0.9705],
        [ 0.9647,  0.0331,  0.1752,  2.6695],
        [ 0.9066, -0.3199,  0.5944,  0.9399]])

In [82]:
orted

tensor([[ 0.6728, -1.8939, -0.3373,  0.9399],
        [ 0.9066, -0.3199,  0.1752,  0.9705],
        [ 0.9647,  0.0331,  0.5944,  2.6695]])

In [79]:
orted

tensor([[-1.8939, -0.3373,  0.6728,  0.9705],
        [ 0.0331,  0.1752,  0.9647,  2.6695],
        [-0.3199,  0.5944,  0.9066,  0.9399]])

In [80]:
indices

tensor([[1, 2, 0, 3],
        [1, 2, 0, 3],
        [1, 2, 0, 3]])

In [83]:
indices

tensor([[0, 0, 0, 2],
        [2, 2, 1, 0],
        [1, 1, 2, 1]])