In [1]:
import re
import os
import sys
import collections
import pandas as pd
import numpy as np

from keras.optimizers import *
from keras.callbacks import *
from keras.models import *
from keras.layers import *
from keras.initializers import *
from keras.activations import *
from keras_layer_normalization import LayerNormalization

import tensorflow as tf
from sklearn.model_selection import train_test_split

import nltk
nltk.download('punkt')

data_en_path = "data/en/train"
data_en_10k_path = "data/en-10k/"
data_en_valid_path = "data/en-valid/"
data_en_valid_10k_path = "data/en-valid-10k/"

raw_task_category = [
    'single-supporting-fact'
    ,'two-supporting-facts'
    ,'three-supporting-facts'
    ,'two-arg-relations'
    ,'three-arg-relations'
    ,'yes-no-questions'
    ,'counting'
    ,'lists-sets'
    ,'simple-negation'
    ,'indefinite-knowledge'
    ,'basic-coreference'
    ,'conjunction'
    ,'compound-coreference'
    ,'time-reasoning'
    ,'basic-deduction'
    ,'basic-induction'
    ,'positional-reasoning'
    ,'size-reasoning'
    ,'path-finding'
    ,'agents-motivations'
]

tasks = [
    'single-supporting-fact'
#     ,'two-supporting-facts'
#     ,'three-supporting-facts'
#     ,'two-arg-relations'
#     ,'three-arg-relations'
#     ,'yes-no-questions'
#     ,'counting'
#     ,'lists-sets'
#     ,'simple-negation'
#     ,'indefinite-knowledge'
#     ,'basic-coreference'
#     ,'conjunction'
#     ,'compound-coreference'
#     ,'time-reasoning'
#     ,'basic-deduction'
#     ,'basic-induction'
#     ,'positional-reasoning'
#     ,'size-reasoning'
#     ,'path-finding'
#     ,'agents-motivations'
]

dropout = 0.3

Using TensorFlow backend.


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\ICPS\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
class Sentence:
    def __init__(self, num, sentence):
        self.num = num
        self.sentence = sentence

class bAbISet:
    def __init__(self, task=None, sentences=None, question=None, answer=None, supporting_num=None):
        self.task = task
        self.sentences = sentences
        self.story_len = len(self.sentences)
        self.question = question
        self.answer = answer
        self.supporting_num = supporting_num
        
    def add_vec_data(self, vec_sentences, vec_question, vec_answer):
        self.vec_sentences = vec_sentences
        self.vec_question = vec_question
        self.vec_answer = vec_answer
        
    def Print(self):
        print("=============================================================")
        print(">> Task: ", self.task)
        print(">> Sentences: ", self.story_len)
        for sentence in self.sentences:
            print(sentence.num, ": ", sentence.sentence)
        print(">> Question: ", self.question)
        print(">> Answer: ", self.answer)
        print(">> Supporting Fact: ", self.supporting_num)
        print("=============================================================")
        
    def Print_vec(self):
        if self.vec_sentences is None:
            return
        
        print("=============================================================")
        print(">> Task: ", self.task)
        print(">> Sentences: ", len(self.sentences))
        for sentence in self.vec_sentences:
            print(sentence.num, ": ", sentence.sentence)
        print(">> Question: ", self.vec_question)
        print(">> Answer: ", self.vec_answer)
        print(">> Supporting Fact: ", self.supporting_num)
        print("=============================================================")
        

