In [1]:
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models

In [2]:
tf.__version__

'2.1.0'

In [3]:
emb_path = './data/word_emb.json'
with open(emb_path) as f:
    embedding_matrix = json.load(f)
embedding_matrix = np.asarray(embedding_matrix, dtype=np.float32)
embedding_matrix.shape

(88714, 300)

In [4]:
learning_rate = 0.001 #"Learning rate"
batch_size = 100 #"Batch size to use"
dropout = 0 #1-keep_prob
context_len = 600 #"The maximum context length of your model"
question_len = 30 #"The maximum question length of your model"
(VOCAB_SIZE, embedding_size) = embedding_matrix.shape
hidden_size = 200 #"Size of the hidden states"

In [6]:
context = layers.Input(shape=(context_len, ))
question = layers.Input(shape=(question_len, ))

context_embs = layers.Embedding(input_dim=VOCAB_SIZE,
                            output_dim=embedding_size,
                            weights=[embedding_matrix],
                            trainable=False)(context)
qn_embs = layers.Embedding(input_dim=VOCAB_SIZE,
                            output_dim=embedding_size,
                            weights=[embedding_matrix],
                            trainable=False)(question)

masking_layer = layers.Masking()
masked_context_embs = masking_layer(context_embs)
masked_qn_embs = masking_layer(qn_embs)

In [15]:
forward_layer = layers.GRU(hidden_size, dropout=dropout, return_sequences=True) 
backward_layer = layers.GRU(hidden_size, dropout=dropout, return_sequences=True, go_backwards=True)
bidirect_gru = layers.Bidirectional(forward_layer, 
                                    backward_layer=backward_layer)
context_hiddens = bidirect_gru(masked_context_embs)
question_hiddens = bidirect_gru(masked_qn_embs)

attn_output = layers.Attention()([context_hiddens, question_hiddens])
attn_output = layers.Dropout(dropout)(attn_output)
blended_reps = layers.concatenate([context_hiddens, attn_output], axis=2)

In [16]:
blended_reps_final = layers.Dense(units=hidden_size,
                                  activation='relu')(blended_reps)

In [17]:
tf.shape(context_hiddens)

<tf.Tensor 'Shape_1:0' shape=(3,) dtype=int32>

In [32]:
logits_start = layers.Dense(units=1, activation=None)(blended_reps_final)
logits_start = tf.squeeze(logits_start, axis=[2])
prob_dist_start = logits_start #layers.Softmax(axis=1)(logits_start)

logits_end = layers.Dense(units=1, activation=None)(blended_reps_final)
logits_end = tf.squeeze(logits_end, axis=[2]) 
prob_dist_end = logits_end #layers.Softmax(axis=1)(logits_end)

In [38]:
def loss(labels, logits):
    #return tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

In [39]:
model = models.Model(inputs=[context, question], outputs=[prob_dist_start, prob_dist_end])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate),
              loss=loss,
              metrics=['accuracy'])

In [40]:
model.summary()

Model: "model_9"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 600)]        0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None, 30)]         0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 600, 300)     26614200    input_3[0][0]                    
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 30, 300)      26614200    input_4[0][0]                    
____________________________________________________________________________________________

In [None]:
# ans_span = tf.placeholder(tf.int32, shape=[None, 2])

model.fit(x=None, 
          y=ans_span, 
          batch_size=batch_size, 
          epochs=1, 
          verbose=1, 
          callbacks=None) 