In [1]:
from bunch import Bunch
from collections import Counter

import json
import numpy as np
import tensorflow as tf

In [2]:
args = Bunch({
    'source_max_len': 10,
    'target_max_len': 10,
    'min_freq': 50,
    'hidden_units': 128,
    'num_blocks': 2,
    'num_heads': 8,
    'num_heads': 8,
    'dropout_rate': 0.1,
    'batch_size': 64,
    'position_encoding': 'param',
    'activation': 'relu',
    'tied_proj_weight': True,
    'tied_embedding': False,
    'label_smoothing': False,
    'lr_decay_strategy': 'exp',
})

In [3]:
class DataLoader:
    def __init__(self, source_path, target_path):
        self.source_words = self.read_data(source_path)
        self.target_words = self.read_data(target_path)

        self.source_word2idx = self.build_index(self.source_words)
        self.target_word2idx = self.build_index(self.target_words, is_target=True)

    
    def read_data(self, path):
        with open(path, 'r', encoding='utf-8') as f:
            return f.read()


    def build_index(self, data, is_target=False):
        chars = [char for line in data.split('\n') for char in line]
        chars = [char for char, freq in Counter(chars).items() if freq > args.min_freq]
        if is_target:
            symbols = ['<pad>','<start>','<end>','<unk>']
            return {char: idx for idx, char in enumerate(symbols + chars)}
        else:
            symbols = ['<pad>','<unk>'] if not args.tied_embedding else ['<pad>','<start>','<end>','<unk>']
            return {char: idx for idx, char in enumerate(symbols + chars)}


    def pad(self, data, word2idx, max_len, is_target=False):
        res = []
        for line in data.split('\n'):
            temp_line = [word2idx.get(char, word2idx['<unk>']) for char in line]
            if len(temp_line) >= max_len:
                if is_target:
                    temp_line = temp_line[:(max_len-1)] + [word2idx['<end>']]
                else:
                    temp_line = temp_line[:max_len]
            if len(temp_line) < max_len:
                if is_target:
                    temp_line += ([word2idx['<end>']] + [word2idx['<pad>']]*(max_len-len(temp_line)-1)) 
                else:
                    temp_line += [word2idx['<pad>']] * (max_len - len(temp_line))
            res.append(temp_line)
        return np.array(res)


    def load(self):
        source_idx = self.pad(self.source_words, self.source_word2idx, args.source_max_len)
        target_idx = self.pad(self.target_words, self.target_word2idx, args.target_max_len, is_target=True)
        return source_idx, target_idx


In [4]:
def layer_norm(inputs, epsilon=1e-8):
    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
    normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))

    params_shape = inputs.get_shape()[-1:]
    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())
    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())
    
    outputs = gamma * normalized + beta
    return outputs


def embed_seq(inputs, vocab_size=None, embed_dim=None, zero_pad=False, scale=False):
    lookup_table = tf.get_variable('lookup_table', dtype=tf.float32, shape=[vocab_size, embed_dim])
    if zero_pad:
        lookup_table = tf.concat((tf.zeros([1, embed_dim]), lookup_table[1:, :]), axis=0)
    outputs = tf.nn.embedding_lookup(lookup_table, inputs)
    if scale:
        outputs = outputs * np.sqrt(embed_dim)
    return outputs