In [3]:
class bAbIUtils:
    def __init__(self, path):
        self.files = os.listdir(path)
        self.files = [os.path.join(path, f) for f in self.files]
        
    def make_question_set(self, sentences):
        all_set = []
        temp_set = []
        for sentence in sentences:
            temp_set.append(sentence.lower())
            if sentence.find('\t') != -1:
                all_set.append(temp_set.copy())
                temp_set.clear()
                
        return all_set
                
    def data_processing(self):
        print("※ Data Processing...")
        all_data = []
        for file in self.files:
            
            idx_start = file.find('_') + 1
            idx_end = file[idx_start:].find('_')
            task = file[idx_start:idx_start + idx_end]
            
            if task not in tasks:
                continue
                
            print(task)
            
            with open(file, 'r') as f:
                data = f.readlines()
                raw_set = self.make_question_set(data)
                
                for one_set in raw_set:
                    Sentence_set = []
                    for sentence in one_set:
                        sentence = sentence.replace('\n', '')
                        
                        idx = sentence.find(' ')
                        idx_answer =  sentence.find('\t')
                        if (len(Sentence_set) == 0) & (sentence[:idx] != '1'):
                            Sentence_set.extend(temp)
                        
                        if idx_answer != -1: # Answer Detected
                            question = sentence[idx + 1:idx_answer]
                            idx_answer_end = sentence[idx_answer + 1:].find('\t')
                            answer = sentence[idx_answer + 1: idx_answer + idx_answer_end + 1]
                            supporting_num = sentence[idx_answer + idx_answer_end + 2:]
                            
                        else:
                            Sentence_set.append(Sentence(sentence[:idx], sentence[idx + 1:]))
                    
                        temp = Sentence_set.copy()
                            
                    all_data.append(bAbISet(task, Sentence_set, question, answer, supporting_num))
                    
        return all_data
    
    def make_dictionary(self, all_set):
        print(">> Make Dictionary...")
        
        sentence_set = []
        for one_set in all_set:
            for Sentence in one_set.sentences:
                sentence_set.append(Sentence.sentence)
            sentence_set.append(one_set.question)
            sentence_set.append(one_set.answer)
        
        words = []
        for sentence in sentence_set:
            tokens = nltk.word_tokenize(sentence)
                        
            words.extend(tokens)

        words = collections.Counter(words)

        dictionary = {}
        dictionary['<PAD>'] = 0
        dictionary['<UNK>'] = 1
        dictionary['<EOS>'] = 2
        dictionary['<S>'] = 3
        idx = 4
        for word in words.most_common():
            if len(word[0]) > 0:
                dictionary[word[0]] = idx
                idx += 1
            
            if idx >= 20000: break;

        return dictionary
    
    def vectorize_sentence(self, sentence, dictionary):
        vec_sentence = []
        #vec_sentence.append(dictionary['<S>'])
        for word in nltk.word_tokenize(sentence):
            if len(word) > 0:
                if word in dictionary:
                    vec_sentence.append(dictionary[word])
                else:
                    vec_sentence.append(dictionary['<UNK>'])
        
#         if add_padding == True:
#             for _ in range(len(vec_sentence), bAbI_max_len):
#                 vec_sentence.append(dictionary['<PAD>'])
        #vec_sentence.append(dictionary['<EOS>'])
        
        return np.array(vec_sentence)
    
    def add_padding(self, sentence, max_len):
        for _ in range(len(sentence), max_len):
            sentence = np.append(sentence, dictionary['<PAD>'])
        
        return sentence
        

In [4]:
bAbI_util = bAbIUtils(data_en_10k_path)
bAbI = bAbI_util.data_processing()
dictionary = bAbI_util.make_dictionary(bAbI)

print("Dictionary Size: ", len(dictionary))
print("Done")

※ Data Processing...
single-supporting-fact
single-supporting-fact
>> Make Dictionary...
Dictionary Size:  25
Done


In [5]:
max_sentences_len = 0
max_question_len = 0
max_answer_len = 1
max_story_size = 0
count = 0
print("Step 1")
for one_set in bAbI:
    count = count + 1
    print("\rSet : {0} / {1}".format(count, len(bAbI)), end='')
    if max_story_size < one_set.story_len:
        max_story_size = one_set.story_len
    
    vec_sentences = []
    for sentence in one_set.sentences:
        vec_sentences.extend(bAbI_util.vectorize_sentence(sentence.sentence, dictionary))
#         vec_sentence = bAbI_util.vectorize_sentence(sentence.sentence, dictionary)
#         vec_sentences.append(Sentence(sentence.num, vec_sentence))

#         if max_len < len(vec_sentence):
#             max_len = len(vec_sentence)
        if max_sentences_len < len(vec_sentences):
            max_sentences_len = len(vec_sentences)
        
    vec_question = bAbI_util.vectorize_sentence(one_set.question, dictionary)
    vec_answer = bAbI_util.vectorize_sentence(one_set.answer, dictionary)
    one_set.add_vec_data(vec_sentences, vec_question, vec_answer)
    
    if len(vec_answer) > 1:
        one_set.Print()
        
    if max_question_len < len(vec_question):
        max_question_len = len(vec_question)
    
