Author: [Zhedong Zheng](https://github.com/zhedongzheng)

![title](end2end_mn.png)

[End-To-End Memory Networks](https://arxiv.org/abs/1503.08895)

In [1]:
from bunch import Bunch
from copy import deepcopy

import tensorflow as tf
import numpy as np
import json, pprint

In [2]:
args = Bunch({
    'n_epochs': 20,
    'batch_size': 64,
    'hidden_dim': 64,
    'dropout_rate': 0.3,
    'n_hops': 2,
})

In [3]:
class BaseDataLoader(object):
    def __init__(self):
        self.data = {
            'size': None,
            'val':{
                'inputs': None,
                'questions': None,
                'answers': None,},
            'len':{
                'inputs_len': None,
                'inputs_sent_len': None,
                'questions_len': None,
                'answers_len': None}
        }
        self.vocab = {
            'size': None,
            'word2idx': None,
            'idx2word': None,
        }
        self.params = {
            'vocab_size': None,
            '<start>': None,
            '<end>': None,
            'max_input_len': None,
            'max_sent_len': None,
            'max_quest_len': None,
            'max_answer_len': None,
        }

    def input_fn(self):
        return tf.estimator.inputs.numpy_input_fn(
            x = {
                'inputs': self.data['val']['inputs'],
                'questions': self.data['val']['questions'],
                'inputs_len': self.data['len']['inputs_len'],
                'inputs_sent_len': self.data['len']['inputs_sent_len'],
                'questions_len': self.data['len']['questions_len'],
                'answers_len': self.data['len']['answers_len']
            },
            y = self.data['val']['answers'] if self.is_training else None,
            batch_size = args.batch_size,
            num_epochs = args.n_epochs if self.is_training else 1,
            shuffle = self.is_training)


class DataLoader(BaseDataLoader):
    def __init__(self, path, is_training, vocab=None, params=None):
        super().__init__()
        data, lens = self.load_data(path)
        if is_training:
            self.build_vocab(data)
        else:
            self.demo = data
            self.vocab = vocab
            self.params = deepcopy(params)
        self.is_training = is_training
        self.padding(data, lens)


    def load_data(self, path):
        data, lens = bAbI_data_load(path)
        self.data['size'] = len(data[0])
        return data, lens


    def build_vocab(self, data):
        signals = ['<pad>', '<unk>', '<start>', '<end>']
        inputs, questions, answers = data
        i_words = [w for facts in inputs for fact in facts for w in fact if w != '<end>']
        q_words = [w for question in questions for w in question]
        a_words = [w for answer in answers for w in answer if w != '<end>']
        words = list(set(i_words + q_words + a_words))
        self.params['vocab_size'] = len(words) + 4
        self.params['<start>'] = 2
        self.params['<end>'] = 3
        self.vocab['word2idx'] = {word: idx for idx, word in enumerate(signals + words)}
        self.vocab['idx2word'] = {idx: word for word, idx in self.vocab['word2idx'].items()}
        

    def padding(self, data, lens):
        inputs_len, inputs_sent_len, questions_len, answers_len = lens

        if self.is_training:
            self.params['max_input_len'] = max(inputs_len)
            self.params['max_sent_len'] = max([fact_len for batch in inputs_sent_len for fact_len in batch])
            self.params['max_quest_len'] = max(questions_len)
            self.params['max_answer_len'] = max(answers_len)

        self.data['len']['inputs_len'] = np.array(inputs_len)
        for batch in inputs_sent_len:
            batch += [0] * (self.params['max_input_len'] - len(batch))
        self.data['len']['inputs_sent_len'] = np.array(inputs_sent_len)
        self.data['len']['questions_len'] = np.array(questions_len)
        self.data['len']['answers_len'] = np.array(answers_len)
        
        inputs, questions, answers = deepcopy(data)
        for facts in inputs:
            for sentence in facts:
                for i in range(len(sentence)):
                    sentence[i] = self.vocab['word2idx'].get(sentence[i], self.vocab['word2idx']['<unk>'])
                sentence += [0] * (self.params['max_sent_len'] - len(sentence))
            paddings = [0] * self.params['max_sent_len']
            facts += [paddings] * (self.params['max_input_len'] - len(facts))
        for question in questions:
            for i in range(len(question)):
                question[i] = self.vocab['word2idx'].get(question[i], self.vocab['word2idx']['<unk>'])
            question += [0] * (self.params['max_quest_len'] - len(question))
        for answer in answers:
            for i in range(len(answer)):
                answer[i] = self.vocab['word2idx'].get(answer[i], self.vocab['word2idx']['<unk>'])

        self.data['val']['inputs'] = np.array(inputs)
        self.data['val']['questions'] = np.array(questions)
        self.data['val']['answers'] = np.array(answers)


def bAbI_data_load(path, END=['<end>']):
    inputs = []
    questions = []
    answers = []

    inputs_len = []
    inputs_sent_len = []
    questions_len = []
    answers_len = []

    for d in open(path):
        index = d.split(' ')[0]
        if index == '1':
            fact = []
        if '?' in d:
            temp = d.split('\t')
            q = temp[0].strip().replace('?', '').split(' ')[1:] + ['?']
            a = temp[1].split() + END
            fact_copied = deepcopy(fact)
            inputs.append(fact_copied)
            questions.append(q)
            answers.append(a)

            inputs_len.append(len(fact_copied))
            inputs_sent_len.append([len(s) for s in fact_copied])
            questions_len.append(len(q))
            answers_len.append(len(a))
        else:
            tokens = d.replace('.', '').replace('\n', '').split(' ')[1:] + END
            fact.append(tokens)
    return [inputs, questions, answers], [inputs_len, inputs_sent_len, questions_len, answers_len]

In [4]:
def model_fn(features, labels, mode, params):
    logits_or_ids = forward(features, labels, mode, params)

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=logits_or_ids)
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        loss_op = tf.reduce_mean(tf.contrib.seq2seq.sequence_loss(
            logits = logits_or_ids,
            targets = labels,
            weights = tf.ones_like(labels, tf.float32)))

        train_op = tf.train.AdamOptimizer().minimize(
            loss_op,
            global_step = tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode = mode,
                                          loss = loss_op,
                                          train_op = train_op)


def hop_forward(features, question, memory_o, memory_i, response_proj, params, is_training):
    match = tf.matmul(question, tf.transpose(memory_i, [0,2,1]))

    match = pre_softmax_masking(match, features['inputs_len'], params['max_input_len'])

    match = tf.nn.softmax(match)       # (batch, question_maxlen, input_maxlen)

    match = post_softmax_masking(match, features['questions_len'], params['max_quest_len'])

    response = tf.matmul(match, memory_o)

    response = response_proj(tf.concat([response, question], -1))
    
    return response


def forward(features, labels, mode, params):
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    
    with tf.variable_scope('questions'):
        question = quest_mem(features['questions'], params, is_training)
        
    with tf.variable_scope('memory_o'):
        memory_o = input_mem(features['inputs'], params, is_training)
    
    with tf.variable_scope('memory_i'):
        memory_i = input_mem(features['inputs'], params, is_training)
    
    with tf.variable_scope('interaction'):
        response_proj = tf.layers.Dense(args.hidden_dim)
        
        for _ in range(args['n_hops']):
            answer = hop_forward(features,
                                 question,
                                 memory_o,
                                 memory_i,
                                 response_proj,
                                 params,
                                 is_training)
            question = answer
    
    with tf.variable_scope('memory_o', reuse=True):
        embedding = tf.get_variable('lookup_table')
    
    with tf.variable_scope('answer'):
        output = answer_module(features, params, answer, embedding, is_training, labels)
    
    return output


def input_mem(x, params, is_training):
    x = embed_seq(x, params)
    x = tf.layers.dropout(x, args.dropout_rate, training=is_training)
    pos = position_encoding(params['max_sent_len'], args.hidden_dim)
    x = tf.reduce_sum(x * pos, 2)
    return x


def quest_mem(x, params, is_training):
    x = embed_seq(x, params)
    x = tf.layers.dropout(x, args.dropout_rate, training=is_training)
    pos = position_encoding(params['max_quest_len'], args.hidden_dim)
    return (x * pos)


def answer_module(features, params, answer, embedding, is_training, labels):
    cell = GRU()
    vocab_proj = tf.layers.Dense(params['vocab_size'])
    state_proj = tf.layers.Dense(args.hidden_dim)
    
    init_state = state_proj(tf.layers.flatten(answer))
    init_state = tf.layers.dropout(init_state, args.dropout_rate, training=is_training)

    if is_training:
        helper = tf.contrib.seq2seq.TrainingHelper(
            inputs = tf.nn.embedding_lookup(embedding, shift_right(labels, params)),
            sequence_length = tf.to_int32(features['answers_len']))
        decoder = tf.contrib.seq2seq.BasicDecoder(
            cell = cell,
            helper = helper,
            initial_state = init_state,
            output_layer = vocab_proj)
        decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = decoder)
        return decoder_output.rnn_output
    else:
        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            embedding = embedding,
            start_tokens = tf.tile(
                tf.constant([params['<start>']], dtype=tf.int32), [tf.shape(init_state)[0]]),
            end_token = params['<end>'])
        decoder = tf.contrib.seq2seq.BasicDecoder(
            cell = cell,
            helper = helper,
            initial_state = init_state,
            output_layer = vocab_proj)
        decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder = decoder,
            maximum_iterations = params['max_answer_len'])
        return decoder_output.sample_id