def multihead_attn(queries, keys, q_masks, k_masks, num_units=None, num_heads=8,
        dropout_rate=args.dropout_rate, future_binding=False, reuse=False, activation=None):
    """
    Args:
      queries: A 3d tensor with shape of [N, T_q, C_q]
      keys: A 3d tensor with shape of [N, T_k, C_k]
    """
    if num_units is None:
        num_units = queries.get_shape().as_list[-1]
    T_q = queries.get_shape().as_list()[1]                                         # max time length of query
    T_k = keys.get_shape().as_list()[1]                                            # max time length of key

    Q = tf.layers.dense(queries, num_units, activation, reuse=reuse, name='Q')     # (N, T_q, C)
    K_V = tf.layers.dense(keys, 2*num_units, activation, reuse=reuse, name='K_V')    
    K, V = tf.split(K_V, 2, -1)        

    Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0)                         # (h*N, T_q, C/h) 
    K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0)                         # (h*N, T_k, C/h) 
    V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0)                         # (h*N, T_k, C/h)

    # Scaled Dot-Product
    align = tf.matmul(Q_, tf.transpose(K_, [0,2,1]))                               # (h*N, T_q, T_k)
    align = align / np.sqrt(K_.get_shape().as_list()[-1])                          # scale

    # Key Masking
    paddings = tf.fill(tf.shape(align), float('-inf'))                             # exp(-large) -> 0

    key_masks = k_masks                                                            # (N, T_k)
    key_masks = tf.tile(key_masks, [num_heads, 1])                                 # (h*N, T_k)
    key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, T_q, 1])                 # (h*N, T_q, T_k)
    align = tf.where(tf.equal(key_masks, 0), paddings, align)                      # (h*N, T_q, T_k)

    if future_binding:
        lower_tri = tf.ones([T_q, T_k])                                            # (T_q, T_k)
        lower_tri = tf.linalg.LinearOperatorLowerTriangular(lower_tri).to_dense()  # (T_q, T_k)
        masks = tf.tile(tf.expand_dims(lower_tri,0), [tf.shape(align)[0], 1, 1])   # (h*N, T_q, T_k)
        align = tf.where(tf.equal(masks, 0), paddings, align)                      # (h*N, T_q, T_k)
    
    # Softmax
    align = tf.nn.softmax(align)                                                   # (h*N, T_q, T_k)

    # Query Masking
    query_masks = tf.to_float(q_masks)                                             # (N, T_q)
    query_masks = tf.tile(query_masks, [num_heads, 1])                             # (h*N, T_q)
    query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, T_k])            # (h*N, T_q, T_k)
    align *= query_masks                                                           # (h*N, T_q, T_k)

    align = tf.layers.dropout(align, dropout_rate, training=(not reuse))           # (h*N, T_q, T_k)

    # Weighted sum
    outputs = tf.matmul(align, V_)                                                 # (h*N, T_q, C/h)
    # Restore shape
    outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2)              # (N, T_q, C)
    # Residual connection
    outputs += queries                                                             # (N, T_q, C)   
    # Normalize
    outputs = layer_norm(outputs)                                                  # (N, T_q, C)
    return outputs


def pointwise_feedforward(inputs, num_units=[None, None], activation=None):
    # Inner layer
    outputs = tf.layers.conv1d(inputs, num_units[0], kernel_size=1, activation=activation)
    # Readout layer
    outputs = tf.layers.conv1d(outputs, num_units[1], kernel_size=1, activation=None)
    # Residual connection
    outputs += inputs
    # Normalize
    outputs = layer_norm(outputs)
    return outputs


def learned_position_encoding(inputs, mask, embed_dim):
    T = inputs.get_shape().as_list()[1]
    outputs = tf.range(tf.shape(inputs)[1])                # (T_q)
    outputs = tf.expand_dims(outputs, 0)                   # (1, T_q)
    outputs = tf.tile(outputs, [tf.shape(inputs)[0], 1])   # (N, T_q)
    outputs = embed_seq(outputs, T, embed_dim, zero_pad=False, scale=False)
    return tf.expand_dims(tf.to_float(mask), -1) * outputs


def sinusoidal_position_encoding(inputs, mask, num_units):
    T = inputs.get_shape().as_list()[1]
    position_idx = tf.tile(tf.expand_dims(tf.range(T), 0), [tf.shape(inputs)[0], 1])

    position_enc = np.array(
        [[pos / np.power(10000, 2.*i/num_units) for i in range(num_units)] for pos in range(T)])
    position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])  # dim 2i
    position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])  # dim 2i+1

    lookup_table = tf.convert_to_tensor(position_enc, tf.float32)
    outputs = tf.nn.embedding_lookup(lookup_table, position_idx)
    
    return tf.expand_dims(tf.to_float(mask), -1) * outputs


def label_smoothing(inputs, epsilon=0.1):
    C = inputs.get_shape().as_list()[-1] # number of channels
    return ((1 - epsilon) * inputs) + (epsilon / C)