count = 0
print("\nStep 2")
for one_set in bAbI:
    count = count + 1
    print("\rSet : {0} / {1}".format(count, len(bAbI)), end='')
    
#     for i in range(max_story_size):
#         if i < len(one_set.vec_sentences):
#             one_set.vec_sentences[i].sentence = bAbI_util.add_padding(one_set.vec_sentences[i].sentence, max_len)
#         else:
#             one_set.vec_sentences.append(Sentence(0, bAbI_util.add_padding([], max_len)))

    one_set.vec_sentences = bAbI_util.add_padding(one_set.vec_sentences, max_sentences_len)    
    one_set.vec_question = bAbI_util.add_padding(one_set.vec_question, max_question_len)

print("\nDone")

Step 1
Set : 11000 / 11000
Step 2
Set : 11000 / 11000
Done


In [6]:
# MAX_DICT_SIZE = len(dictionary)
# MAX_SEN_LEN = max_len
# MAX_STORY_SIZE = max_story_size
# MAX_ANSWER_LEN = max_answer_len
# D_MODEL = 64

# print("MAX_DICT_SIZE = ", MAX_DICT_SIZE)
# print("MAX_STORY_SIZE = ", MAX_STORY_SIZE)
# print("MAX_SEN_LEN = ", MAX_SEN_LEN)
# print("MAX_ANSWER_LEN = ", MAX_ANSWER_LEN)

# story_train = np.zeros((len(bAbI), MAX_STORY_SIZE, MAX_SEN_LEN), dtype=int)
# question_train = np.zeros((len(bAbI), MAX_SEN_LEN), dtype=int)
# answer_train = np.zeros((len(bAbI), MAX_ANSWER_LEN), dtype=int)

# count = 0
# for idx, one_set in enumerate(bAbI):
#     count = count + 1
#     print("\rSet : {0} / {1}".format(idx + 1, len(bAbI)), end='')
    
#     for i, sentence in enumerate(one_set.vec_sentences):
#         story_train[idx][i] = sentence.sentence
        
#     question_train[idx] = one_set.vec_question.copy()
#     answer_train[idx] = one_set.vec_answer.copy()
    
# story_train = np.array(story_train)
# question_train = np.array(question_train)
# answer_train = np.array(answer_train)

# print("\nstory_train = ", story_train.shape)
# print("question_train = ", question_train.shape)
# print("answer_train = ", answer_train.shape)

In [7]:
MAX_DICT_SIZE = len(dictionary)
MAX_SEN_LEN = max_sentences_len
MAX_Q_LEN = max_question_len
MAX_STORY_SIZE = max_story_size
MAX_ANSWER_LEN = max_answer_len
D_MODEL = 64

print("MAX_DICT_SIZE = ", MAX_DICT_SIZE)
print("MAX_STORY_SIZE = ", MAX_STORY_SIZE)
print("MAX_SEN_LEN = ", MAX_SEN_LEN)
print("MAX_Q_LEN = ", MAX_Q_LEN)
print("MAX_ANSWER_LEN = ", MAX_ANSWER_LEN)

story_train = np.zeros((len(bAbI), MAX_SEN_LEN), dtype=int)
question_train = np.zeros((len(bAbI), MAX_Q_LEN), dtype=int)
answer_train = np.zeros((len(bAbI), MAX_ANSWER_LEN), dtype=int)

count = 0
for idx, one_set in enumerate(bAbI):
    count = count + 1
    print("\rSet : {0} / {1}".format(idx + 1, len(bAbI)), end='')
    
#     for i, sentence in enumerate(one_set.vec_sentences):
#         story_train[idx][i] = sentence.sentence
        
    story_train[idx] = one_set.vec_sentences.copy()
    question_train[idx] = one_set.vec_question.copy()
    answer_train[idx] = one_set.vec_answer.copy()
    
story_train = np.array(story_train)
question_train = np.array(question_train)
answer_train = np.array(answer_train)

print("\nstory_train = ", story_train.shape)
print("question_train = ", question_train.shape)
print("answer_train = ", answer_train.shape)

MAX_DICT_SIZE =  25
MAX_STORY_SIZE =  10
MAX_SEN_LEN =  68
MAX_Q_LEN =  4
MAX_ANSWER_LEN =  1


