먼저 bAbi Dataset이 어떻게 구성되어 있는지 살펴보고, 이를 어떻게 학습에 맞게 바꿀 수 있을지 고민하자.

In [2]:
with open('./data/qa1_single-supporting-fact_train.txt') as f:
    lines = f.readlines()
print(lines[:20])

['1 Mary moved to the bathroom.\n', '2 John went to the hallway.\n', '3 Where is Mary? \tbathroom\t1\n', '4 Daniel went back to the hallway.\n', '5 Sandra moved to the garden.\n', '6 Where is Daniel? \thallway\t4\n', '7 John moved to the office.\n', '8 Sandra journeyed to the bathroom.\n', '9 Where is Daniel? \thallway\t4\n', '10 Mary moved to the hallway.\n', '11 Daniel travelled to the office.\n', '12 Where is Daniel? \toffice\t11\n', '13 John went back to the garden.\n', '14 John moved to the bedroom.\n', '15 Where is Sandra? \tbathroom\t8\n', '1 Sandra travelled to the office.\n', '2 Sandra went to the bathroom.\n', '3 Where is Sandra? \tbathroom\t2\n', '4 Mary went to the bedroom.\n', '5 Daniel moved to the hallway.\n']


Story를 구성하는 문장 중간중간에 Question과 Answer이 tap으로 구분된 문장(QA)이 섞여있다. QA 문장은 Question(1), Answer(2)과 함께 정답의 근거가 되는 Supporting(3)이 함께 제안된다. 모델의 input으로 Story, Question, Answer이 들어가도록 데이터를 preprocessing해야한다. 위의 문장 번호로 예를 들면

S:[1,2],Q:[3(1)],A:[3(2)]

S:[1,2,4,5],Q:[6(1)],A:[6(2)]

S:[1,2,4,5,7,8],Q:[9(1)],A:[9(2)]  


실제 구현은 Story를 최대로 저장할 수 있는 크기를 설정한 뒤(memory length), 남은 부분은 zero-padding한다. 만약 memory length를 넘는 story가 존재한다면 story의 앞부분을 잘라낸다.

memory_length=5

S:[1,2,0,0,0],Q:[3(1)],A:[3(2)]

S:[1,2,4,5,0],Q:[6(1)],A:[6(2)]

S:[2,4,5,7,8],Q:[9(1)],A:[9(2)]  


각각의 sentence는 단어로 구성되어 있으며, 문장 역시 최대 길이를 설정한 뒤 남은 부분은 zero-padding한다. 결과적으로 Story의 dimension은 


batch_size X memory_length X sentence_length X word_embedding_dimension 이 될 것이다.


그럼 우선 문장을 tokeinze하는 함수를 구현한 뒤, dataset를 만들어주는 함수를 구현하자

In [3]:
import re

def tokenize(sentence):
    return [w.strip() for w in re.split("(\W+)?", sentence) if w.strip()]

print(tokenize("Mary moved to the bathroom.\n"))
print(tokenize("Daniel travelled to the office.\n"))

['Mary', 'moved', 'to', 'the', 'bathroom', '.']
['Daniel', 'travelled', 'to', 'the', 'office', '.']


  return _compile(pattern, flags).split(string, maxsplit)


In [4]:
def split_SQA(lines):
    """return data: [[story,question,answer,supporting id]]"""
    data = []
    story_len = []
    sentence_len = []
    story = None
    num_questions = None
    for line in lines:
        line.lower()
        nid, line = line.split(' ',1)
        nid = int(nid)
    
        if nid == 1:
            story = [] # init story
            num_questions = [0] #init num_questions
            question_count = 0
            
        if '\t' not in line: #normal story sentence if '\t' is not in line
            line = tokenize(line)
            line = line[:-1] if line[-1] == '.' else line
            story.append(line)
            sentence_len.append(len(line))
    
        else : #QA sentence if '\t' is in the line
            q, a, sid = line.split('\t')
            q = tokenize(q)
            q = q[:-1] if q[-1] == '?' else q
            sid = int(sid) - num_questions[int(sid)]
            data.append([story[:], q, a, sid])
            story_len.append(len(story))
            question_count += 1
            
        num_questions.append(question_count) #need to match sentence index without question index
            
    return data, story_len, sentence_len       

