In [1]:
import os
os.chdir("/home/tiagoalmeida/bioASQ-taskb/")

import sys
import numpy as np
import pickle
import gc
import json

from pubmed_data import pubmed_helper as ph

##add keras to the modules
module_path = os.path.abspath(os.path.join('pubmed_data'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
from tensorflow.keras.layers import Embedding, Input, LSTM, Dot, Activation, Concatenate
from tensorflow.keras import Model
from tensorflow.keras.callbacks import ModelCheckpoint, LambdaCallback
from tensorflow.keras.preprocessing.sequence import pad_sequences
import tensorflow.keras.backend as K

def dssm_projectiom_model(activation='relu'):
    
    lstm_1 = LSTM(200, activation=activation, return_sequences=True)
        
    lstm_2 = LSTM(200, activation=activation, return_sequences=True)

    lstm_3 = LSTM(200, activation=activation)

    def build_model(inputs, embedding_layer ,model_name):
        x = embedding_layer(inputs)
        x = lstm_1(x)
        x = lstm_2(x)
        x = lstm_3(x)

        return Model(inputs=[inputs], outputs=[x], name=model_name)
    
    return build_model


In [3]:
MODE = "regex_full_tokens"

In [21]:
#load
tk = ph.load_tokenizer(MODE)
embedding_dict = ph.load_embeddings(MODE)


Load regex_full_tokens_tokenizer.p
Load regex_full_tokens_word_embedding.p


In [42]:
VOC = len(tk.word_counts)

#matrix
delta = 0.0001
embedding_matrix = np.zeros((VOC, embedding_dict[1].shape[0]))+delta

for i in range(1,len(embedding_dict)):
    embedding_matrix[i] = embedding_dict[i]



In [53]:
NUM_NEG_EXAMPLES = 4

K.clear_session()

embedding_layer = Embedding( input_dim = embedding_matrix.shape[0],
           output_dim = embedding_matrix.shape[1], 
           weights=[embedding_matrix], 
           trainable=False,
           name="embedding_layer")
# Follow the paper arch

#The INPUT will be the result of the hash trick layer
query = Input(shape = (None,), name = "dssm_query_input")
pos_doc = Input(shape = (None,), name = "dssm_pos_doc_input")
neg_docs = [Input(shape = (None,), name = ("dssm_neg_doc_input_"+str(i))) for i in range(NUM_NEG_EXAMPLES)]



#Create a sub model of the network (siamese arch)
#2 Inputs query and doc
q_input = Input(shape = (None,), name= "q_input")
doc_input = Input(shape = (None,), name= "doc_input")


projection_input = Input(shape = (None,), name= "projection_input")

#same weights
dssm_model_builder = dssm_projectiom_model()
projection_model = dssm_model_builder(projection_input,embedding_layer,"projection_model")



#same model for both
query_projection_model = projection_model
doc_projection_model = projection_model


query_projection = query_projection_model(q_input)
doc_projection = doc_projection_model(doc_input)
#similarity between the query and the docs
q_doc_sim = Dot(axes=1,normalize=True)([query_projection,doc_projection])

sub_model = Model(inputs=[q_input, doc_input], outputs=[q_doc_sim], name="siamese_model")
sub_model.summary()

#Making the softmax approximation for 1 pos doc and N neg doc
q_doc_pos_output = sub_model([query,pos_doc])
q_doc_neg_output = [sub_model([query,neg_doc]) for neg_doc in neg_docs]

concat = Concatenate(axis=1)([q_doc_pos_output]+q_doc_neg_output)

#missing the smoth factor
prob = Activation("softmax")(concat)

dssm_model = Model(inputs=[query,pos_doc]+neg_docs,outputs=prob)

dssm_model.summary()

#try the sgd optimizer
dssm_model.compile(optimizer='sgd',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
q_input (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
doc_input (InputLayer)          (None, None)         0                                            
__________________________________________________________________________________________________
projection_model (Model)        (None, 200)          859321200   q_input[0][0]                    
                                                                 doc_input[0][0]                  
__________________________________________________________________________________________________
dot (Dot)                       (None, 1)            0           projection_model[1][0]           
          

In [4]:
#load Tokenized articles
articles_gen = ph.create_tokenized_pubmed_collection_generator(mode=MODE)

tokenized_articles = []
for articles in articles_gen():
    tokenized_articles.extend(articles)

del articles
gc.collect()

Open /backup/pubmed_archive_tokenized/regex_full_tokens_title_abs.tar.gz
Creating generator
Open the file: regex_full_tokens_file_000_title_abs_pubmed.p
Returning: 3690895 articles
Force garbage collector 0
Open the file: regex_full_tokens_file_001_title_abs_pubmed.p
Returning: 3643138 articles
Force garbage collector 0
Open the file: regex_full_tokens_file_002_title_abs_pubmed.p
Returning: 3790281 articles
Force garbage collector 0
Open the file: regex_full_tokens_file_003_title_abs_pubmed.p
Returning: 3838006 articles
Force garbage collector 0
Open the file: regex_full_tokens_file_004_title_abs_pubmed.p
Returning: 3862035 articles
Force garbage collector 0


In [23]:
#load pmid_index_map
pmid_index_map = ph.pmid_index_mapping()

#Load bioASQ training

bioASQ_data_path = "/backup/BioASQ-training7b/"
bioASQ_data_train = json.load(open(os.path.join(bioASQ_data_path,"7b_train_split.json")))
bioASQ_data_test = json.load(open(os.path.join(bioASQ_data_path,"7b_test_split.json")))



Load /backup/saved_models/pmid_index_mapping.p


In [24]:
#convert to the generator expected format
convert_format = lambda data:list(map( lambda x:{"query":tk.texts_to_sequences([x["body"]])[0],"documents":list(map(lambda x:pmid_index_map[x] ,x["documents"]))} ,data))

bioASQ_data_train = convert_format(bioASQ_data_train)
bioASQ_data_test = convert_format(bioASQ_data_test)

In [50]:
#Create training generator
#For each query get true positive
#For each true positive par with negative samples

class DSSM_Train_Generator(object):
    def __init__(self, collection, training_data, query_batch_size=16, neg_samples = NUM_NEG_EXAMPLES):
        """
        articles: list of the tokenized articles
        training_data: list of dict following {query:"string",documents:[index list]} 
        """
        
        self.collection = collection
        self.num_docs = len(collection)
        self.training_data = training_data
        self.neg_samples = neg_samples
        self.query_batch_size = query_batch_size
        
        training_samples = sum([ len(q["documents"]) for q in training_data])
        self.num_steps = training_samples//query_batch_size

    def _negative_random_documents(self, exclude):

        #bad approach! but the selection exclude is a lot small that the num articles...
        neg_random_indexs = np.random.randint(0,self.num_docs,(self.neg_samples,))

        while any([i in exclude for i in neg_random_indexs]):
            neg_random_indexs = np.random.randint(0,self.num_docs,(self.neg_samples,))

        return neg_random_indexs
        
    def __iter__(self):
        
        #each list represent a data colum 
        queries = []
        pos_docs = []
        neg_docs = [ [] for _ in range(self.neg_samples) ]
        
        max_len_doc_in_batch = 0
        max_article_index = len(self.collection)
        
        #LOOP the training DATA
        while True:

            for query_data in self.training_data:

                if len(queries)>=self.query_batch_size:
                    #apply the padding
                    queries = np.array(queries)
                    queries = pad_sequences(queries, padding="post", maxlen=max_len_doc_in_batch)
                    pos_docs = pad_sequences(pos_docs, padding="post", maxlen=max_len_doc_in_batch)
                    neg_docs = [ pad_sequences(neg_doc, padding="post",  maxlen=max_len_doc_in_batch) for neg_doc in neg_docs]
                    

                    
                    X = [queries,pos_docs]+neg_docs

                    Y = np.array([[1]+[0]*self.neg_samples]*len(queries))

                    yield (X,Y)
                    queries = []
                    pos_docs = []
                    neg_docs = [ [] for _ in range(self.neg_samples) ]
                    max_len_doc_in_batch = 0
                    
                else:
                    pos_doc_set = {document_index for document_index in query_data["documents"]}
          
                    for index_article in pos_doc_set:
                        
                        batch_doc_len = []
                        
                        queries.append(query_data["query"])
                        
                        pos_doc = self.collection[index_article]
                        pos_docs.append(pos_doc)
                        batch_doc_len.append(len(pos_doc))
                        
                        neg_random_indexs = self._negative_random_documents(pos_doc_set)
                        
                        for i in range(self.neg_samples):
                            neg_doc = self.collection[neg_random_indexs[i]]
                            neg_docs[i].append(neg_doc)
                            batch_doc_len.append(len(neg_doc))

                        
                        #document with higher len in this batch
                        max_len_doc_in_batch = max([max_len_doc_in_batch] + batch_doc_len)

                        
    def __len__(self):
        return self.num_steps


In [51]:
gen = DSSM_Train_Generator(tokenized_articles, bioASQ_data_train)

In [46]:
d = next(iter(gen))

{5140771, 8647685, 8378186, 17319307, 6030864, 16603954, 7654035, 8420211, 7655446, 8651132, 17305694, 8141695}
{15793729, 12123889}
{14305434}
{12906624, 15282756, 17014153, 6122281, 12961678, 473877, 1698102, 10869623, 9052246, 6245657}
{16866240, 2519681, 10334051, 16597829, 2290572, 5985871, 11077039, 18183477, 17139414, 8940792, 16874073}


In [54]:
dssm_model.fit_generator(iter(gen), steps_per_epoch=len(gen), epochs=2)

Epoch 1/2
   9/1510 [..............................] - ETA: 5:51:38 - loss: 1.6086 - acc: 0.5439

KeyboardInterrupt: 