Set : 1 / 11000Set : 2 / 11000Set : 3 / 11000Set : 4 / 11000Set : 5 / 11000Set : 6 / 11000Set : 7 / 11000Set : 8 / 11000Set : 9 / 11000Set : 10 / 11000Set : 11 / 11000Set : 12 / 11000Set : 13 / 11000Set : 14 / 11000Set : 15 / 11000Set : 16 / 11000Set : 17 / 11000Set : 18 / 11000Set : 19 / 11000Set : 20 / 11000Set : 21 / 11000Set : 22 / 11000Set : 23 / 11000Set : 24 / 11000Set : 25 / 11000Set : 26 / 11000Set : 27 / 11000Set : 28 / 11000Set : 29 / 11000Set : 30 / 11000Set : 31 / 11000Set : 32 / 11000Set : 33 / 11000Set : 34 / 11000Set : 35 / 11000Set : 36 / 11000Set : 37 / 11000Set : 38 / 11000Set : 39 / 11000Set : 40 / 11000Set : 41 / 11000Set : 42 / 11000Set : 43 / 11000Set : 44 / 11000Set : 45 / 11000Set : 46 / 11000Set : 47 / 11000Set : 48 / 11000Set : 49 / 11000Set : 50 / 11000Set : 51 / 11000Set : 52 / 11000Set : 53 / 11000Set : 54 / 11000Set : 55 / 11000Set : 56 / 11000Set : 57 / 11000Set : 58 / 11000Set : 59 / 11000Set :

Set : 1906 / 11000Set : 1907 / 11000Set : 1908 / 11000Set : 1909 / 11000Set : 1910 / 11000Set : 1911 / 11000Set : 1912 / 11000Set : 1913 / 11000Set : 1914 / 11000Set : 1915 / 11000Set : 1916 / 11000Set : 1917 / 11000Set : 1918 / 11000Set : 1919 / 11000Set : 1920 / 11000Set : 1921 / 11000Set : 1922 / 11000Set : 1923 / 11000Set : 1924 / 11000Set : 1925 / 11000Set : 1926 / 11000Set : 1927 / 11000Set : 1928 / 11000Set : 1929 / 11000Set : 1930 / 11000Set : 1931 / 11000Set : 1932 / 11000Set : 1933 / 11000Set : 1934 / 11000Set : 1935 / 11000Set : 1936 / 11000Set : 1937 / 11000Set : 1938 / 11000Set : 1939 / 11000Set : 1940 / 11000Set : 1941 / 11000Set : 1942 / 11000Set : 1943 / 11000Set : 1944 / 11000Set : 1945 / 11000Set : 1946 / 11000Set : 1947 / 11000Set : 1948 / 11000Set : 1949 / 11000Set : 1950 / 11000Set : 1951 / 11000Set : 1952 / 11000Set : 1953 / 11000Set : 1954 / 11000Set : 1955 / 11000Set : 1956 / 11000Set : 1957 / 11000Set : 1958 

Set : 2405 / 11000Set : 2406 / 11000Set : 2407 / 11000Set : 2408 / 11000Set : 2409 / 11000Set : 2410 / 11000Set : 2411 / 11000Set : 2412 / 11000Set : 2413 / 11000Set : 2414 / 11000Set : 2415 / 11000Set : 2416 / 11000Set : 2417 / 11000Set : 2418 / 11000Set : 2419 / 11000Set : 2420 / 11000Set : 2421 / 11000Set : 2422 / 11000Set : 2423 / 11000Set : 2424 / 11000Set : 2425 / 11000Set : 2426 / 11000Set : 2427 / 11000Set : 2428 / 11000Set : 2429 / 11000Set : 2430 / 11000Set : 2431 / 11000Set : 2432 / 11000Set : 2433 / 11000Set : 2434 / 11000Set : 2435 / 11000Set : 2436 / 11000Set : 2437 / 11000Set : 2438 / 11000Set : 2439 / 11000Set : 2440 / 11000Set : 2441 / 11000Set : 2442 / 11000Set : 2443 / 11000Set : 2444 / 11000Set : 2445 / 11000Set : 2446 / 11000Set : 2447 / 11000Set : 2448 / 11000Set : 2449 / 11000Set : 2450 / 11000Set : 2451 / 11000Set : 2452 / 11000Set : 2453 / 11000Set : 2454 / 11000Set : 2455 / 11000Set : 2456 / 11000Set : 2457 