In [20]:
data,story_len, sentence_len = split_SQA(lines)
print(data[1])
print("The longest story length:{0}\nThe longest sentence length:{1}".format(max(story_len), max(sentence_len)))

[[['Mary', 'moved', 'to', 'the', 'bathroom'], ['John', 'went', 'to', 'the', 'hallway'], ['Daniel', 'went', 'back', 'to', 'the', 'hallway'], ['Sandra', 'moved', 'to', 'the', 'garden']], ['Where', 'is', 'Daniel'], 'hallway', 3]
The longest story length:10
The longest sentence length:6


가장 긴 story의 길이는 10이며, 가장 긴 sentence의 길이는 6이다. 제법 단순한 문장들로 구성되어 있다. 

In [6]:
def make_dictionary(lines):
    word2idx = {}
    idx = 1
    for line in lines:
        line.lower()
        _, line = line.split(' ',1)
        if '\t' in line:
            line = line.split('\t')[0]
        line = tokenize(line)
        line = line[:-1] if line[-1] is '?' or '.' else line
        for w in line:
            if w not in word2idx.keys():
                word2idx[w] = idx
                idx += 1
    return word2idx

In [21]:
dic = make_dictionary(lines)
print(dic)

{'Mary': 1, 'moved': 2, 'to': 3, 'the': 4, 'bathroom': 5, 'John': 6, 'went': 7, 'hallway': 8, 'Where': 9, 'is': 10, 'Daniel': 11, 'back': 12, 'Sandra': 13, 'garden': 14, 'office': 15, 'journeyed': 16, 'travelled': 17, 'bedroom': 18, 'kitchen': 19}


등장하는 단어의 수도 무척 제한적이다. 이제 단어들로 이루어진 문장들을 index로 바꾸어 저장하자. 추가적으로 sentence zero padding과 story zero padding을 함께 진행해준다.

In [94]:
import numpy as np

def make_batch(data, sentence_len, memory_len, dic):
    S = [];Q = []; A = []; Support = []
    for story, question, answer, support in data:
        #delete front part of story that exceeds memrory length 
        start = max(len(story) - memory_len,0)
        story = story[start:]
        #(1)convert words to idx and (2)zero-pad to match the sentence length
        story_idx = []
        for sentence in story:
            story_idx.append([dic[w] for w in sentence] + [0]*(sentence_len - len(sentence)))
        #zero-pad to match the memroy length
        for _ in range(memory_len - len(story_idx)):
            story_idx.append([0]*sentence_len)
        
        question_idx = [[dic[w] for w in question] + [0]*(sentence_len - len(question))]
        answer_idx = dic[answer]
        
        S.append(story_idx); Q.append(question_idx); A.append(answer_idx); Support.append(support)
    return np.array(S),np.array(Q),np.array(A),np.array(Support)
        
        

batch size가 2인 상태를 예시로 들어 Memory network 내부구조 구현을 살펴보자

In [95]:
ss_len = max(sentence_len)
mem_len = max(story_len)
S,Q,A,Support = make_batch(data[:2], max(sentence_len), max(story_len), dic)
print(S)

[[[ 1  2  3  4  5  0]
  [ 6  7  3  4  8  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]]

 [[ 1  2  3  4  5  0]
  [ 6  7  3  4  8  0]
  [11  7 12  3  4  8]
  [13  2  3  4 14  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]]]


In [96]:
print(S.shape,Q.shape,A.shape)

(2, 10, 6) (2, 1, 6) (2,)


In [97]:
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.layers as layers

In [98]:
emb_a = layers.Embedding(input_dim = len(dic)+1, output_dim=12)
emb_b = layers.Embedding(input_dim = len(dic)+1, output_dim=12)
emb_c = layers.Embedding(input_dim = len(dic)+1, output_dim=12)

In [99]:
a = emb_a(S)
print(a.shape)
b = emb_b(Q)
print(b.shape)
c = emb_c(S)
print(c.shape)

(2, 10, 6, 12)
(2, 1, 6, 12)
(2, 10, 6, 12)


dimension of Story: batch_size X memory_length X sentence_length X word_embedding_dimension

dimension of Question: batch_size X 1 X sentence_length X word_embedding_dimension