def pre_softmax_masking(x, seq_len, max_seq_len):
    paddings = tf.fill(tf.shape(x), float('-inf'))
    T = x.get_shape().as_list()[1]
    masks = tf.sequence_mask(seq_len, max_seq_len, dtype=tf.float32)
    masks = tf.tile(tf.expand_dims(masks, 1), [1, T, 1])
    return tf.where(tf.equal(masks, 0), paddings, x)


def post_softmax_masking(x, seq_len, max_seq_len):
    T = x.get_shape().as_list()[-1]
    masks = tf.sequence_mask(seq_len, max_seq_len, dtype=tf.float32)
    masks = tf.tile(tf.expand_dims(masks, -1), [1, 1, T])
    return (x * masks)


def shift_right(x, params):
    batch_size = tf.shape(x)[0]
    start = tf.to_int64(tf.fill([batch_size, 1], params['<start>']))
    return tf.concat([start, x[:, :-1]], 1)


def embed_seq(x, params, zero_pad=True):
    lookup_table = tf.get_variable('lookup_table', [params['vocab_size'], args.hidden_dim], tf.float32)
    if zero_pad:
        lookup_table = tf.concat((tf.zeros([1, args.hidden_dim]), lookup_table[1:, :]), axis=0)
    return tf.nn.embedding_lookup(lookup_table, x)