Set : 3405 / 11000Set : 3406 / 11000Set : 3407 / 11000Set : 3408 / 11000Set : 3409 / 11000Set : 3410 / 11000Set : 3411 / 11000Set : 3412 / 11000Set : 3413 / 11000Set : 3414 / 11000Set : 3415 / 11000Set : 3416 / 11000Set : 3417 / 11000Set : 3418 / 11000Set : 3419 / 11000Set : 3420 / 11000Set : 3421 / 11000Set : 3422 / 11000Set : 3423 / 11000Set : 3424 / 11000Set : 3425 / 11000Set : 3426 / 11000Set : 3427 / 11000Set : 3428 / 11000Set : 3429 / 11000Set : 3430 / 11000Set : 3431 / 11000Set : 3432 / 11000Set : 3433 / 11000Set : 3434 / 11000Set : 3435 / 11000Set : 3436 / 11000Set : 3437 / 11000Set : 3438 / 11000Set : 3439 / 11000Set : 3440 / 11000Set : 3441 / 11000Set : 3442 / 11000Set : 3443 / 11000Set : 3444 / 11000Set : 3445 / 11000Set : 3446 / 11000Set : 3447 / 11000Set : 3448 / 11000Set : 3449 / 11000Set : 3450 / 11000Set : 3451 / 11000Set : 3452 / 11000Set : 3453 / 11000Set : 3454 / 11000Set : 3455 / 11000Set : 3456 / 11000Set : 3457 

Set : 3904 / 11000Set : 3905 / 11000Set : 3906 / 11000Set : 3907 / 11000Set : 3908 / 11000Set : 3909 / 11000Set : 3910 / 11000Set : 3911 / 11000Set : 3912 / 11000Set : 3913 / 11000Set : 3914 / 11000Set : 3915 / 11000Set : 3916 / 11000Set : 3917 / 11000Set : 3918 / 11000Set : 3919 / 11000Set : 3920 / 11000Set : 3921 / 11000Set : 3922 / 11000Set : 3923 / 11000Set : 3924 / 11000Set : 3925 / 11000Set : 3926 / 11000Set : 3927 / 11000Set : 3928 / 11000Set : 3929 / 11000Set : 3930 / 11000Set : 3931 / 11000Set : 3932 / 11000Set : 3933 / 11000Set : 3934 / 11000Set : 3935 / 11000Set : 3936 / 11000Set : 3937 / 11000Set : 3938 / 11000Set : 3939 / 11000Set : 3940 / 11000Set : 3941 / 11000Set : 3942 / 11000Set : 3943 / 11000Set : 3944 / 11000Set : 3945 / 11000Set : 3946 / 11000Set : 3947 / 11000Set : 3948 / 11000Set : 3949 / 11000Set : 3950 / 11000Set : 3951 / 11000Set : 3952 / 11000Set : 3953 / 11000Set : 3954 / 11000Set : 3955 / 11000Set : 3956 

Set : 4904 / 11000Set : 4905 / 11000Set : 4906 / 11000Set : 4907 / 11000Set : 4908 / 11000Set : 4909 / 11000Set : 4910 / 11000Set : 4911 / 11000Set : 4912 / 11000Set : 4913 / 11000Set : 4914 / 11000Set : 4915 / 11000Set : 4916 / 11000Set : 4917 / 11000Set : 4918 / 11000Set : 4919 / 11000Set : 4920 / 11000Set : 4921 / 11000Set : 4922 / 11000Set : 4923 / 11000Set : 4924 / 11000Set : 4925 / 11000Set : 4926 / 11000Set : 4927 / 11000Set : 4928 / 11000Set : 4929 / 11000Set : 4930 / 11000Set : 4931 / 11000Set : 4932 / 11000Set : 4933 / 11000Set : 4934 / 11000Set : 4935 / 11000Set : 4936 / 11000Set : 4937 / 11000Set : 4938 / 11000Set : 4939 / 11000Set : 4940 / 11000Set : 4941 / 11000Set : 4942 / 11000Set : 4943 / 11000Set : 4944 / 11000Set : 4945 / 11000Set : 4946 / 11000Set : 4947 / 11000Set : 4948 / 11000Set : 4949 / 11000Set : 4950 / 11000Set : 4951 / 11000Set : 4952 / 11000Set : 4953 / 11000Set : 4954 / 11000Set : 4955 / 11000Set : 4956 