In [100]:
def get_avg_word_emb(sentence_word_idx, sentence_word_emb):
    '''
    intput: 
        sentence_word_idx : [batch_size,memory_length,sentence_length]
        sentence_word_emb : [batch_size,memory_length,sentence_length,word_emb]
    output: 
        sentence_emb: [batch_size,memory_length,word_emb]
        average sentences
    '''
    # sentence_word_idx ->  not_zero:[batch_size,memory_length,1]
    # 1 if word index is not zero, else 0
    not_zero = tf.not_equal(sentence_word_idx, 0)
    not_zero = tf.cast(tf.expand_dims(not_zero,-1), tf.float32)
    
    mul = tf.multiply(sentence_word_emb,not_zero)
    return tf.reduce_sum(mul,-2)

In [101]:
keys = get_avg_word_emb(S,a)
query = get_avg_word_emb(Q,b)
values = get_avg_word_emb(S,c)

print(keys.shape, query.shape, values.shape)

(2, 10, 12) (2, 1, 12) (2, 10, 12)


In [107]:
def get_attention_score(keys, query):
    '''
    input:
        keys: [batch, mem_size, sentence_emb]
        query: [batch, sentence_emb]
    output:
        attn_score: [batch,mem_size], 
        attention socres for each memory
    '''
    #calcuate dot product
    #dot product-> logits: [batch,mem_size]
    elemwise_mul = tf.multiply(keys, query)
    logits = tf.reduce_sum(elemwise_mul,-1)
    
    #zero's of logit: padding sentence. set that value as negative inf
    logits = logits + tf.cast(tf.equal(logits,0.),tf.float32)*-1e+10
    attn_score = tf.nn.softmax(logits)
    
    return attn_score

In [110]:
attn_score = get_attention_score(keys,query)
print(attn_score)

tf.Tensor(
[[0.49940035 0.5005996  0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.25186282 0.24918939 0.24856508 0.25038275 0.         0.
  0.         0.         0.         0.        ]], shape=(2, 10), dtype=float32)


In [None]:
def get_output_memory_represntation()

In [101]:
c = tf.cast(tf.expand_dims(z,-1), tf.float32)

In [107]:
t = tf.multiply(o,c)
print(t)
t.shape

tf.Tensor(
[[[[-0.04579781 -0.03127763  0.02413743 ...  0.0434393   0.04954794
     0.03501357]
   [-0.04754801 -0.02171919 -0.00236537 ... -0.01449496  0.02593795
     0.01264102]
   [ 0.01127317  0.02023664 -0.0270755  ...  0.03585057 -0.02598071
     0.01751279]
   [-0.01103199 -0.04447675  0.04603953 ...  0.04143052  0.02226057
    -0.0499106 ]
   [-0.0175122   0.03206041  0.02658495 ... -0.03840851  0.04460161
     0.04957979]
   [-0.          0.         -0.         ...  0.          0.
     0.        ]]

  [[-0.00634755 -0.03342913  0.00969497 ... -0.03706624 -0.02374871
    -0.00558214]
   [ 0.04983275 -0.04245949 -0.0041325  ... -0.02977382  0.04058473
     0.02706334]
   [ 0.01127317  0.02023664 -0.0270755  ...  0.03585057 -0.02598071
     0.01751279]
   [-0.01103199 -0.04447675  0.04603953 ...  0.04143052  0.02226057
    -0.0499106 ]
   [-0.00936512 -0.04077234 -0.02926142 ... -0.03386031 -0.03828931
     0.04946932]
   [-0.          0.         -0.         ...  0.          0.


TensorShape([2, 10, 6, 12])

In [106]:
tf.reduce_sum(t,-2)

<tf.Tensor: id=251, shape=(2, 10, 12), dtype=float32, numpy=
array([[[-0.11061684, -0.04517652,  0.06732103,  0.10635129,
         -0.09996976,  0.08914635, -0.04336112,  0.04540521,
          0.0312294 ,  0.06781693,  0.11636735,  0.06483655],
        [ 0.03436127, -0.14090106, -0.00473492,  0.0174717 ,
          0.02808759, -0.01624556,  0.02488458, -0.02590782,
         -0.07966189, -0.02341929, -0.02517342,  0.03855269],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.       