In [37]:
import jieba
import zhconv
import tensorflow as tf
import re


class Vocab:

    def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
        counter = Vocab.count_corpus(tokens)
        # 对词频率排序
        self.__token_freqs = sorted(counter.items(),
                                    key=lambda x: x[1],
                                    reverse=True)
        self.index_to_token = ['<unk>'] + reserved_tokens
        self.token_to_index = {
            token: idx
            for idx, token in enumerate(self.index_to_token)
        }
        for token, freq in self.__token_freqs:
            if freq >= min_freq and token not in self.token_to_index:
                self.index_to_token.append(token)
                self.token_to_index[token] = len(self.index_to_token) - 1

    def __len__(self):
        return len(self.index_to_token)

    def get_tokens(self, indicates):
        if not isinstance(indicates, (list, tuple)):
            return self.index_to_token[indicates]
        return ''.join([self.get_tokens(index) for index in indicates])

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_index.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    @property
    def unk(self):
        return 0

    @property
    def token_freqs(self):
        return self.__token_freqs

    @staticmethod
    def count_corpus(tokens):
        if isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        from collections import Counter
        return Counter(tokens)


def truncate_and_pad(line, steps, padding_token):
    """
    \{填充||截断\}序列，使序列长度保持一致
    """
    if len(line) > steps:
        return line[:steps]
    return line + [padding_token] * (steps - len(line))


def load_datasets(steps=20, batch_size=50):
    """
    预处理数据并封装成tf.data.Dataset
    """
    with open('./en_zh.trans.txt', 'r') as data_file:
        lines = data_file.readlines()
    # print(zhconv.convert(s[1], 'zh-cn'))
    en, zh = [], []
    for line in lines:
        split = line.split('\t')
        en.append(
            truncate_and_pad(
                re.sub('[^A-Za-z]+', ' ', split[0]).strip().lower().split(' ')
                + ['<eos>'], steps, '<pad>')), zh.append(
                    truncate_and_pad(
                        list(
                            jieba.cut(zhconv.convert(split[1], 'zh-cn'),
                                      cut_all=False)) + ['<eos>'], steps,
                    '<pad>'))
    en_vocab, zh_vocab = Vocab(
        en, min_freq=2,
        reserved_tokens=['<pad>', '<bos>', '<eos>'
                         ]), Vocab(zh,
                                   min_freq=2,
                                   reserved_tokens=['<pad>', '<bos>', '<eos>'])
    en = tf.constant([en_vocab[line] for line in en], dtype='float32')
    zh = tf.constant([zh_vocab[line] for line in zh], dtype='float32')
    en_len = tf.reduce_sum(tf.cast(en != en_vocab['<pad>'], dtype='float32'), axis=1)
    zh_len = tf.reduce_sum(tf.cast(zh != zh_vocab['<pad>'], dtype='float32'), axis=1)
    ds = (
        tf.data.Dataset.from_tensor_slices(en),
        tf.data.Dataset.from_tensor_slices(en_len),
        tf.data.Dataset.from_tensor_slices(zh),
        tf.data.Dataset.from_tensor_slices(zh_len),
    )
    train_iter = tf.data.Dataset.zip(ds).shuffle(buffer_size=len(en)).batch(batch_size=batch_size)
    return train_iter, en_vocab, zh_vocab

train_iter, en_vocab, zh_vocab = load_datasets()

In [None]:
class Seq2SeqEncoder(tf.keras.layers.Layer):
    
    def __init__(self, vocab_size, embed_size, hiddens, layers, dropout=0., **kwargs):
        super.__init__(Seq2SeqEncoder, **kwargs)
        self.embedding = tf.keras.layers.Embedding(vocab_size, embed_size)
        self.rnn_net = tf.keras.layers.StackedRNNCells([
                tf.keras.layers.GRUCell(hiddens, dropout=dropout) for _ in range(layers)
            ], return_sequences=True, return_state=True)
    
    def call(self, x, *args, **kwargs):
        x = self.embedding(x)
        y = self.rnn_net(x, *args, **kwargs)
        state = y[1:]
        return y[0], state

class Seq2SeqDecoder(tf.keras.layers.Layer):
    
    def __init__(self, vocab_size, embed_size, hiddens, layers, dropout=0., **kwargs):
        super.__init__(Seq2SeqDecoder, **kwargs)
        self.embedding = tf.keras.layers.Embedding(vocab_size, embed_size)
        self.rnn_net = tf.keras.layers.StackedRNNCells([
            tf.keras.layers.GRUCell(hiddens, dropout=dropout) for _ in range(layers)
            ], return_sequences=True, return_state=True)
        self.output = tf.keras.layers.Dense(vocab_size)
        
    def call(self, x, state, **kwargs):
        x = self.embedding(x)
        context = tf.repeat(tf.expand_dims(state[-1], axis=1), repeats=x.shape[1], axis=1)
        x_ctx = tf.concat((x, context), axis=2)
        y = self.rnn_net(x, state, **kwargs)
        state = y[1:]
        y = self.output(y[0])
        return y, state