Set : 5435 / 11000Set : 5436 / 11000Set : 5437 / 11000Set : 5438 / 11000Set : 5439 / 11000Set : 5440 / 11000Set : 5441 / 11000Set : 5442 / 11000Set : 5443 / 11000Set : 5444 / 11000Set : 5445 / 11000Set : 5446 / 11000Set : 5447 / 11000Set : 5448 / 11000Set : 5449 / 11000Set : 5450 / 11000Set : 5451 / 11000Set : 5452 / 11000Set : 5453 / 11000Set : 5454 / 11000Set : 5455 / 11000Set : 5456 / 11000Set : 5457 / 11000Set : 5458 / 11000Set : 5459 / 11000Set : 5460 / 11000Set : 5461 / 11000Set : 5462 / 11000Set : 5463 / 11000Set : 5464 / 11000Set : 5465 / 11000Set : 5466 / 11000Set : 5467 / 11000Set : 5468 / 11000Set : 5469 / 11000Set : 5470 / 11000Set : 5471 / 11000Set : 5472 / 11000Set : 5473 / 11000Set : 5474 / 11000Set : 5475 / 11000Set : 5476 / 11000Set : 5477 / 11000Set : 5478 / 11000Set : 5479 / 11000Set : 5480 / 11000Set : 5481 / 11000Set : 5482 / 11000Set : 5483 / 11000Set : 5484 / 11000Set : 5485 / 11000Set : 5486 / 11000Set : 5487 

Set : 6637 / 11000Set : 6638 / 11000Set : 6639 / 11000Set : 6640 / 11000Set : 6641 / 11000Set : 6642 / 11000Set : 6643 / 11000Set : 6644 / 11000Set : 6645 / 11000Set : 6646 / 11000Set : 6647 / 11000Set : 6648 / 11000Set : 6649 / 11000Set : 6650 / 11000Set : 6651 / 11000Set : 6652 / 11000Set : 6653 / 11000Set : 6654 / 11000Set : 6655 / 11000Set : 6656 / 11000Set : 6657 / 11000Set : 6658 / 11000Set : 6659 / 11000Set : 6660 / 11000Set : 6661 / 11000Set : 6662 / 11000Set : 6663 / 11000Set : 6664 / 11000Set : 6665 / 11000Set : 6666 / 11000Set : 6667 / 11000Set : 6668 / 11000Set : 6669 / 11000Set : 6670 / 11000Set : 6671 / 11000Set : 6672 / 11000Set : 6673 / 11000Set : 6674 / 11000Set : 6675 / 11000Set : 6676 / 11000Set : 6677 / 11000Set : 6678 / 11000Set : 6679 / 11000Set : 6680 / 11000Set : 6681 / 11000Set : 6682 / 11000Set : 6683 / 11000Set : 6684 / 11000Set : 6685 / 11000Set : 6686 / 11000Set : 6687 / 11000Set : 6688 / 11000Set : 6689 

Set : 7184 / 11000Set : 7185 / 11000Set : 7186 / 11000Set : 7187 / 11000Set : 7188 / 11000Set : 7189 / 11000Set : 7190 / 11000Set : 7191 / 11000Set : 7192 / 11000Set : 7193 / 11000Set : 7194 / 11000Set : 7195 / 11000Set : 7196 / 11000Set : 7197 / 11000Set : 7198 / 11000Set : 7199 / 11000Set : 7200 / 11000Set : 7201 / 11000Set : 7202 / 11000Set : 7203 / 11000Set : 7204 / 11000Set : 7205 / 11000Set : 7206 / 11000Set : 7207 / 11000Set : 7208 / 11000Set : 7209 / 11000Set : 7210 / 11000Set : 7211 / 11000Set : 7212 / 11000Set : 7213 / 11000Set : 7214 / 11000Set : 7215 / 11000Set : 7216 / 11000Set : 7217 / 11000Set : 7218 / 11000Set : 7219 / 11000Set : 7220 / 11000Set : 7221 / 11000Set : 7222 / 11000Set : 7223 / 11000Set : 7224 / 11000Set : 7225 / 11000Set : 7226 / 11000Set : 7227 / 11000Set : 7228 / 11000Set : 7229 / 11000Set : 7230 / 11000Set : 7231 / 11000Set : 7232 / 11000Set : 7233 / 11000Set : 7234 / 11000Set : 7235 / 11000Set : 7236 