In [5]:
def forward_pass(sources, targets, params, reuse=False):
    with tf.variable_scope('forward_pass', reuse=reuse):
        pos_enc = _get_position_encoder()

        # ENCODER
        en_masks = tf.sign(sources)   

        with tf.variable_scope('encoder_embedding', reuse=reuse):
            encoded = embed_seq(
                sources, params['source_vocab_size'], args.hidden_units, zero_pad=True, scale=True)
        
        with tf.variable_scope('encoder_position_encoding', reuse=reuse):
            encoded += pos_enc(sources, en_masks, args.hidden_units)
        
        with tf.variable_scope('encoder_dropout', reuse=reuse):
            encoded = tf.layers.dropout(encoded, args.dropout_rate, training=(not reuse))

        for i in range(args.num_blocks):
            with tf.variable_scope('encoder_attn_%d'%i, reuse=reuse):
                encoded = multihead_attn(queries=encoded, keys=encoded, q_masks=en_masks, k_masks=en_masks,
                    num_units=args.hidden_units, num_heads=args.num_heads, dropout_rate=args.dropout_rate,
                    future_binding=False, reuse=reuse, activation=None)
            
            with tf.variable_scope('encoder_feedforward_%d'%i, reuse=reuse):
                encoded = pointwise_feedforward(encoded, num_units=[4*args.hidden_units, args.hidden_units],
                    activation=params['activation'])

        # DECODER
        decoder_inputs = _shift_right(targets, params['start_symbol'])
        de_masks = tf.sign(decoder_inputs)
            
        if args.tied_embedding:
            with tf.variable_scope('encoder_embedding', reuse=True):
                decoded = embed_seq(decoder_inputs, params['target_vocab_size'], args.hidden_units,
                    zero_pad=True, scale=True)
        else:
            with tf.variable_scope('decoder_embedding', reuse=reuse):
                decoded = embed_seq(
                    decoder_inputs, params['target_vocab_size'], args.hidden_units, zero_pad=True, scale=True)
        
        with tf.variable_scope('decoder_position_encoding', reuse=reuse):
            decoded += pos_enc(decoder_inputs, de_masks, args.hidden_units)
                
        with tf.variable_scope('decoder_dropout', reuse=reuse):
            decoded = tf.layers.dropout(decoded, args.dropout_rate, training=(not reuse))

        for i in range(args.num_blocks):
            with tf.variable_scope('decoder_self_attn_%d'%i, reuse=reuse):
                decoded = multihead_attn(queries=decoded, keys=decoded, q_masks=de_masks, k_masks=de_masks,
                    num_units=args.hidden_units, num_heads=args.num_heads, dropout_rate=args.dropout_rate,
                    future_binding=True, reuse=reuse, activation=None)
            
            with tf.variable_scope('decoder_attn_%d'%i, reuse=reuse):
                decoded = multihead_attn(queries=decoded, keys=encoded, q_masks=de_masks, k_masks=en_masks,
                    num_units=args.hidden_units, num_heads=args.num_heads, dropout_rate=args.dropout_rate,
                    future_binding=False, reuse=reuse, activation=None)
            
            with tf.variable_scope('decoder_feedforward_%d'%i, reuse=reuse):
                decoded = pointwise_feedforward(decoded, num_units=[4*args.hidden_units, args.hidden_units],
                    activation=params['activation'])
        
        # OUTPUT LAYER    
        if args.tied_proj_weight:
            b = tf.get_variable('bias', [params['target_vocab_size']], tf.float32)
            _scope = 'encoder_embedding' if args.tied_embedding else 'decoder_embedding'
            with tf.variable_scope(_scope, reuse=True):
                shared_w = tf.get_variable('lookup_table')
            decoded = tf.reshape(decoded, [-1, args.hidden_units])
            logits = tf.nn.xw_plus_b(decoded, tf.transpose(shared_w), b)
            logits = tf.reshape(logits, [tf.shape(sources)[0], -1, params['target_vocab_size']])
        else:
            with tf.variable_scope('output_layer', reuse=reuse):
                logits = tf.layers.dense(decoded, params['target_vocab_size'], reuse=reuse)
        return logits


