In [1]:
import json
from os.path import join
import os
os.chdir("/home/tiagoalmeida/bioASQ-taskb/")

import sys
import pickle
import gc
import numpy as np
import tarfile
import random
from collections import defaultdict
from bisect import bisect


##add keras to the modules
module_path = os.path.abspath(os.path.join('pubmed_data'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from pubmed_data import pubmed_helper as ph

import


## DeepRank
Reference PAPER :https://arxiv.org/pdf/1710.05649.pdf

### Network structure
 - [General Network Configuration](#var_def)
 - [Input Network](#input_net)
 - [Measure Network](#measure_net)
 - [Aggregation Network](#aggreation_net)
 - [Final Network](#final_net)

In [2]:
#Load tokenizer and the embedding matrix

MODE = "regex_full_tokens"
tk = ph.load_tokenizer(mode=MODE)
emb_dict = ph.load_embeddings(mode=MODE)

assert len(tk.word_counts) == len(emb_dict)

#Number of different words
VOCAB_SIZE = len(tk.word_counts)+1

#Dimension of embeddings
EMB_DIM = emb_dict[1].shape[0]

emb_matrix = np.zeros((VOCAB_SIZE, EMB_DIM))

for i,vector in emb_dict.items():
    emb_matrix[i] = vector


Load regex_full_tokens_tokenizer.p
Load regex_full_tokens_word_embedding.p


In [3]:
#Tokenized stopwords

biomedical_stop_words = ["a", "about", "again", "all", "almost", "also", "although", "always", "among", "an", "and", "another", "any", "are", "as", "at", "be", "because", "been", "before", "being", "between", "both", "but", "by", "can", "could", "did", "do", "does", "done", "due", "during", "each", "either", "enough", "especially", "etc", "for", "found", "from", "further", "had", "has", "have", "having", "here", "how", "however", "i", "if", "in", "into", "is", "it", "its", "itself", "just", "kg", "km", "made", "mainly", "make", "may", "mg", "might", "ml", "mm", "most", "mostly", "must", "nearly", "neither", "no", "nor", "obtained", "of", "often", "on", "our", "overall", "perhaps", "pmid", "quite", "rather", "really", "regarding", "seem", "seen", "several", "should", "show", "showed", "shown", "shows", "significantly", "since", "so", "some", "such", "than", "that", "the", "their", "theirs", "them", "then", "there", "therefore", "these", "they", "this", "those", "through", "thus", "to", "upon", "use", "used", "using", "various", "very", "was", "we", "were", "what", "when", "which", "while", "with", "within", "without", "would"]
biomedical_stop_words_tokens = set(tk.texts_to_sequences([biomedical_stop_words])[0])


<a id='var_def'></a>
## General Network Configuration

In [4]:

from tensorflow import unstack, stack
##Test 
from tensorflow.keras import backend as K
from tensorflow.keras import initializers, regularizers, activations
from tensorflow.keras.initializers import Zeros, Ones, Constant
from tensorflow.keras.layers import Dense, Lambda, Bidirectional, Dot,Masking,Reshape, Concatenate, Layer, Embedding, Input, Conv2D, GlobalMaxPooling2D, Flatten, TimeDistributed, GRU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.activations import tanh, sigmoid


from tensorflow.keras.preprocessing.sequence import pad_sequences

from models.deep_model_for_ir.custom_layers import MaskedSelfAttention

K.clear_session()

#Number max of term per query
MAX_Q_TERM = 13

#Number max of the snippet terms
QUERY_CENTRIC_CONTEX = 15

#Number max of passages per query term
MAX_PASSAGES_PER_QUERY = 5

#Snippet position padding value
SNIPPET_POSITION_PADDING_VALUE = -1

#Mode for the creation of the S matrix
S_MATRIX_MODE = 0
#S_MATRIX_DIMENSION = EMB_DIM*2+1

#Train embedding weights
EMB_TRAINABLE = False

#Number of filters in CNN
CNN_FILTERS = 40
CNN_KERNELS = (3,3)

#RNN DIM
USE_BIDIRECTIONAL = False
GRU_REPRESENTATION_DIM = 58

ACTIVATION_FUNCTION = "selu"

REGULARIZATION = regularizers.l2(0.0001)

#Term gating network mode
TERM_GATING_MODE =  3#2- weigt fixed per position, 1 - DRMM like term gating

assert S_MATRIX_MODE in [0,1]
assert TERM_GATING_MODE in [0,1,2,3]

#MACRO STYLE
def S_MATRIX_3D_DIMENSION():
    if S_MATRIX_MODE==0:
        return 1
    elif S_MATRIX_MODE==1:
        return EMB_DIM*2+1

DEBUG = False

<a id='input_net'></a>
## Input Network


In [5]:

"""""""""""""""""""""""""""
  ---- Custom Layers ----
"""""""""""""""""""""""""""
class SimilarityMatrix(Layer):
    
    def __init__(self, query_max_term, snippet_max_term, interaction_mode=0, **kwargs):
        """
        interaction mode 0: only use similarity matrix
                    mode 1: similarity matrix + query and snippet embeddings
        """
        assert interaction_mode in [0,1] #only valid modes
        
        self.query_max_term = query_max_term
        self.snippet_max_term = snippet_max_term
        self.interaction_mode = interaction_mode
        
        super().__init__(**kwargs)
        
    def call(self,x):
        if self.interaction_mode==0:
            #sim => dot product (None, MAX_Q_TERM, EMB_DIM) x (None, MAX_Q_TERM, MAX_PASSAGE_PER_Q, EMB_DIM, QUERY_CENTRIC_CONTEX)
            query = K.expand_dims(x[0], axis=1) #(None, 1, MAX_Q_TERM, EMB_DIM)
            query = K.expand_dims(query, axis=1) #(None, 1, 1, MAX_Q_TERM, EMB_DIM)
            query = K.repeat_elements(query,x[1].shape[1],axis=1) #(None, MAX_PASSAGE_PER_Q, MAX_Q_TERM, EMB_DIM)
            query = K.repeat_elements(query,x[1].shape[2],axis=2)
            s_matrix = K.batch_dot(query,x[1]) #(None, MAX_PASSAGE_PER_Q, MAX_Q_TERM, #(None, MAX_PASSAGE_PER_Q, MAX_Q_TERM, EMB_DIM)
            
            s_matrix = K.expand_dims(s_matrix)
            
            return s_matrix #Add one more dimension #(None, MAX_PASSAGE_PER_Q, MAX_Q_TERM, #(None, MAX_PASSAGE_PER_Q, MAX_Q_TERM, EMB_DIM, 1)
        elif self.interaction_mode==1:
            raise NotImplementedError("interaction mode of layer SimilarityMatrix is not implemented")
    """                  
    def compute_output_shape(self, input_shape):
        if self.interaction_mode==0:
            return (input_shape[0][0], input_shape[0][1], self.query_max_term, self.snippet_max_term, 1)
        elif self.interaction_mode==1:
            return (input_shape[0][0], input_shape[0][1], self.query_max_term, self.snippet_max_term, input_shape[0][2]*input_shape[1][2]+1) 
    """

"""""""""""""""""""""""""""
     ---- Layers ----
"""""""""""""""""""""""""""
#Embedding Layer
embedding = Embedding(VOCAB_SIZE,EMB_DIM, name="embedding_layer",weights=[emb_matrix], trainable=EMB_TRAINABLE)

#S matrix ref in the paper
similarity_matrix = SimilarityMatrix(MAX_Q_TERM, QUERY_CENTRIC_CONTEX, interaction_mode=S_MATRIX_MODE, name="query_snippet_similarity")

#transpose (None, QUERY_CENTRIC_CONTEX, EMB_DIM) => (None, EMB_DIM, QUERY_CENTRIC_CONTEX) 
transpose_layer = Lambda(lambda x:K.permute_dimensions(x,[0,1,2,4,3]), name="snippet_transpose") 

"""""""""""""""""""""""""""
 ---- Auxiliar Models ----
"""""""""""""""""""""""""""

#Snippet single embedding transformation
snippet_token_input = Input(shape = (MAX_Q_TERM, MAX_PASSAGES_PER_QUERY, QUERY_CENTRIC_CONTEX,), name = "snippet_token")
snippet_emb = embedding(snippet_token_input)
snippet_emb_transpose = transpose_layer(snippet_emb)
snippet_emb_model = Model(inputs = [snippet_token_input], outputs=[snippet_emb_transpose], name = "snippet_emb_model")
print("\n\nsnippet_emb_model summary")
snippet_emb_model.summary()

"""""""""""""""""""""""""""
  ---- Input Network ----
"""""""""""""""""""""""""""

if DEBUG:
    query_token_input = Input(shape=(MAX_Q_TERM,), name="query_tokens")


    snippets_tokens_input = Input(shape = (MAX_Q_TERM, MAX_PASSAGES_PER_QUERY, QUERY_CENTRIC_CONTEX), name = "snippet_tokens_ipmodel") 
    
    query_emb = embedding(query_token_input)

    snippet_emb = embedding(snippets_tokens_input)
    snippet_emb_transpose = transpose_layer(snippet_emb)
    
    sim_matrix_layer = similarity_matrix([query_emb,snippet_emb_transpose])
    
    
    
    input_model = Model(inputs = [query_token_input,snippets_tokens_input], outputs=[sim_matrix_layer], name="input_model")
    print("\n\ninput_model summary")
    input_model.summary()
    
    
    print("\nOutput tensor",sim_matrix_layer)
    



snippet_emb_model summary
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
snippet_token (InputLayer)   (None, 13, 5, 15)         0         
_________________________________________________________________
embedding_layer (Embedding)  (None, 13, 5, 15, 200)    858359000 
_________________________________________________________________
snippet_transpose (Lambda)   (None, 13, 5, 200, 15)    0         
Total params: 858,359,000
Trainable params: 0
Non-trainable params: 858,359,000
_________________________________________________________________


<a id='measure_net'></a>
## Measure Network

In [6]:
class MaskedConv2D(Layer):
    
    def __init__(self, filters, kernel_size, activation, initializer='glorot_normal', regularizer=None, **kargs):
        super(MaskedConv2D, self).__init__(**kargs)

        self.activation = activations.get(activation)
        self.initializer = initializers.get(initializer)
        
        if regularizer is None or isinstance(regularizer,str):
            self.regularizer = regularizers.get(regularizer)
        else:
            self.regularizer = regularizer
        
        self.filters = filters
        self.kernel_size = kernel_size

    def build(self, input_shape):

        input_filter = int(input_shape[-1])
        
        self.kernel_3_3 = self.add_variable(name = "conv_kernel_3_3",
                                   shape = (3,3,input_filter,CNN_FILTERS),
                                   initializer = self.initializer,
                                   regularizer = self.regularizer,)
        
        self.kernel_5_1 = self.add_variable(name = "conv_kernel_5_1",
                                   shape = (5,1,input_filter,CNN_FILTERS),
                                   initializer = self.initializer,
                                   regularizer = self.regularizer,)
        
        self.kernel_1_5 = self.add_variable(name = "conv_kernel_1_5",
                                   shape = (1,5,input_filter,CNN_FILTERS),
                                   initializer = self.initializer,
                                   regularizer = self.regularizer,)
        
        self.kernel_3_3_bias = self.add_variable(name = "conv_kernel_3_3_bias",
                                   shape = (self.filters,),)
        
        self.kernel_5_1_bias = self.add_variable(name = "conv_kernel_5_1_bias",
                                   shape = (self.filters,),)
        
        self.kernel_1_5_bias = self.add_variable(name = "conv_kernel_1_5_bias",
                                   shape = (self.filters,),)
        
        #end dimensions = 7, 9, 100

        
        super(MaskedConv2D, self).build(input_shape)
    
    def call(self, x):
        
        condition = K.all(x) #if all the values are the same
        inv_condition = (1-K.cast(condition, K.floatx()))
        
        kernel_3_3 = K.conv2d(x, self.kernel_3_3)
        kernel_3_3 = K.bias_add(kernel_3_3, self.kernel_3_3_bias)
        kernel_3_3 = self.activation(kernel_3_3)
        kernel_3_3_pool = K.pool2d(kernel_3_3,(11,13))
        
        kernel_5_1 = K.conv2d(x, self.kernel_5_1)
        kernel_5_1 = K.bias_add(kernel_5_1, self.kernel_5_1_bias)
        kernel_5_1 = self.activation(kernel_5_1)
        kernel_5_1_pool = K.pool2d(kernel_5_1,(9,15))
        
        kernel_1_5 = K.conv2d(x, self.kernel_1_5)
        kernel_1_5 = K.bias_add(kernel_1_5, self.kernel_1_5_bias)
        kernel_1_5 = self.activation(kernel_1_5)
        kernel_1_5_pool = K.pool2d(kernel_1_5,(13,11))
        
        print(kernel_3_3_pool)
        print(kernel_5_1_pool)
        print(kernel_1_5_pool)
        
        kernel_3_3_flat = K.reshape(kernel_3_3_pool,(-1,self.filters))
        kernel_5_1_flat = K.reshape(kernel_5_1_pool,(-1,self.filters))
        kernel_1_5_flat = K.reshape(kernel_1_5_pool,(-1,self.filters))
        print(kernel_3_3_flat)
        print(kernel_5_1_flat)
        print(kernel_1_5_flat)
        
        concat =  K.concatenate([kernel_3_3_flat,kernel_5_1_flat,kernel_1_5_flat])
        #print(concat)
        
        #proj = K.dot(concat, self.dense)
        #proj = K.bias_add(proj,self.dense_bias)
        #proj = self.activation(proj)

        
        return concat * inv_condition

In [7]:

    

"""""""""""""""""""""""""""
 ---- Auxiliar Models ----
"""""""""""""""""""""""""""

#Exctrate high-level features from query and snippet interactions with CNN
cnn_extraction_model = Sequential(name="cnn_extraction_model")
cnn_extraction_model.add(MaskedConv2D(input_shape = (MAX_Q_TERM, QUERY_CENTRIC_CONTEX, S_MATRIX_3D_DIMENSION()), filters = CNN_FILTERS, kernel_size=CNN_KERNELS, activation=ACTIVATION_FUNCTION ))
#cnn_extraction_model.add(GlobalMaxPooling2D())
print("\n\ncnn_extraction_model summary")
cnn_extraction_model.summary()


td_cnn_extraction_model = Sequential(name="TD_cnn_extraction_model")
td_cnn_extraction_model.add(TimeDistributed(cnn_extraction_model, input_shape=(MAX_PASSAGES_PER_QUERY, MAX_Q_TERM, QUERY_CENTRIC_CONTEX, S_MATRIX_3D_DIMENSION())))
td_cnn_extraction_model.summary()

"""""""""""""""""""""""""""
     ---- Layers ----
"""""""""""""""""""""""""""
#concatenation layer over the last dimension
concat_snippet_position = Concatenate( name = "concat_snippet_position")

self_attention = MaskedSelfAttention(CNN_FILTERS*3+1)

#add dimension Layer
add_passage_dim = Lambda(lambda x:K.expand_dims(x,axis=1), name="add_passage_dim")#Reshape(target_shape=(1,GRU_REPRESENTATION_DIM))

#add last dimension Layer
add_dim = Lambda(lambda x:K.expand_dims(x), name="add_dim")

#reciprocal function
reciprocal_f = Lambda(lambda x:1/(x+2), name="reciprocal_function")

#concatenation layer over second dimension (passage dimension)
concat_representation = Concatenate(axis = 1,name = "concat_representation")

Tensor("masked_conv2d/MaxPool:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("masked_conv2d/MaxPool_1:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("masked_conv2d/MaxPool_2:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("masked_conv2d/Reshape:0", shape=(?, 40), dtype=float32)
Tensor("masked_conv2d/Reshape_1:0", shape=(?, 40), dtype=float32)
Tensor("masked_conv2d/Reshape_2:0", shape=(?, 40), dtype=float32)


cnn_extraction_model summary
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
masked_conv2d (MaskedConv2D) (None, 120)               880       
Total params: 880
Trainable params: 880
Non-trainable params: 0
_________________________________________________________________
Tensor("time_distributed/masked_conv2d/MaxPool:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("time_distributed/masked_conv2d/MaxPool_1:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("time_distributed/masked_conv2d/MaxPool_2:0", sh

<a id='measure_net'></a>
## Aggregation Network

In [8]:

"""""""""""""""""""""""""""
  ---- Custom Layers ----
"""""""""""""""""""""""""""

snippet_rnn_rep_dim = CNN_FILTERS


    
class TermGatingDRMM_FFN(Layer):
    
    def __init__(self, embedding_dim = EMB_DIM, rnn_dim = snippet_rnn_rep_dim ,activation=None, initializer='glorot_normal', regularizer=None):
        super(TermGatingDRMM_FFN, self).__init__()

        self.activation = activations.get(activation)
        self.initializer = initializers.get(initializer)
        
        if regularizer is None or isinstance(regularizer,str):
            self.regularizer = regularizers.get(regularizer)
        else:
            self.regularizer = regularizer
        
        self.emb_dim = embedding_dim
        self.rnn_dim = rnn_dim

    def build(self, input_shape):
        
        #term gating W
        self.W_query = self.add_variable(name = "term_gating_We",
                                   shape = [self.emb_dim,1],
                                   initializer = self.initializer,
                                   regularizer = self.regularizer,)
        
        self.dense_score = Dense(1,kernel_regularizer = self.regularizer, activation=self.activation)
        
        dense_shape = input_shape[1]
        print(dense_shape)
        
        self.dense_score.build((dense_shape[0],dense_shape[2]))
        self._trainable_weights += self.dense_score.trainable_weights
        #self.ones = K.constant(np.ones((aggreation_dimension,1)))
        
        super(TermGatingDRMM_FFN, self).build(input_shape)
    
    def call(self, x):
        
        query_embeddings = x[0] #(None, MAX_Q_TERM, EMB_SIZE)
        snippet_representation_per_query = x[1] #(None, MAX_Q_TERM, BI_GRU_DIM)
        
        #compute gated weights
        gated_logits = K.squeeze(K.dot(query_embeddings, self.W_query), axis = -1 )
        #print(gated_logits)
        gated_distribution = K.expand_dims(K.softmax(gated_logits))
        #print(gated_distribution)
        #snippet projection
        
        weighted_score = K.sum(snippet_representation_per_query * gated_distribution,  axis = 1)
        print(weighted_score)
        
        return self.dense_score(weighted_score) # Replace with K.sum of all elements?

<a id='final_net'></a>
## Final Network

In [9]:

"""""""""""""""""""""""""""
  ---- Final Network ----
"""""""""""""""""""""""""""
query_token_input = Input(shape=(MAX_Q_TERM,), name="ds_query_tokens")
doc_score_snippet_input = Input(shape = (MAX_Q_TERM,MAX_PASSAGES_PER_QUERY,QUERY_CENTRIC_CONTEX), name = "ds_snippet_tokens")
doc_score_snippet_position_input = Input(shape = (MAX_Q_TERM,MAX_PASSAGES_PER_QUERY), name = "ds_snippet_position_tokens")


unstack_by_q_term = Lambda(lambda x:unstack(x,axis=1), name="unstack_query_term")

#doc_score_snippet_by_q_term = unstack_by_q_term(doc_score_snippet_input)
#doc_score_snippet_position_by_q_term = unstack_by_q_term(doc_score_snippet_position_input)

#INPUT in token format
#query_token_input = Input(shape=(MAX_Q_TERM,), name="query_tokens")
#snippets_tokens_input = [Input(shape = (MAX_PASSAGES_PER_QUERY,QUERY_CENTRIC_CONTEX), name = "snippet_tokens_"+str(q)) for q in range(MAX_Q_TERM)]
#inputs_contex_position = [Input(shape = (MAX_PASSAGES_PER_QUERY,), name = "q_context_position_"+str(q)) for q in range(MAX_Q_TERM)]

query_emb = embedding(query_token_input)

doc_score_snippet_emb = embedding(doc_score_snippet_input)
doc_score_snippet_emb_transpose = transpose_layer(doc_score_snippet_emb)

query_snippets_s_matrix = similarity_matrix([query_emb,doc_score_snippet_emb_transpose])

list_of_s_matrix_by_q_term = unstack_by_q_term(query_snippets_s_matrix)
list_of_snippet_postion_by_q_term = unstack_by_q_term(doc_score_snippet_position_input)

relevance_representation = []
for i in range(MAX_Q_TERM):
    
    snippet_relative_position = reciprocal_f(list_of_snippet_postion_by_q_term[i])
    
    local_relevance = td_cnn_extraction_model(list_of_s_matrix_by_q_term[i])
    
    local_relevance_position = concat_snippet_position([local_relevance,add_dim(snippet_relative_position)])
    
    relevance_representation.append(add_passage_dim(self_attention(local_relevance_position)))

concat_relevance = concat_representation(relevance_representation)

if TERM_GATING_MODE==0:
    term_gating = TermGating(vocab_size=VOCAB_SIZE, activation=ACTIVATION_FUNCTION)
    document_score = term_gating([query_token_input,concat_relevance])
    
elif TERM_GATING_MODE==1:
    term_gating = TermGatingDRMM()
    document_score = term_gating([query_emb,concat_relevance])

elif TERM_GATING_MODE==2:
    term_gating = TermGatingDRMM_Projection()
    document_score = term_gating([query_emb,concat_relevance])

elif TERM_GATING_MODE==3:
    term_gating = TermGatingDRMM_FFN(activation=ACTIVATION_FUNCTION, regularizer=REGULARIZATION)
    document_score = term_gating([query_emb,concat_relevance])

document_score_model = Model(inputs = [query_token_input, doc_score_snippet_input, doc_score_snippet_position_input], outputs = [document_score], name="query_document_score")
document_score_model.summary()      



Tensor("TD_cnn_extraction_model/time_distributed/masked_conv2d/MaxPool:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("TD_cnn_extraction_model/time_distributed/masked_conv2d/MaxPool_1:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("TD_cnn_extraction_model/time_distributed/masked_conv2d/MaxPool_2:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("TD_cnn_extraction_model/time_distributed/masked_conv2d/Reshape:0", shape=(?, 40), dtype=float32)
Tensor("TD_cnn_extraction_model/time_distributed/masked_conv2d/Reshape_1:0", shape=(?, 40), dtype=float32)
Tensor("TD_cnn_extraction_model/time_distributed/masked_conv2d/Reshape_2:0", shape=(?, 40), dtype=float32)
condition Tensor("masked_self_attention/All:0", shape=(?, 5, 1), dtype=bool)
inv_condition Tensor("masked_self_attention/sub:0", shape=(?, 5, 1), dtype=float32)
x_projection Tensor("masked_self_attention/Reshape_2:0", shape=(?, 5, 121), dtype=float32)
x_tanh Tensor("masked_self_attention/Tanh:0", shape=(?, 5, 121), dtype=float32)
x_attention 

Tensor("TD_cnn_extraction_model_6/time_distributed/masked_conv2d/MaxPool:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("TD_cnn_extraction_model_6/time_distributed/masked_conv2d/MaxPool_1:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("TD_cnn_extraction_model_6/time_distributed/masked_conv2d/MaxPool_2:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("TD_cnn_extraction_model_6/time_distributed/masked_conv2d/Reshape:0", shape=(?, 40), dtype=float32)
Tensor("TD_cnn_extraction_model_6/time_distributed/masked_conv2d/Reshape_1:0", shape=(?, 40), dtype=float32)
Tensor("TD_cnn_extraction_model_6/time_distributed/masked_conv2d/Reshape_2:0", shape=(?, 40), dtype=float32)
condition Tensor("masked_self_attention_6/All:0", shape=(?, 5, 1), dtype=bool)
inv_condition Tensor("masked_self_attention_6/sub:0", shape=(?, 5, 1), dtype=float32)
x_projection Tensor("masked_self_attention_6/Reshape_2:0", shape=(?, 5, 121), dtype=float32)
x_tanh Tensor("masked_self_attention_6/Tanh:0", shape=(?, 5, 121), dtype=f

x_attention_softmax Tensor("masked_self_attention_12/transpose_3:0", shape=(?, 5, 1), dtype=float32)
x_scored_emb Tensor("masked_self_attention_12/mul_1:0", shape=(?, 5, 121), dtype=float32)
x_attention_rep Tensor("masked_self_attention_12/Sum:0", shape=(?, 121), dtype=float32)
(?, 13, 121)
Tensor("term_gating_drmm_ffn/Sum:0", shape=(?, 121), dtype=float32)
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
ds_query_tokens (InputLayer)    (None, 13)           0                                            
__________________________________________________________________________________________________
ds_snippet_tokens (InputLayer)  (None, 13, 5, 15)    0                                            
__________________________________________________________________________________________________
embedding_layer (Embedding)     multiple      

## FINAL Trainable arch

In [10]:


query_token_input = Input(shape=(MAX_Q_TERM,), name="dr_query_tokens")
positive_snippet_input = Input(shape = (MAX_Q_TERM,MAX_PASSAGES_PER_QUERY,QUERY_CENTRIC_CONTEX), name = "positive_snippet_tokens")
positive_snippet_position_input = Input(shape = (MAX_Q_TERM,MAX_PASSAGES_PER_QUERY), name = "positive_snippet_position_tokens")
negative_snippet_input = Input(shape = (MAX_Q_TERM,MAX_PASSAGES_PER_QUERY,QUERY_CENTRIC_CONTEX), name = "negative_snippet_tokens")
negative_snippet_position_input = Input(shape = (MAX_Q_TERM,MAX_PASSAGES_PER_QUERY), name = "negative_snippet_position_tokens")

positive_documents_score = document_score_model([query_token_input, positive_snippet_input, positive_snippet_position_input])
negative_documents_score = document_score_model([query_token_input, negative_snippet_input, negative_snippet_position_input])

#stack_socres = stack_scores_layer([positive_documents_score,negative_documents_score])



#pairwise_loss_layer = Lambda(pairwise_hinge_loss, name="pairwise_hinge")
#pairwise_loss = pairwise_loss_layer([positive_documents_score,negative_documents_score])



inputs = [query_token_input, positive_snippet_input, positive_snippet_position_input, negative_snippet_input, negative_snippet_position_input]

deepRank_model = Model(inputs = inputs, outputs = [positive_documents_score, negative_documents_score], name="deep_rank")


p_loss = K.mean(K.maximum(0.0, 1.0 - positive_documents_score + negative_documents_score))

deepRank_model.add_loss(p_loss)

deepRank_model.summary() 
#m.predict([Q, Q_t1_passage, Q_t2_passage, Q_t3_passage, Q_t1_passage_pos, Q_t2_passage_pos, Q_t3_passage_pos])

#deepRank_model.compile(loss=pairwise_hinge_loss, optimizer='sgd')

Tensor("query_document_score/TD_cnn_extraction_model/time_distributed/masked_conv2d/MaxPool:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("query_document_score/TD_cnn_extraction_model/time_distributed/masked_conv2d/MaxPool_1:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("query_document_score/TD_cnn_extraction_model/time_distributed/masked_conv2d/MaxPool_2:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("query_document_score/TD_cnn_extraction_model/time_distributed/masked_conv2d/Reshape:0", shape=(?, 40), dtype=float32)
Tensor("query_document_score/TD_cnn_extraction_model/time_distributed/masked_conv2d/Reshape_1:0", shape=(?, 40), dtype=float32)
Tensor("query_document_score/TD_cnn_extraction_model/time_distributed/masked_conv2d/Reshape_2:0", shape=(?, 40), dtype=float32)
Tensor("query_document_score/TD_cnn_extraction_model_1/time_distributed/masked_conv2d/MaxPool:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("query_document_score/TD_cnn_extraction_model_1/time_distributed/masked_conv2

condition Tensor("query_document_score/masked_self_attention/All:0", shape=(?, 5, 1), dtype=bool)
inv_condition Tensor("query_document_score/masked_self_attention/sub:0", shape=(?, 5, 1), dtype=float32)
x_projection Tensor("query_document_score/masked_self_attention/Reshape_2:0", shape=(?, 5, 121), dtype=float32)
x_tanh Tensor("query_document_score/masked_self_attention/Tanh:0", shape=(?, 5, 121), dtype=float32)
x_attention Tensor("query_document_score/masked_self_attention/Reshape_5:0", shape=(?, 5, 1), dtype=float32)
x_attention_maked Tensor("query_document_score/masked_self_attention/add:0", shape=(?, 5, 1), dtype=float32)
x_attention_softmax Tensor("query_document_score/masked_self_attention/transpose_3:0", shape=(?, 5, 1), dtype=float32)
x_scored_emb Tensor("query_document_score/masked_self_attention/mul_1:0", shape=(?, 5, 121), dtype=float32)
x_attention_rep Tensor("query_document_score/masked_self_attention/Sum:0", shape=(?, 121), dtype=float32)
condition Tensor("query_document_

x_attention_rep Tensor("query_document_score/masked_self_attention_9/Sum:0", shape=(?, 121), dtype=float32)
condition Tensor("query_document_score/masked_self_attention_10/All:0", shape=(?, 5, 1), dtype=bool)
inv_condition Tensor("query_document_score/masked_self_attention_10/sub:0", shape=(?, 5, 1), dtype=float32)
x_projection Tensor("query_document_score/masked_self_attention_10/Reshape_2:0", shape=(?, 5, 121), dtype=float32)
x_tanh Tensor("query_document_score/masked_self_attention_10/Tanh:0", shape=(?, 5, 121), dtype=float32)
x_attention Tensor("query_document_score/masked_self_attention_10/Reshape_5:0", shape=(?, 5, 1), dtype=float32)
x_attention_maked Tensor("query_document_score/masked_self_attention_10/add:0", shape=(?, 5, 1), dtype=float32)
x_attention_softmax Tensor("query_document_score/masked_self_attention_10/transpose_3:0", shape=(?, 5, 1), dtype=float32)
x_scored_emb Tensor("query_document_score/masked_self_attention_10/mul_1:0", shape=(?, 5, 121), dtype=float32)
x_atten

Tensor("query_document_score_1/TD_cnn_extraction_model_8/time_distributed/masked_conv2d/MaxPool_1:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("query_document_score_1/TD_cnn_extraction_model_8/time_distributed/masked_conv2d/MaxPool_2:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("query_document_score_1/TD_cnn_extraction_model_8/time_distributed/masked_conv2d/Reshape:0", shape=(?, 40), dtype=float32)
Tensor("query_document_score_1/TD_cnn_extraction_model_8/time_distributed/masked_conv2d/Reshape_1:0", shape=(?, 40), dtype=float32)
Tensor("query_document_score_1/TD_cnn_extraction_model_8/time_distributed/masked_conv2d/Reshape_2:0", shape=(?, 40), dtype=float32)
Tensor("query_document_score_1/TD_cnn_extraction_model_9/time_distributed/masked_conv2d/MaxPool:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("query_document_score_1/TD_cnn_extraction_model_9/time_distributed/masked_conv2d/MaxPool_1:0", shape=(?, 1, 1, 40), dtype=float32)
Tensor("query_document_score_1/TD_cnn_extraction_model_9

x_attention Tensor("query_document_score_1/masked_self_attention_4/Reshape_5:0", shape=(?, 5, 1), dtype=float32)
x_attention_maked Tensor("query_document_score_1/masked_self_attention_4/add:0", shape=(?, 5, 1), dtype=float32)
x_attention_softmax Tensor("query_document_score_1/masked_self_attention_4/transpose_3:0", shape=(?, 5, 1), dtype=float32)
x_scored_emb Tensor("query_document_score_1/masked_self_attention_4/mul_1:0", shape=(?, 5, 121), dtype=float32)
x_attention_rep Tensor("query_document_score_1/masked_self_attention_4/Sum:0", shape=(?, 121), dtype=float32)
condition Tensor("query_document_score_1/masked_self_attention_5/All:0", shape=(?, 5, 1), dtype=bool)
inv_condition Tensor("query_document_score_1/masked_self_attention_5/sub:0", shape=(?, 5, 1), dtype=float32)
x_projection Tensor("query_document_score_1/masked_self_attention_5/Reshape_2:0", shape=(?, 5, 121), dtype=float32)
x_tanh Tensor("query_document_score_1/masked_self_attention_5/Tanh:0", shape=(?, 5, 121), dtype=float3

Tensor("query_document_score_1/term_gating_drmm_ffn/Sum:0", shape=(?, 121), dtype=float32)
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
dr_query_tokens (InputLayer)    (None, 13)           0                                            
__________________________________________________________________________________________________
positive_snippet_tokens (InputL (None, 13, 5, 15)    0                                            
__________________________________________________________________________________________________
positive_snippet_position_token (None, 13, 5)        0                                            
__________________________________________________________________________________________________
negative_snippet_tokens (InputL (None, 13, 5, 15)    0                                            
__________________

In [11]:
WRITE_GRAPH = False
if WRITE_GRAPH:
    from tensorflow.summary import FileWriter

    graph = K.get_session().graph
     # Your model implementation
    #with graph.as_default():
      # compile method actually creates the model in the graph.
      #deepRank_model.compile(loss=identity_loss, optimizer='adam', metrics=['accuracy'])
    writer = FileWriter(logdir='tensorboard/deepRank', graph=graph)
    writer.flush()




In [12]:
#document_score.predict([Q, Q_t1_passage, Q_t2_passage, Q_t3_passage, Q_t1_passage_pos, Q_t2_passage_pos, Q_t3_passage_pos])

In [13]:
path_dl_train = "/backup/results/fast_method_relevant_results/train_data_deep_models_v2.tar.gz"



tar = tarfile.open(path_dl_train)
#open
print("Open",path_dl_train)
m = tar.getmembers()[0]
f = tar.extractfile(m)
train_articles_collection = pickle.load(f)

f.close()

Open /backup/results/fast_method_relevant_results/train_data_deep_models_v2.tar.gz


In [14]:
PARTILLY_POSITIVE_SAMPLES = 2
NEGATIVE_SAMPLES = 4

class TrainDataGenerator(object):
    def __init__(self, article_collection, tokenizer, batch_queries_size):
        
        self.batch_size = batch_queries_size
        self.tokenizer = tokenizer
        
        self.train_data = article_collection["bioasq_data"]
        self.articles = article_collection["collection"]
        self.irrelevant_pmid = article_collection["irrelevant_pmid"]
        
        self.num_steps = len(self.train_data)//self.batch_size
    
    def __len__(self):
        return self.num_steps
    
    def __iter__(self):
        
        
        query = []
        query_positive_doc = []
        query_positive_doc_position = []
        query_negative_doc = []
        query_negative_doc_position = []
        
        while True:
            
            #stop condition
            if len(query)>=self.batch_size:
                #missing fill the gap for the missing query_terms
                query = np.array(query)
                p=np.random.permutation(query.shape[0])
                query = query[p]
                query_positive_doc = np.array(query_positive_doc)[p]
                query_positive_doc_position = np.array(query_positive_doc_position)[p]
                query_negative_doc = np.array(query_negative_doc)[p]
                query_negative_doc_position =  np.array(query_negative_doc_position)[p]
                
                X = [query, query_positive_doc, query_positive_doc_position, query_negative_doc, query_negative_doc_position]
                #Y = [np.zeros((len(query))),np.zeros((len(query)))]
                yield X


                #reset
                query = []
                query_positive_doc = []
                query_positive_doc_position = []
                query_negative_doc = []
                query_negative_doc_position = []
            
            #select a random question
            random_query_index = random.randint(0, len(self.train_data)-1) 
            query_data = self.train_data[random_query_index]
            
            #list of partilly relevant documents
            partilly_positive_pmid_docs = query_data["partilly_positive_pmid"]

            tokenized_query = query_data["query"][:MAX_Q_TERM]
            
            for j in range(PARTILLY_POSITIVE_SAMPLES+NEGATIVE_SAMPLES):
                #select a random positive
                random_doc_index = random.randint(0, len(query_data["positive_pmid"])-1) 
                doc_pmid = query_data["positive_pmid"][random_doc_index]

                tokenized_positive_doc = self.articles[doc_pmid]
                positive_snippets, positive_snippets_position = self.__snippet_interaction(tokenized_query, tokenized_positive_doc)
                
                if j<PARTILLY_POSITIVE_SAMPLES:
                    #select the partilly posivite doc
                    random_ind = bisect(query_data["partially_positive_cumulative_prob"],random.random())
                    random_negative_doc_pmid = query_data["partilly_positive_pmid"][random_ind]
                    #print(self.__get_article(random_negative_doc_pmid))
                    tokenized_negative_doc = self.articles[random_negative_doc_pmid]
                    negative_snippets, negative_snippets_position = self.__snippet_interaction(tokenized_query, tokenized_negative_doc)
                else:
                    #select a random negative
                    random_doc_index = random.randint(0, len(self.irrelevant_pmid)-1) 
                    doc_pmid = self.irrelevant_pmid[random_doc_index]
                    
                    tokenized_negative_doc = self.articles[doc_pmid]
                    negative_snippets, negative_snippets_position = self.__snippet_interaction(tokenized_query, tokenized_negative_doc)
                
                
                ### add ###

                #not efficient
                query.append(tokenized_query)

                #positive doc
                query_positive_doc.append(positive_snippets)
                query_positive_doc_position.append(positive_snippets_position)

                #negative doc
                query_negative_doc.append(negative_snippets)
                query_negative_doc_position.append(negative_snippets_position)
            

            
    def __snippet_interaction(self, tokenized_query, tokenized_doc, snippet_length=QUERY_CENTRIC_CONTEX):
        
        snippets = []
        snippets_position = [] 

        half_size = snippet_length//2
        
        #O(n^2) complexity, probably can do better with better data struct TODO see if is worthit
        for query_token in tokenized_query:
            
            snippets_per_token = []
            snippets_per_token_position = []
            
            if query_token != 0: #jump padded token
            
                for i,doc_token in enumerate(tokenized_doc):

                    if doc_token==query_token:

                        lower_index = i-half_size
                        lower_index = max(0,lower_index)

                        higher_index = i+half_size
                        higher_index = min(len(tokenized_doc),higher_index)

                        snippets_per_token.append(tokenized_doc[lower_index:higher_index])
                        snippets_per_token_position.append(i)
            
            if len(snippets_per_token)==0:
                snippets.append(np.zeros((MAX_PASSAGES_PER_QUERY,QUERY_CENTRIC_CONTEX), dtype=np.int32))
                snippets_position.append(np.zeros((MAX_PASSAGES_PER_QUERY), dtype=np.int32)+SNIPPET_POSITION_PADDING_VALUE)
                continue
                
            max_snippets_len = min(MAX_PASSAGES_PER_QUERY, len(snippets_per_token))
            
            ### snippets in matrix format
            #pad
            snippets_per_token = pad_sequences(snippets_per_token, maxlen = QUERY_CENTRIC_CONTEX, padding="post")
            #fill the gaps
            _temp = np.zeros((MAX_PASSAGES_PER_QUERY,QUERY_CENTRIC_CONTEX), dtype=np.int32)
            _temp[:max_snippets_len] = snippets_per_token[:max_snippets_len]
            snippets.append(_temp)
            
            ### snippets_position in matrix format
            #pad
            snippets_per_token_position = pad_sequences([snippets_per_token_position], maxlen = MAX_PASSAGES_PER_QUERY, padding="post", value=SNIPPET_POSITION_PADDING_VALUE)[0]
            snippets_position.append(snippets_per_token_position)
            
        return snippets, snippets_position
            
        
        


## Test (validation) data generator


In [15]:
path_dl_test = "/backup/results/fast_method_relevant_results/test_data_deep_models_v2.tar.gz"



tar = tarfile.open(path_dl_test)
#open
print("Open",path_dl_test)
m = tar.getmembers()[0]
f = tar.extractfile(m)
test_articles_collection = pickle.load(f)

f.close()

Open /backup/results/fast_method_relevant_results/test_data_deep_models_v2.tar.gz


In [16]:

class TestDataGenerator(object):
    def __init__(self, article_collection, tokenizer):
        
        self.tokenizer = tokenizer
        
        self.test_data = article_collection["bioasq_data"] 
        self.articles = article_collection["collection"]
        
        self.num_steps = len(self.test_data)
        
    
    def __get_article(self, pmid):
        return self.article_map(self.articles[pmid])
    
    def __len__(self):
        return self.num_steps
    
    def __iter__(self):
        
        
        query = []
        query_doc = []
        query_doc_position = []
        

        for query_data in self.test_data:

            #tokenized_query = self.tokenizer.texts_to_sequences([query_data["query"]])[0]
            tokenized_query = query_data["query"][:MAX_Q_TERM]
            #manualy remove the stopwords
            #tokenized_query = [ token for token in tokenized_query if token not in biomedical_stop_words_tokens]

            #tokenized_query = pad_sequences([tokenized_query], maxlen = MAX_Q_TERM, padding="post")[0]

            for doc_pmid in query_data["documents"]:
                #positive

                tokenized_doc = self.articles[doc_pmid]
                doc_snippets, doc_snippets_position = self.__snippet_interaction(tokenized_query, tokenized_doc)

                ### add ###

                query.append(tokenized_query)

                #positive doc
                query_doc.append(doc_snippets)
                query_doc_position.append(doc_snippets_position)


            #missing fill the gap for the missing query_terms

            X = [np.array(query), np.array(query_doc), np.array(query_doc_position)]

            yield X

            #reset
            query = []
            query_doc = []
            query_doc_position = []

                
    def __snippet_interaction(self, tokenized_query, tokenized_doc, snippet_length=QUERY_CENTRIC_CONTEX):
        
        snippets = []
        snippets_position = [] 

        half_size = snippet_length//2
        
        #O(n^2) complexity, probably can do better with better data struct TODO see if is worthit
        for query_token in tokenized_query:
            
            snippets_per_token = []
            snippets_per_token_position = []
            
            if query_token != 0: #jump padded token
                
                for i,doc_token in enumerate(tokenized_doc):

                    if doc_token==query_token:

                        lower_index = i-half_size
                        lower_index = max(0,lower_index)

                        higher_index = i+half_size
                        higher_index = min(len(tokenized_doc),higher_index)

                        snippets_per_token.append(tokenized_doc[lower_index:higher_index])
                        snippets_per_token_position.append(i)

            if len(snippets_per_token)==0:
                snippets.append(np.zeros((MAX_PASSAGES_PER_QUERY,QUERY_CENTRIC_CONTEX), dtype=np.int32))
                snippets_position.append(np.zeros((MAX_PASSAGES_PER_QUERY), dtype=np.int32)+SNIPPET_POSITION_PADDING_VALUE)
                continue
                
            max_snippets_len = min(MAX_PASSAGES_PER_QUERY, len(snippets_per_token))
            
            ### snippets in matrix format
            #pad
            snippets_per_token = pad_sequences(snippets_per_token, maxlen = QUERY_CENTRIC_CONTEX, padding="post")
            #fill the gaps
            _temp = np.zeros((MAX_PASSAGES_PER_QUERY,QUERY_CENTRIC_CONTEX), dtype=np.int32)
            _temp[:max_snippets_len] = snippets_per_token[:max_snippets_len]
            snippets.append(_temp)
            
            ### snippets_position in matrix format
            #pad
            snippets_per_token_position = pad_sequences([snippets_per_token_position], maxlen = MAX_PASSAGES_PER_QUERY, padding="post",value=SNIPPET_POSITION_PADDING_VALUE)[0]
            snippets_position.append(snippets_per_token_position)
            
        return snippets, snippets_position

In [17]:
validation_articles_collection = {"bioasq_data":[],"collection":test_articles_collection["collection"]}

_temp = test_articles_collection["bioasq_data"][:]
random.shuffle(_temp)
print(len(_temp))

validation_percentage = 0.15

split_index = int(len(_temp)*validation_percentage)
print(split_index)

validation_articles_collection["bioasq_data"] = _temp[:split_index]
#test_articles_collection["bioasq_data"] = _temp[split_index:]

print("validation size",len(validation_articles_collection["bioasq_data"]),"test size",len(test_articles_collection["bioasq_data"]))

549
82
validation size 82 test size 549


In [18]:
def validate_test_data(data):
    data_generator = TestDataGenerator(data, tk)
    data_generator = iter(data_generator)

    query_results = {}

    for i,X in enumerate(data_generator):
        print("Predict query:",i,end="\r")
        deep_ranking = document_score_model.predict(X)
        deep_ranking = map(lambda x:x[0],deep_ranking.tolist())
        bm25_results = data["bioasq_data"][i]["documents"]
        deep_ranking_pmid = list(zip(bm25_results,deep_ranking))
        deep_ranking_pmid.sort(key=lambda x:-x[1])
        query_results[data["bioasq_data"][i]["id"]] = {"result":deep_ranking_pmid,"goldstandard":data["bioasq_data"][i]["positive_pmid"]}
        #print("save query results:",i,end="\r")
        
    return query_results


In [19]:
#validate score

def validation_score(deep_rank_test_query_results):
    id_to_remove = []
    for k,v in deep_rank_test_query_results.items():
        if len(v["goldstandard"]) == 0:
            id_to_remove.append(k)

    for k in id_to_remove:
        del deep_rank_test_query_results[k]

    deep_rank_test_query_results = list(deep_rank_test_query_results.values())

    print("TEST set, len ",len(deep_rank_test_query_results))

    expectations = list(map(lambda x:x["goldstandard"],deep_rank_test_query_results))
    predictions = list(map(lambda x:x["result"],deep_rank_test_query_results))

    #print("Recall:",f_recall(predictions,expectations,at=1000))
    bioasq_map = f_map(predictions,expectations,bioASQ=True)
    print("MAP @10 bioASQ:", bioasq_map)
    print("MAP @25:",f_map(predictions,expectations, bioASQ=True, at=25))
    print("MAP @50:",f_map(predictions,expectations, bioASQ=True, at=50))
    print("MAP @100:",f_map(predictions,expectations, bioASQ=True, at=100))
    print("MAP @200:",f_map(predictions,expectations, bioASQ=True, at=200))
    print("MAP @300:",f_map(predictions,expectations, bioASQ=True, at=300))

    print("MAP:",f_map(predictions,expectations, use_len=True))
    
    print("RECALL@10:",f_recall(predictions,expectations, at=10))
    print("RECALL@50:",f_recall(predictions,expectations, at=50))
    print("RECALL@100:",f_recall(predictions,expectations, at=100))
    return bioasq_map, predictions, expectations

## Train 

In [20]:
from tensorflow.keras.optimizers import SGD, Adam,  Adadelta

#sgd = SGD(lr=0.001)
#adam = Adam(lr=0.001)
adadelta = Adadelta(lr=2)

deepRank_model.compile( optimizer=adadelta)



In [21]:
from models.generic_model import ModelAPI, f_recall, f_map

gen = TrainDataGenerator(train_articles_collection, tk, 256)

gen_iter = iter(gen)

loss = []

for i,line in enumerate(loss):
    
    print("Epoach:",i,"| avg loss:",np.mean(loss[i]),"| max loss:",np.max(loss[i]),"| min loss:",np.min(loss[i]))

import time

max_bio_map_val = 0.166
max_bio_map_test = 0

for epoach in range(1,200):
    loss_per_epoach = []
    for step in range(len(gen)):
        X = next(gen_iter)
        
        start = time.time()
        loss_per_epoach.append(deepRank_model.train_on_batch(X))
        print("Step:",step,"| loss:",loss_per_epoach[-1],"| current max loss:",np.max(loss_per_epoach),"| current min loss:",np.min(loss_per_epoach),"| time:",time.time()-start,end="\r")
        

    
    
    if epoach%10==0:
        print("")
        validate_query_results = validate_test_data(validation_articles_collection)
        print("")
        bio_map_val, _, _ = validation_score(validate_query_results)
        if bio_map_val >= max_bio_map_val:
            max_bio_map_val = bio_map_val
            print("")
            print("Run for the test set")
            test_query_results = validate_test_data(test_articles_collection)
            bio_map_test, _, _ = validation_score(test_query_results)

            if bio_map_test >= max_bio_map_test:
                max_bio_map_test = bio_map_test
                
                deepRank_model.save_weights("deep_rank_v5_weights.h5")
                #deepRank_model.save("deep_rank_model.h5")
                
    loss.append(loss_per_epoach)
    print("",end="\r")#clear the line
    print("Epoach:",epoach,"| avg loss:",np.mean(loss[-1]),"| max loss:",np.max(loss[-1]),"| min loss:",np.min(loss[-1]))
#deepRank_model.fit_generator(gen_iter, steps_per_epoch=len(gen), verbose=1, epochs=27)

Epoach: 2 | avg loss: 0.40768892 | max loss: 0.53801274 | min loss: 0.304041480.30404148 | time: 6.665841579437256
Epoach: 3 | avg loss: 0.27792883 | max loss: 0.30427703 | min loss: 0.236191180.23619118 | time: 6.604721784591675
Epoach: 4 | avg loss: 0.24360353 | max loss: 0.27660736 | min loss: 0.20170411.20170411 | time: 6.6383769512176518
Epoach: 5 | avg loss: 0.22092175 | max loss: 0.2607345 | min loss: 0.181415110.18141511 | time: 6.7410035133361823
Epoach: 6 | avg loss: 0.19556233 | max loss: 0.2728382 | min loss: 0.147979070.14797907 | time: 6.627788782119751
Epoach: 7 | avg loss: 0.16901559 | max loss: 0.19523887 | min loss: 0.1300078 0.1300078 | time: 6.6236169338226329
Epoach: 8 | avg loss: 0.16823205 | max loss: 0.21954162 | min loss: 0.132492070.13249207 | time: 6.626317977905273
Epoach: 9 | avg loss: 0.13484278 | max loss: 0.19471215 | min loss: 0.098840386098840386 | time: 6.64189910888671993
Step: 7 | loss: 0.17270325 | current max loss: 0.19086082 | current min loss: 0

Epoach: 54 | avg loss: 0.11777384 | max loss: 0.16076581 | min loss: 0.10310105.10310105 | time: 6.60422611236572373
Epoach: 55 | avg loss: 0.11555499 | max loss: 0.18000238 | min loss: 0.077465020.07746502 | time: 6.564379930496216
Epoach: 56 | avg loss: 0.11012677 | max loss: 0.1386212 | min loss: 0.08543737.08543737 | time: 6.61575913429260250225
Epoach: 57 | avg loss: 0.12539126 | max loss: 0.1717216 | min loss: 0.089674175.089674175 | time: 6.63014650344848654
Epoach: 58 | avg loss: 0.117917046 | max loss: 0.15706517 | min loss: 0.08812550588125505 | time: 6.69188880920410266
Epoach: 59 | avg loss: 0.1155462 | max loss: 0.14528348 | min loss: 0.082360710.08236071 | time: 6.52817082405090326
Step: 7 | loss: 0.13142581 | current max loss: 0.14645652 | current min loss: 0.07414343 | time: 6.5420241355896395
Predict query: 81
TEST set, len  80
MAP @10 bioASQ: 0.14267509920634916
MAP @25: 0.1867262154474334
MAP @50: 0.21659930483592973
MAP @100: 0.2366132364058704
MAP @200: 0.243469037

Epoach: 106 | avg loss: 0.10524091 | max loss: 0.15994973 | min loss: 0.044178147044178147 | time: 6.6230583190917975
Epoach: 107 | avg loss: 0.103249095 | max loss: 0.15968287 | min loss: 0.058363485836348 | time: 6.6496825218200685
Epoach: 108 | avg loss: 0.10421786 | max loss: 0.14180015 | min loss: 0.07998557479985574 | time: 6.6740245819091809
Epoach: 109 | avg loss: 0.10680083 | max loss: 0.15517062 | min loss: 0.06045667560456675 | time: 6.66395282745361355
Step: 7 | loss: 0.051427174 | current max loss: 0.16211236 | current min loss: 0.051427174 | time: 6.579084396362305
Predict query: 81
TEST set, len  80
MAP @10 bioASQ: 0.14408234126984126
MAP @25: 0.20151347643477097
MAP @50: 0.23116764427406028
MAP @100: 0.24852335337180703
MAP @200: 0.25536605265642776
MAP @300: 0.2590835149853847
MAP: 0.29953991831421733
RECALL@10: 0.4565406865612457
RECALL@50: 0.6898932383191283
RECALL@100: 0.7610401283514673
Epoach: 110 | avg loss: 0.11023672 | max loss: 0.16211236 | min loss: 0.0514271

Epoach: 158 | avg loss: 0.098402716 | max loss: 0.17036521 | min loss: 0.059220655922065 | time: 6.708971023559572835
Epoach: 159 | avg loss: 0.10253416 | max loss: 0.11882699 | min loss: 0.09040440690404406 | time: 6.54732060432434115
Step: 7 | loss: 0.08078201 | current max loss: 0.13074794 | current min loss: 0.07708773 | time: 6.7263538837432865
Predict query: 81
TEST set, len  80
MAP @10 bioASQ: 0.14608085317460315
MAP @25: 0.19757751430977302
MAP @50: 0.2292144385375543
MAP @100: 0.2465068524110678
MAP @200: 0.25326074788863856
MAP @300: 0.257039216193323
MAP: 0.3046795766513028
RECALL@10: 0.4657106346991215
RECALL@50: 0.6915287366339513
RECALL@100: 0.7630047044888375
Epoach: 160 | avg loss: 0.09582196 | max loss: 0.13074794 | min loss: 0.07708773
Epoach: 161 | avg loss: 0.08770729 | max loss: 0.10745096 | min loss: 0.05163588451635884 | time: 6.5934236049652143
Epoach: 162 | avg loss: 0.09907551 | max loss: 0.13771719 | min loss: 0.06011619860116198 | time: 6.7035443782806472
Ep

In [25]:
#max_bio_map_test = 0.155

test_query_results = validate_test_data(test_articles_collection)
bio_map_test, _, _ = validation_score(test_query_results)

if bio_map_test >= max_bio_map_test:
    max_bio_map_test = bio_map_test

    deepRank_model.save_weights("deep_rank_v5_weights.h5")


TEST set, len  542
MAP @10 bioASQ: 0.16037383588121584
MAP @25: 0.23260139423996665
MAP @50: 0.27425804306276125
MAP @100: 0.2948849950455449
MAP @200: 0.3043438557295708
MAP @300: 0.30763173528831983
MAP: 0.2765173823526237
RECALL@10: 0.40202250614540214
RECALL@50: 0.6916148570996985
RECALL@100: 0.785004485743941


In [24]:

gen = TrainDataGenerator(train_articles_collection, tk, 256)

gen_iter = iter(gen)

#loss = []

for i,line in enumerate(loss):
    
    print("Epoach:",i,"| avg loss:",np.mean(loss[i]),"| max loss:",np.max(loss[i]),"| min loss:",np.min(loss[i]))

import time

#max_bio_map_val = 0.166
#max_bio_map_test = 0

for epoach in range(200,600):
    loss_per_epoach = []
    for step in range(len(gen)):
        X = next(gen_iter)
        
        start = time.time()
        loss_per_epoach.append(deepRank_model.train_on_batch(X))
        print("Step:",step,"| loss:",loss_per_epoach[-1],"| current max loss:",np.max(loss_per_epoach),"| current min loss:",np.min(loss_per_epoach),"| time:",time.time()-start,end="\r")
        

    
    
    if epoach%10==0:
        print("")
        validate_query_results = validate_test_data(validation_articles_collection)
        print("")
        bio_map_val, _, _ = validation_score(validate_query_results)
        if bio_map_val >= max_bio_map_val:
            max_bio_map_val = bio_map_val
            print("")
            print("Run for the test set")
            test_query_results = validate_test_data(test_articles_collection)
            bio_map_test, _, _ = validation_score(test_query_results)

            if bio_map_test >= max_bio_map_test:
                max_bio_map_test = bio_map_test
                
                deepRank_model.save_weights("deep_rank_v5_weights.h5")
                #deepRank_model.save("deep_rank_model.h5")
                
    loss.append(loss_per_epoach)
    print("",end="\r")#clear the line
    print("Epoach:",epoach,"| avg loss:",np.mean(loss[-1]),"| max loss:",np.max(loss[-1]),"| min loss:",np.min(loss[-1]))
#deepRank_model.fit_generator(gen_iter, steps_per_epoch=len(gen), verbose=1, epochs=27)

Epoach: 0 | avg loss: 0.8321499 | max loss: 0.98413956 | min loss: 0.66884595
Epoach: 1 | avg loss: 0.40768892 | max loss: 0.53801274 | min loss: 0.30404148
Epoach: 2 | avg loss: 0.27792883 | max loss: 0.30427703 | min loss: 0.23619118
Epoach: 3 | avg loss: 0.24360353 | max loss: 0.27660736 | min loss: 0.20170411
Epoach: 4 | avg loss: 0.22092175 | max loss: 0.2607345 | min loss: 0.18141511
Epoach: 5 | avg loss: 0.19556233 | max loss: 0.2728382 | min loss: 0.14797907
Epoach: 6 | avg loss: 0.16901559 | max loss: 0.19523887 | min loss: 0.1300078
Epoach: 7 | avg loss: 0.16823205 | max loss: 0.21954162 | min loss: 0.13249207
Epoach: 8 | avg loss: 0.13484278 | max loss: 0.19471215 | min loss: 0.098840386
Epoach: 9 | avg loss: 0.1669903 | max loss: 0.19086082 | min loss: 0.13702452
Epoach: 10 | avg loss: 0.14433524 | max loss: 0.19414073 | min loss: 0.10807654
Epoach: 11 | avg loss: 0.13415176 | max loss: 0.15312915 | min loss: 0.11469062
Epoach: 12 | avg loss: 0.14275536 | max loss: 0.186520

Step: 7 | loss: 0.3153788 | current max loss: 0.3153788 | current min loss: 0.15015712 | time: 6.561886548995972325
Predict query: 81
TEST set, len  80
MAP @10 bioASQ: 0.1471190476190476
MAP @25: 0.2031237373058004
MAP @50: 0.23465863649205637
MAP @100: 0.24982584614790002
MAP @200: 0.25785806160688385
MAP @300: 0.26130441828096757
MAP: 0.3026465603204521
RECALL@10: 0.4656430512104854
RECALL@50: 0.6990161206451355
RECALL@100: 0.7652277815851732
Epoach: 200 | avg loss: 0.22828594 | max loss: 0.3153788 | min loss: 0.15015712
Epoach: 201 | avg loss: 0.21347909 | max loss: 0.3063248 | min loss: 0.1635835816358358 | time: 6.5157291889190672
Epoach: 202 | avg loss: 0.21698585 | max loss: 0.25497717 | min loss: 0.17182371718237 | time: 6.6319456100463875
Epoach: 203 | avg loss: 0.22261891 | max loss: 0.30146992 | min loss: 0.1699086116990861 | time: 6.453084230422974
Epoach: 204 | avg loss: 0.19445388 | max loss: 0.26273087 | min loss: 0.135962013596201 | time: 6.5555143356323241
Epoach: 205 

Epoach: 251 | avg loss: 0.19925462 | max loss: 0.24205582 | min loss: 0.165934626593462 | time: 6.68982553482055785
Epoach: 252 | avg loss: 0.22925942 | max loss: 0.26912072 | min loss: 0.1759893417598934 | time: 6.592452287673951
Epoach: 253 | avg loss: 0.20746121 | max loss: 0.28320408 | min loss: 0.1407528514075285 | time: 6.634201765060425
Epoach: 254 | avg loss: 0.22232227 | max loss: 0.27130094 | min loss: 0.1653312716533127 | time: 6.5558781623840335
Epoach: 255 | avg loss: 0.21640968 | max loss: 0.29265815 | min loss: 0.1729493317294933 | time: 6.6561615467071534
Epoach: 256 | avg loss: 0.25080276 | max loss: 0.29607013 | min loss: 0.1782515517825155 | time: 6.4931960105896686
Epoach: 257 | avg loss: 0.22579378 | max loss: 0.28147253 | min loss: 0.192418049241804 | time: 6.61109137535095225
Epoach: 258 | avg loss: 0.23322314 | max loss: 0.3164362 | min loss: 0.1382735.1382735 | time: 6.55853128433227557
Epoach: 259 | avg loss: 0.21859571 | max loss: 0.26080313 | min loss: 0.152

Epoach: 304 | avg loss: 0.21022655 | max loss: 0.2994533 | min loss: 0.1346981513469815 | time: 6.5646626949310385
Epoach: 305 | avg loss: 0.1888673 | max loss: 0.24476983 | min loss: 0.12750065.12750065 | time: 6.6005206108093265
Epoach: 306 | avg loss: 0.21121311 | max loss: 0.30465582 | min loss: 0.1181858858185885 | time: 6.62452650070190426
Epoach: 307 | avg loss: 0.20638251 | max loss: 0.24155006 | min loss: 0.1766238.1766238 | time: 6.52595853805542665
Epoach: 308 | avg loss: 0.23054263 | max loss: 0.31069753 | min loss: 0.1692671816926718 | time: 6.6174378395080574
Epoach: 309 | avg loss: 0.19951415 | max loss: 0.28349724 | min loss: 0.1550144615501446 | time: 6.6312055587768555
Step: 7 | loss: 0.19441085 | current max loss: 0.24759448 | current min loss: 0.13697295 | time: 6.628967523574829
Predict query: 81
TEST set, len  80
MAP @10 bioASQ: 0.14719593253968255
MAP @25: 0.20604292323944323
MAP @50: 0.23621282448682512
MAP @100: 0.25287734529159656
MAP @200: 0.2595332671691699


Epoach: 356 | avg loss: 0.20711686 | max loss: 0.3110124 | min loss: 0.1375435313754353 | time: 6.547342061996465
Epoach: 357 | avg loss: 0.21570858 | max loss: 0.31618002 | min loss: 0.1590669815906698 | time: 6.5439414978027346
Epoach: 358 | avg loss: 0.22131073 | max loss: 0.27506414 | min loss: 0.1514928615149286 | time: 6.5889191627502445
Epoach: 359 | avg loss: 0.19664246 | max loss: 0.24822205 | min loss: 0.1568109515681095 | time: 6.551016807556152
Step: 7 | loss: 0.21885537 | current max loss: 0.27002645 | current min loss: 0.18861534 | time: 6.5196812152862555
Predict query: 81
TEST set, len  80
MAP @10 bioASQ: 0.14855059523809516
MAP @25: 0.2078707547135153
MAP @50: 0.23841817343767344
MAP @100: 0.2546278833476608
MAP @200: 0.26118752850367855
MAP @300: 0.26475224919896123
MAP: 0.29506297857785946
RECALL@10: 0.4391246737546079
RECALL@50: 0.7062654215176952
RECALL@100: 0.7714248561036487
Epoach: 360 | avg loss: 0.22263142 | max loss: 0.27002645 | min loss: 0.18861534
Epoach: 

Epoach: 409 | avg loss: 0.21417177 | max loss: 0.24598931 | min loss: 0.171760417176041 | time: 6.53281998634338425
Step: 7 | loss: 0.23764661 | current max loss: 0.2652886 | current min loss: 0.14692824 | time: 6.6394503116607676
Predict query: 81
TEST set, len  80
MAP @10 bioASQ: 0.14582688492063484
MAP @25: 0.20742408814903893
MAP @50: 0.23619757028813707
MAP @100: 0.2531901555520491
MAP @200: 0.2597854035656408
MAP @300: 0.26351419164536205
MAP: 0.2937195725625784
RECALL@10: 0.4415398974198316
RECALL@50: 0.705537781711108
RECALL@100: 0.7774110354944332
Epoach: 410 | avg loss: 0.20994022 | max loss: 0.2652886 | min loss: 0.14692824
Epoach: 411 | avg loss: 0.19099072 | max loss: 0.27596888 | min loss: 0.1266282512662825 | time: 6.4717223644256595
Epoach: 412 | avg loss: 0.22992656 | max loss: 0.27804518 | min loss: 0.1465876514658765 | time: 6.560704469680786
Epoach: 413 | avg loss: 0.20747258 | max loss: 0.287146 | min loss: 0.1331137413311374 | time: 6.73568153381347756
Epoach: 414

Epoach: 461 | avg loss: 0.20519497 | max loss: 0.27486202 | min loss: 0.1364640613646406 | time: 6.599867820739746
Epoach: 462 | avg loss: 0.1855808 | max loss: 0.2347573 | min loss: 0.13200851.13200851 | time: 6.6398158073425295
Epoach: 463 | avg loss: 0.20338343 | max loss: 0.32490495 | min loss: 0.136411713641171 | time: 6.7280066013336187
Epoach: 464 | avg loss: 0.20392218 | max loss: 0.25717628 | min loss: 0.1566443615664436 | time: 6.587696313858032
Epoach: 465 | avg loss: 0.20888144 | max loss: 0.27753228 | min loss: 0.1286703312867033 | time: 6.5931847095489515
Epoach: 466 | avg loss: 0.22082673 | max loss: 0.2897444 | min loss: 0.165764756576475 | time: 6.69291162490844718
Epoach: 467 | avg loss: 0.20868447 | max loss: 0.26130518 | min loss: 0.1445599514455995 | time: 6.5876038074493415
Epoach: 468 | avg loss: 0.18451095 | max loss: 0.21968827 | min loss: 0.1515901215159012 | time: 6.655316352844238
Epoach: 469 | avg loss: 0.23957302 | max loss: 0.29080448 | min loss: 0.206139

Epoach: 514 | avg loss: 0.20829582 | max loss: 0.27369833 | min loss: 0.1635375516353755 | time: 6.5458853244781494
Epoach: 515 | avg loss: 0.22117099 | max loss: 0.32017148 | min loss: 0.1841390.184139 | time: 6.52384948730468757
Epoach: 516 | avg loss: 0.18846717 | max loss: 0.2546912 | min loss: 0.1505329815053298 | time: 6.5931687355041596
Epoach: 517 | avg loss: 0.21212713 | max loss: 0.28025013 | min loss: 0.1844366518443665 | time: 6.7001490592956545
Epoach: 518 | avg loss: 0.20639388 | max loss: 0.23499621 | min loss: 0.1276117112761171 | time: 6.6472306251525885
Epoach: 519 | avg loss: 0.18985757 | max loss: 0.22872236 | min loss: 0.1354240313542403 | time: 6.5481302738189755
Step: 7 | loss: 0.1749748 | current max loss: 0.2451573 | current min loss: 0.13844758 | time: 6.66505622863769575
Predict query: 81
TEST set, len  80
MAP @10 bioASQ: 0.14654315476190471
MAP @25: 0.20919477491009877
MAP @50: 0.23818001262120508
MAP @100: 0.2544878821658683
MAP @200: 0.2619737118002656
MAP

Epoach: 567 | avg loss: 0.22401659 | max loss: 0.28705812 | min loss: 0.1817863418178634 | time: 6.5294883251190186
Epoach: 568 | avg loss: 0.21090639 | max loss: 0.28363767 | min loss: 0.1389360713893607 | time: 6.6138694286346436
Epoach: 569 | avg loss: 0.21528359 | max loss: 0.2933114 | min loss: 0.1482379614823796 | time: 6.6710588932037355
Step: 7 | loss: 0.17595659 | current max loss: 0.25721458 | current min loss: 0.15529639 | time: 6.576533555984497
Predict query: 81
TEST set, len  80
MAP @10 bioASQ: 0.14692906746031745
MAP @25: 0.20794546293172292
MAP @50: 0.23510993660859897
MAP @100: 0.2521255375698135
MAP @200: 0.258883082329317
MAP @300: 0.26259068276033737
MAP: 0.2975388849721606
RECALL@10: 0.44441307808906483
RECALL@50: 0.7052321482410493
RECALL@100: 0.7998045468033858
Epoach: 570 | avg loss: 0.18849008 | max loss: 0.25721458 | min loss: 0.15529639
Epoach: 571 | avg loss: 0.21394023 | max loss: 0.27822956 | min loss: 0.1498499814984998 | time: 6.543672084808353
Epoach: 5

In [29]:
path_save = "/backup/results/deep_rank"
path_save = os.path.join(path_save, "deep_rank_v2_17_1_test_data.p")

with open(path_save, "wb") as f:
    pickle.dump(test_query_results,f)

In [None]:
## PREPARE SUBMISSION


test_bioASQ_results_results = list(map(lambda k:{"id":k[0],"documents":list(map(lambda x:"http://www.ncbi.nlm.nih.gov/pubmed/"+str(x[0]), k[1]["result"]))[:10]}, test_bioASQ_results.items()))
_temp = []

for query in bioASQ_data:
    _jump = False
    for r_query in test_bioASQ_results_results:
        if query["id"] == r_query["id"]:
            _jump = True
    
    #no match so add
    if not _jump:
        _temp.append({"id":query["id"],"documents":[]})

test_bioASQ_results_results.extend(_temp)

print(len(test_bioASQ_results_results))
assert len(test_bioASQ_results_results) == 100
a = {"questions": test_bioASQ_results_results}
with open("5b_phaseA_01.json","w") as f:
    json.dump(a,f)
    

test_bioASQ_results_results[0]


In [26]:
query_to_test_index = 0

data_generator = TestDataGenerator(test_articles_collection, tk)
data_generator = iter(data_generator)
for _ in range(query_to_test_index+1):
    X = next(data_generator)


In [65]:
re_ranking = document_score_model.predict(X)

In [66]:
re_ranking = map(lambda x:x[0],re_ranking.tolist())

In [67]:
bm25_results = test_articles_collection["bioasq_data"][query_to_test_index]["documents"]
positive_docs = test_articles_collection["bioasq_data"][query_to_test_index]["positive_pmid"]

In [68]:
re_ranking_pmid = list(zip(bm25_results,re_ranking))

In [69]:
re_ranking_pmid.sort(key=lambda x:-x[1])

In [70]:
re_ranking_pmid[:10]

[('24794627', 5.3322014808654785),
 ('30251567', 5.313037872314453),
 ('28796422', 5.227417945861816),
 ('30114722', 5.093368053436279),
 ('29947303', 5.0901007652282715),
 ('30697454', 5.0804266929626465),
 ('30569414', 4.911670207977295),
 ('28901190', 4.814671039581299),
 ('24577791', 4.803395748138428),
 ('26907255', 4.67585563659668)]

In [71]:
test_articles_collection["bioasq_data"][query_to_test_index]["query"]

array([13502,    43,   478,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0], dtype=int32)

In [72]:
positive_docs

['24554704',
 '24784583',
 '24577791',
 '23197849',
 '24035588',
 '21060967',
 '25479728',
 '21755313',
 '24469711',
 '22512788',
 '24911883',
 '24794627',
 '21464439',
 '25059784']

In [73]:
positive_docs_ranked = []
for i,result in enumerate(re_ranking_pmid):
    if result[0] in set(positive_docs):
        positive_docs_ranked.append((i,result[0],result[1]))
        
true_ranked = []
for i,pmid in enumerate(bm25_results):
    if pmid in set(positive_docs):
        true_ranked.append((i,pmid))

print(positive_docs_ranked)
print(true_ranked)

[(0, '24794627', 5.3322014808654785), (8, '24577791', 4.803395748138428), (11, '24554704', 4.6346845626831055), (12, '23197849', 4.612618446350098), (16, '21060967', 4.446898460388184), (18, '24469711', 4.439567565917969), (19, '21755313', 4.431821823120117), (25, '22512788', 4.27816104888916), (34, '21464439', 4.063064098358154), (35, '25059784', 4.056085586547852), (38, '24784583', 4.003556728363037), (43, '25479728', 3.937878131866455), (50, '24035588', 3.7958528995513916), (343, '24911883', 2.838761806488037)]
[(3, '23197849'), (4, '21755313'), (6, '25479728'), (7, '24784583'), (8, '24577791'), (9, '24035588'), (15, '22512788'), (17, '24911883'), (19, '21464439'), (22, '24794627'), (126, '25059784'), (134, '24554704'), (466, '21060967'), (2793, '24469711')]


# Test with train set, check overfit

In [38]:
query_tokens = np.array([X[0][0]])
snippet_list = np.array([X[1][0]])

query_tokens = X[0][:2]
snippet_list = X[1][:2]

model_input = [query_tokens, snippet_list]

In [39]:

print(query_tokens.shape)
print(snippet_list.shape)

(2, 15)
(2, 15, 3, 15)


In [40]:
matrix = input_model.predict(model_input)
np.array(matrix).shape

(2, 15, 3, 15, 15, 1)

In [21]:
matrix[0][4][0].shape

(15, 15, 1)

In [42]:
np.squeeze(matrix[0][7][0])

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.

In [148]:
query_tokens[0]

array([  165, 13502,    26,    61,     8,     1,    43,     2,   478],
      dtype=int32)

In [168]:
snippet_list[0][1][0]

array([  363,     5,  2386,    97,  7598,   774,    32, 13502,    18,
          15,     7, 11695,   117,    17,     0], dtype=int32)

In [169]:
emb_dict[13502]

array([ 0.05703647,  0.19594026,  0.03365219,  0.15514491,  0.00540348,
       -0.02335026, -0.06095085,  0.0226689 , -0.05668721,  0.01571985,
       -0.09896637,  0.13836679,  0.02710932,  0.06420047, -0.03692323,
        0.03899341,  0.00553868, -0.08639584, -0.05358738, -0.02609682,
        0.06495432, -0.00129713, -0.01882407, -0.10850747, -0.02421302,
        0.05556208,  0.00291283, -0.04882976,  0.01770345,  0.0035051 ,
        0.07192209, -0.00432884, -0.15161929, -0.07024549, -0.04793473,
        0.01823143,  0.10337584, -0.04076301,  0.01026187,  0.12004871,
        0.03939956, -0.03548966, -0.10689223, -0.16337523,  0.10883316,
        0.01135785,  0.03041399,  0.06011688, -0.09919181,  0.01741308,
       -0.04328503, -0.00256405, -0.11370766,  0.0522779 ,  0.0702537 ,
        0.01021139,  0.06773005,  0.01114117, -0.05878652,  0.0720681 ,
        0.05551391,  0.08731035,  0.07339004,  0.0031227 ,  0.10792159,
        0.12050318, -0.05851915, -0.08350374, -0.03341928,  0.12

In [147]:
data_generator = TrainDataGenerator(train_articles_collection, tk, 256)
data_generator = iter(data_generator)
for i in range(9):
    print(i,end="\r")
    X,Y = next(data_generator)

8

In [151]:
list(map(lambda x: len(list(filter(lambda y:y!=0,x))),X[0]))

[8,
 8,
 8,
 8,
 8,
 8,
 8,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 12,
 12,
 12,
 12,
 12,
 4,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 10,
 7,
 7,
 7,
 8,
 8,
 8,
 8,
 8,
 8,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 4,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 4,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 9,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 3,
 3,
 3,
 3,
 3

In [157]:
X[0][26]

array([ 988,  988,  279, 1208,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0], dtype=int32)

In [161]:
X[3][26]

array([[[  1075,      2,    986,      4,      6,    248,    988,   2250,
             44,    713,   6300,   6300,      0],
        [  3974,      2,    986,      4,      6,    248,    988,   2250,
             18,    713,   6300,   6300,      0],
        [    17,     17,  46298,  59753,   9143,   2106,    988,   2250,
             16,      6,    200,    174,      0]],

       [[  1075,      2,    986,      4,      6,    248,    988,   2250,
             44,    713,   6300,   6300,      0],
        [  3974,      2,    986,      4,      6,    248,    988,   2250,
             18,    713,   6300,   6300,      0],
        [    17,     17,  46298,  59753,   9143,   2106,    988,   2250,
             16,      6,    200,    174,      0]],

       [[     2,    248,    986,      4,      1,  10169,    279,      0,
              0,      0,      0,      0,      0],
        [     0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0],
        [   

In [164]:
pos,neg = deepRank_model.predict(X)

In [256]:
query_to_test_index = 1

data_generator = TestDataGenerator(train_articles_collection, tk)
data_generator = iter(data_generator)
for _ in range(query_to_test_index+1):
    X = next(data_generator)


KeyError: 'bioasq_data'

In [257]:
re_ranking = document_score_model.predict(X)

bm25_results = train_articles_collection["test_data"][query_to_test_index]["documents"]
positive_docs = train_articles_collection["test_data"][query_to_test_index]["positive_pmid"]

re_ranking_pmid = list(zip(bm25_results,re_ranking.tolist()))

re_ranking_pmid.sort(key=lambda x:-x[1])

re_ranking_pmid[:10]

[('26671317', 7.931817054748535),
 ('20975159', 7.902041435241699),
 ('20650709', 7.8478264808654785),
 ('19805301', 7.842199802398682),
 ('21731768', 7.818233013153076),
 ('24681619', 7.759010314941406),
 ('26631348', 7.714381217956543),
 ('22196114', 7.696432590484619),
 ('23817568', 7.648404598236084),
 ('26410599', 7.3919854164123535)]

In [216]:
print(tk.sequences_to_texts([train_articles_collection["test_data"][query_to_test_index]["query"]]))
positive_docs

['kind enzyme encoded proto oncogene abl1']


['21435002',
 '20841568',
 '9500553',
 '24012954',
 '18796434',
 '23842646',
 '18528425']

In [217]:
positive_docs_ranked = []
for i,result in enumerate(re_ranking_pmid):
    if result[0] in set(positive_docs):
        positive_docs_ranked.append((i,result[0],result[1]))
        
true_ranked = []
for i,pmid in enumerate(bm25_results):
    if pmid in set(positive_docs):
        true_ranked.append((i,pmid))

print(positive_docs_ranked)
print(true_ranked)

[(5, '24012954', 5.5485920906066895), (29, '9500553', 4.8889594078063965), (87, '21435002', 4.592401504516602), (112, '23842646', 4.525805473327637), (155, '18796434', 4.400295257568359), (342, '18528425', 4.139955043792725), (2244, '20841568', 3.1609909534454346)]
[(1, '9500553'), (29, '21435002'), (187, '24012954'), (309, '18528425'), (333, '23842646'), (610, '20841568'), (2354, '18796434')]


In [None]:
bm25_results