Set : 8300 / 11000Set : 8301 / 11000Set : 8302 / 11000Set : 8303 / 11000Set : 8304 / 11000Set : 8305 / 11000Set : 8306 / 11000Set : 8307 / 11000Set : 8308 / 11000Set : 8309 / 11000Set : 8310 / 11000Set : 8311 / 11000Set : 8312 / 11000Set : 8313 / 11000Set : 8314 / 11000Set : 8315 / 11000Set : 8316 / 11000Set : 8317 / 11000Set : 8318 / 11000Set : 8319 / 11000Set : 8320 / 11000Set : 8321 / 11000Set : 8322 / 11000Set : 8323 / 11000Set : 8324 / 11000Set : 8325 / 11000Set : 8326 / 11000Set : 8327 / 11000Set : 8328 / 11000Set : 8329 / 11000Set : 8330 / 11000Set : 8331 / 11000Set : 8332 / 11000Set : 8333 / 11000Set : 8334 / 11000Set : 8335 / 11000Set : 8336 / 11000Set : 8337 / 11000Set : 8338 / 11000Set : 8339 / 11000Set : 8340 / 11000Set : 8341 / 11000Set : 8342 / 11000Set : 8343 / 11000Set : 8344 / 11000Set : 8345 / 11000Set : 8346 / 11000Set : 8347 / 11000Set : 8348 / 11000Set : 8349 / 11000Set : 8350 / 11000Set : 8351 / 11000Set : 8352 

Set : 8901 / 11000Set : 8902 / 11000Set : 8903 / 11000Set : 8904 / 11000Set : 8905 / 11000Set : 8906 / 11000Set : 8907 / 11000Set : 8908 / 11000Set : 8909 / 11000Set : 8910 / 11000Set : 8911 / 11000Set : 8912 / 11000Set : 8913 / 11000Set : 8914 / 11000Set : 8915 / 11000Set : 8916 / 11000Set : 8917 / 11000Set : 8918 / 11000Set : 8919 / 11000Set : 8920 / 11000Set : 8921 / 11000Set : 8922 / 11000Set : 8923 / 11000Set : 8924 / 11000Set : 8925 / 11000Set : 8926 / 11000Set : 8927 / 11000Set : 8928 / 11000Set : 8929 / 11000Set : 8930 / 11000Set : 8931 / 11000Set : 8932 / 11000Set : 8933 / 11000Set : 8934 / 11000Set : 8935 / 11000Set : 8936 / 11000Set : 8937 / 11000Set : 8938 / 11000Set : 8939 / 11000Set : 8940 / 11000Set : 8941 / 11000Set : 8942 / 11000Set : 8943 / 11000Set : 8944 / 11000Set : 8945 / 11000Set : 8946 / 11000Set : 8947 / 11000Set : 8948 / 11000Set : 8949 / 11000Set : 8950 / 11000Set : 8951 / 11000Set : 8952 / 11000Set : 8953 

Set : 9901 / 11000Set : 9902 / 11000Set : 9903 / 11000Set : 9904 / 11000Set : 9905 / 11000Set : 9906 / 11000Set : 9907 / 11000Set : 9908 / 11000Set : 9909 / 11000Set : 9910 / 11000Set : 9911 / 11000Set : 9912 / 11000Set : 9913 / 11000Set : 9914 / 11000Set : 9915 / 11000Set : 9916 / 11000Set : 9917 / 11000Set : 9918 / 11000Set : 9919 / 11000Set : 9920 / 11000Set : 9921 / 11000Set : 9922 / 11000Set : 9923 / 11000Set : 9924 / 11000Set : 9925 / 11000Set : 9926 / 11000Set : 9927 / 11000Set : 9928 / 11000Set : 9929 / 11000Set : 9930 / 11000Set : 9931 / 11000Set : 9932 / 11000Set : 9933 / 11000Set : 9934 / 11000Set : 9935 / 11000Set : 9936 / 11000Set : 9937 / 11000Set : 9938 / 11000Set : 9939 / 11000Set : 9940 / 11000Set : 9941 / 11000Set : 9942 / 11000Set : 9943 / 11000Set : 9944 / 11000Set : 9945 / 11000Set : 9946 / 11000Set : 9947 / 11000Set : 9948 / 11000Set : 9949 / 11000Set : 9950 / 11000Set : 9951 / 11000Set : 9952 / 11000Set : 9953 