def _model_fn_train(features, mode, params, logits):
    with tf.name_scope('backward'):
        targets = features['target']
        masks = tf.to_float(tf.not_equal(targets, 0))

        if args.label_smoothing:
            loss_op = label_smoothing_sequence_loss(
                logits=logits, targets=targets, weights=masks, label_depth=params['target_vocab_size'])
        else:
            loss_op = tf.contrib.seq2seq.sequence_loss(
                logits=logits, targets=targets, weights=masks)

        if args.lr_decay_strategy == 'noam':
            step_num = tf.train.get_global_step() + 1   # prevents zero global step
            lr = _get_noam_lr(step_num)
        elif args.lr_decay_strategy == 'exp':
            lr = tf.train.exponential_decay(1e-3, tf.train.get_global_step(), 100000, 0.1)
        else:
            raise ValueError("lr decay strategy must be one of 'noam' and 'exp'")
        log_hook = tf.train.LoggingTensorHook({'lr': lr}, every_n_iter=100)
        
        train_op = tf.train.AdamOptimizer(lr).minimize(loss_op, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(
        mode=mode, loss=loss_op, train_op=train_op, training_hooks=[log_hook])


def _model_fn_predict(features, mode, params):
    def cond(i, x, temp):
        return i < args.target_max_len

    def body(i, x, temp):
        logits = forward_pass(features['source'], x, params, reuse=True)
        ids = tf.argmax(logits, -1)[:, i]
        ids = tf.expand_dims(ids, -1)

        temp = tf.concat([temp[:, 1:], ids], -1)

        x = tf.concat([temp[:, -(i+1):], temp[:, :-(i+1)]], -1)
        x = tf.reshape(x, [tf.shape(temp)[0], args.target_max_len])
        i += 1
        return i, x, temp

    _, res, _ = tf.while_loop(cond, body, [tf.constant(0), features['target'], features['target']])
    
    return tf.estimator.EstimatorSpec(mode=mode, predictions=res)


def tf_estimator_model_fn(features, labels, mode, params):
    logits = forward_pass(features['source'], features['target'], params)
    if mode == tf.estimator.ModeKeys.TRAIN:
        _ = forward_pass(features['source'], features['target'], params, reuse=True)
        return _model_fn_train(features, mode, params, logits)
    if mode == tf.estimator.ModeKeys.PREDICT:
        return _model_fn_predict(features, mode, params)


def _shift_right(targets, start_symbol):
    start_symbols = tf.cast(tf.fill([tf.shape(targets)[0], 1], start_symbol), tf.int64)
    return tf.concat([start_symbols, targets[:, :-1]], axis=-1)


def _get_position_encoder():
    if args.position_encoding == 'non_param':
        pos_enc = sinusoidal_position_encoding
    elif args.position_encoding == 'param':
        pos_enc = learned_position_encoding
    else:
        raise ValueError("position encoding has to be either 'param' or 'non_param'")
    return pos_enc


def _get_noam_lr(step_num):
    return tf.rsqrt(tf.to_float(args.hidden_units)) * tf.minimum(
        tf.rsqrt(tf.to_float(step_num)),
        tf.to_float(step_num) * tf.convert_to_tensor(args.warmup_steps ** (-1.5)))

In [6]:
def greedy_decode(test_words, tf_estimator, dl):
    test_indices = []
    for test_word in test_words:
        test_idx = [dl.source_word2idx[c] for c in test_word] + \
                   [dl.source_word2idx['<pad>']] * (args.source_max_len - len(test_word))
        test_indices.append(test_idx)
    test_indices = np.atleast_2d(test_indices)
    
    zeros = np.zeros([len(test_words), args.target_max_len], np.int64)

    pred_ids = tf_estimator.predict(tf.estimator.inputs.numpy_input_fn(
        x={'source':test_indices, 'target':zeros}, batch_size=len(test_words), shuffle=False))
    pred_ids = list(pred_ids)
    
    target_idx2word = {i: w for w, i in dl.target_word2idx.items()}
    for i, test_word in enumerate(test_words):
        ans = ''.join([target_idx2word[id] for id in pred_ids[i]])
        print(test_word, '->', ans.replace('<end>', ''))


def prepare_params(dl):
    if args.activation == 'relu':
        activation = tf.nn.relu
    elif args.activation == 'elu':
        activation = tf.nn.elu
    elif args.activation == 'lrelu':
        activation = tf.nn.leaky_relu
    else:
        raise ValueError("acitivation fn has to be 'relu' or 'elu' or 'lrelu'")
    params = {
        'source_vocab_size': len(dl.source_word2idx),
        'target_vocab_size': len(dl.target_word2idx),
        'start_symbol': dl.target_word2idx['<start>'],
        'activation': activation}
    return params


def main():
    dl = DataLoader(
        source_path='../temp/dialog_source.txt',
        target_path='../temp/dialog_target.txt')
    sources, targets = dl.load()
    print('Source Vocab Size:', len(dl.source_word2idx))
    print('Target Vocab Size:', len(dl.target_word2idx))
    
    tf_estimator = tf.estimator.Estimator(
        tf_estimator_model_fn, params=prepare_params(dl))
    
    for epoch in range(2):
        tf_estimator.train(tf.estimator.inputs.numpy_input_fn(
            x = {'source':sources, 'target':targets},
            batch_size = args.batch_size,
            shuffle = True))
        greedy_decode(['你是谁', '你喜欢我吗', '给我唱一首歌', '我帅吗'], tf_estimator, dl)


if __name__ == '__main__':
    main()

Source Vocab Size: 2022
Target Vocab Size: 2481
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8632gsu2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x1172422e8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
Instructions for updating:
`NHWC` for data_format is deprecated, use `NWC` instead
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
IN

INFO:tensorflow:loss = 3.2124953, step = 4501 (16.284 sec)
INFO:tensorflow:lr = 0.0009015712 (16.284 sec)
INFO:tensorflow:global_step/sec: 6.05663
INFO:tensorflow:loss = 2.7760828, step = 4601 (16.511 sec)
INFO:tensorflow:lr = 0.0008994976 (16.511 sec)
INFO:tensorflow:global_step/sec: 6.26572
INFO:tensorflow:loss = 3.0912976, step = 4701 (15.960 sec)
INFO:tensorflow:lr = 0.00089742884 (15.960 sec)
INFO:tensorflow:global_step/sec: 6.2609
INFO:tensorflow:loss = 3.0268857, step = 4801 (15.972 sec)
INFO:tensorflow:lr = 0.0008953648 (15.972 sec)
INFO:tensorflow:global_step/sec: 6.56058
INFO:tensorflow:loss = 3.069369, step = 4901 (15.242 sec)
INFO:tensorflow:lr = 0.0008933055 (15.242 sec)
INFO:tensorflow:global_step/sec: 6.54205
INFO:tensorflow:loss = 2.7635202, step = 5001 (15.286 sec)
INFO:tensorflow:lr = 0.000891251 (15.286 sec)
INFO:tensorflow:global_step/sec: 6.4039
INFO:tensorflow:loss = 3.2860072, step = 5101 (15.616 sec)
INFO:tensorflow:lr = 0.0008892011 (15.616 sec)
INFO:tensorflow

INFO:tensorflow:lr = 0.00080552686 (16.570 sec)
INFO:tensorflow:global_step/sec: 6.13069
INFO:tensorflow:loss = 3.0732715, step = 9493 (16.312 sec)
INFO:tensorflow:lr = 0.0008036742 (16.312 sec)
INFO:tensorflow:global_step/sec: 6.38792
INFO:tensorflow:loss = 2.81342, step = 9593 (15.654 sec)
INFO:tensorflow:lr = 0.0008018258 (15.654 sec)
INFO:tensorflow:global_step/sec: 6.06807
INFO:tensorflow:loss = 3.0305955, step = 9693 (16.479 sec)
INFO:tensorflow:lr = 0.00079998164 (16.480 sec)
INFO:tensorflow:global_step/sec: 6.19815
INFO:tensorflow:loss = 3.2607522, step = 9793 (16.134 sec)
INFO:tensorflow:lr = 0.00079814176 (16.133 sec)
INFO:tensorflow:global_step/sec: 6.15417
INFO:tensorflow:loss = 2.7144303, step = 9893 (16.249 sec)
INFO:tensorflow:lr = 0.00079630606 (16.249 sec)
INFO:tensorflow:global_step/sec: 6.46069
INFO:tensorflow:loss = 3.2435334, step = 9993 (15.478 sec)
INFO:tensorflow:lr = 0.0007944746 (15.478 sec)
INFO:tensorflow:global_step/sec: 6.33906
INFO:tensorflow:loss = 3.196