def position_encoding(sentence_size, embedding_size):
    encoding = np.ones((embedding_size, sentence_size), dtype=np.float32)
    ls = sentence_size + 1
    le = embedding_size + 1
    for i in range(1, le):
        for j in range(1, ls):
            encoding[i-1, j-1] = (i - (le-1)/2) * (j - (ls-1)/2)
    encoding = 1 + 4 * encoding / embedding_size / sentence_size
    return np.transpose(encoding)


def GRU(rnn_size=None):
    rnn_size = args.hidden_dim if rnn_size is None else rnn_size
    return tf.nn.rnn_cell.GRUCell(
        rnn_size, kernel_initializer=tf.orthogonal_initializer())

In [5]:
def main():
    tf.logging.set_verbosity(tf.logging.INFO)
    print(json.dumps(args, indent=4))

    train_dl = DataLoader(
        path='../temp/qa5_three-arg-relations_train.txt',
        is_training=True)
    test_dl = DataLoader(
        path='../temp/qa5_three-arg-relations_test.txt',
        is_training=False, vocab=train_dl.vocab, params=train_dl.params)

    model = tf.estimator.Estimator(model_fn, params=train_dl.params)
    model.train(train_dl.input_fn())
    gen = model.predict(test_dl.input_fn())
    preds = np.concatenate(list(gen))
    preds = np.reshape(preds, [test_dl.data['size'], 2])
    print('Testing Accuracy:', (test_dl.data['val']['answers'][:, 0] == preds[:, 0]).mean())
    demo(test_dl.demo, test_dl.vocab['idx2word'], preds)


def demo(demo, idx2word, ids, demo_idx=3):
    demo_i, demo_q, demo_a = demo
    print()
    pprint.pprint(demo_i[demo_idx])
    print()
    print('Question:', demo_q[demo_idx])
    print()
    print('Prediction:', [idx2word[id] for id in ids[demo_idx]])


if __name__ == '__main__':
    main()

{
    "n_epochs": 20,
    "batch_size": 64,
    "hidden_dim": 64,
    "dropout_rate": 0.3,
    "n_hops": 2
}
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpx2w1hm9z', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11874d7b8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_o