Set : 10484 / 11000Set : 10485 / 11000Set : 10486 / 11000Set : 10487 / 11000Set : 10488 / 11000Set : 10489 / 11000Set : 10490 / 11000Set : 10491 / 11000Set : 10492 / 11000Set : 10493 / 11000Set : 10494 / 11000Set : 10495 / 11000Set : 10496 / 11000Set : 10497 / 11000Set : 10498 / 11000Set : 10499 / 11000Set : 10500 / 11000Set : 10501 / 11000Set : 10502 / 11000Set : 10503 / 11000Set : 10504 / 11000Set : 10505 / 11000Set : 10506 / 11000Set : 10507 / 11000Set : 10508 / 11000Set : 10509 / 11000Set : 10510 / 11000Set : 10511 / 11000Set : 10512 / 11000Set : 10513 / 11000Set : 10514 / 11000Set : 10515 / 11000Set : 10516 / 11000Set : 10517 / 11000Set : 10518 / 11000Set : 10519 / 11000Set : 10520 / 11000Set : 10521 / 11000Set : 10522 / 11000Set : 10523 / 11000Set : 10524 / 11000Set : 10525 / 11000Set : 10526 / 11000Set : 10527 / 11000Set : 10528 / 11000Set : 10529 / 11000Set : 10530 / 11000Set : 10531 / 11000Set : 10532 / 11000Set : 10533 / 11000

In [10]:
class MEMN2N:
    def __init__(self):
        self.A = Sequential()
        self.A.add(Embedding(MAX_SEN_LEN, D_MODEL))
        self.A.add(Dropout(dropout))
        
        self.B = Sequential()
        self.B.add(Embedding(MAX_SEN_LEN, D_MODEL))
        self.B.add(Dropout(dropout))
        
        self.C = Sequential()
        self.C.add(Embedding(MAX_SEN_LEN, MAX_Q_LEN))
        self.C.add(Dropout(dropout))
        
    def compile(self, optimizer='adam'):
        story_input = Input(shape=(MAX_SEN_LEN, ), dtype='int32')
        question_input = Input(shape=(MAX_Q_LEN,), dtype='int32')
        
        story_encoded_m = self.A(story_input)
        story_encoded_c = self.C(story_input)
        u = self.B(question_input)
        
        #story_encoded_m = Lambda(lambda x:tf.reduce_sum(x, axis=2))(story_encoded_m)
        #story_encoded_c = Lambda(lambda x:tf.reduce_sum(x, axis=2))(story_encoded_c)
        #u = Lambda(lambda x:tf.reduce_sum(x, axis=1))(u)
        
        m = dot([story_encoded_m, u], axes=(2, 2))
        p = Activation('softmax')(m)
        
        c = add([p, story_encoded_c])
        o = Permute((2, 1))(c) 
        #o = Lambda(lambda x:tf.matmul(x[0], x[1]))([c, p])

        answer = concatenate([o, u])
        answer = LSTM(32)(answer)  # (samples, 32)
        answer = Dropout(dropout)(answer)  # (samples, 32)
        answer = Dense(MAX_DICT_SIZE, use_bias=False)(answer)
        answer = Activation('softmax')(answer)
        #answer = Lambda(lambda x:tf.argmax(x, axis=-1))(answer)
        #answer = Lambda(lambda x:tf.cast(x, 'int32'))(answer)
        
        self.model = Model([story_input, question_input], answer)
        self.model.compile(optimizer=optimizer,
                           loss='sparse_categorical_crossentropy',
                           metrics=['accuracy'])


In [11]:
memn2n = MEMN2N()
memn2n.compile(Adam(0.001, 0.9, 0.98, epsilon=1e-9))
memn2n.model.fit([story_train, question_train], answer_train,
                 batch_size=32, epochs=100, validation_split=0.1)

Train on 9900 samples, validate on 1100 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100

KeyboardInterrupt: 