In [1]:
import argparse

from mmnrm.utils import set_random_seed, load_neural_model, load_model, load_sentence_generator, flat_list
from nir.embeddings import FastText, Word2Vec

set_random_seed()

import io
from nir.tokenizers import Regex, BioCleanTokenizer, BioCleanTokenizer2, Regex2
import numpy as np
import math
import os 
import json

import tensorflow as tf
from tensorflow.keras import backend as K

from mmnrm.dataset import TrainCollectionV2, TrainSnippetsCollectionV2, TestCollectionV2, sentence_splitter_builderV2, TrainPairwiseCollection
from mmnrm.modelsv2 import sibm2_wSnippets
from mmnrm.callbacks import TriangularLR, WandBValidationLogger, LearningRateScheduler
from mmnrm.training import PairwiseTraining, pairwise_cross_entropy
from mmnrm.utils import merge_dicts
from nltk.tokenize.punkt import PunktSentenceTokenizer

import nltk



def build_data_generators(tokenizer, queries_sw=None, docs_sw=None):
    
    punkt_sent_tokenizer = PunktSentenceTokenizer().span_tokenize
    
    def sent_tokenize(document):
        return [ document[start:end] for start,end in punkt_sent_tokenizer(document) ]
    
    def maybe_tokenize(documents):
        if "tokens" not in documents:
            split = sent_tokenize(documents["text"])
            documents["tokens"] = tokenizer.texts_to_sequences(split)
            if docs_sw is not None:
                for tokenized_sentence in documents["tokens"]:
                    tokenized_sentence = [token for token in tokenized_sentence if token not in docs_sw]
    
    def maybe_tokenize_train(documents):
        if "tokens" not in documents:
            #split = nltk.sent_tokenize(documents["text"])
            documents["tokens"] = tokenizer.texts_to_sequences(map(lambda x: x["text"],documents["snippets"]))
            if docs_sw is not None:
                for tokenized_sentence in documents["tokens"]:
                    tokenized_sentence = [token for token in tokenized_sentence if token not in docs_sw]
    
    def train_generator(data_generator):
        while True:

            # get the batch triplet
            query, pos_docs, pos_label, neg_docs = next(data_generator)

            # tokenization, this can be cached for efficientcy porpuses NOTE!!
            tokenized_query = tokenizer.texts_to_sequences(query)

            if queries_sw is not None:
                for tokens in tokenized_query:
                    tokenized_query = [token for token in tokens if token not in queries_sw] 
            
            saveReturn = True
            
            for batch_index in range(len(pos_docs)):
                
                # tokenizer with cache in [batch_index][tokens]
                maybe_tokenize_train(pos_docs[batch_index])
                
                # assertion
                if all([ len(sentence)==0  for sentence in pos_docs[batch_index]["tokens"]]):
                    saveReturn = False
                    break # try a new resampling, NOTE THIS IS A EASY FIX PLS REDO THIS!!!!!!!
                          # for obvious reasons
                
                maybe_tokenize_train(neg_docs[batch_index])
                
            if saveReturn: # this is not true, if the batch is rejected
                yield tokenized_query, pos_docs, pos_label, neg_docs

    def test_generator(data_generator):
        for _id, query, docs in data_generator:
            tokenized_queries = []
            for i in range(len(_id)):
                # tokenization
                tokenized_query = tokenizer.texts_to_sequences([query[i]])[0]

                if queries_sw is not None:
                    tokenized_query = [token for token in tokenized_query if token not in queries_sw] 
                
                tokenized_queries.append(tokenized_query)
                    
        
                for doc in docs[i]:
                    maybe_tokenize(doc)
                                                 
            yield _id, tokenized_queries, docs
            
    return train_generator, test_generator

def model_train_generator_for_model(model):

    if "model" in model.savable_config:
        cfg = model.savable_config["model"]
    
    train_gen, test_gen = build_data_generators(model.tokenizer)
    
    pad_tokens = lambda x, max_len, dtype='int32': tf.keras.preprocessing.sequence.pad_sequences(x, 
                                                                                           maxlen=max_len,
                                                                                           dtype=dtype, 
                                                                                           padding='post', 
                                                                                           truncating='post', 
                                                                                           value=0)

    pad_sentences = lambda x, max_lim, dtype='int32': x[:max_lim] + [[]]*(max_lim-len(x))
    
    pad_labels = lambda x, max_lim, dtype='int32': x[:max_lim] + [0]*(max_lim-len(x))
    
    def maybe_padding(document, labels = None):
        if isinstance(document["tokens"], list):
            #overflow prevention
            bounded_doc_passage = min(cfg["max_passages"],len(document["tokens"]))
            document["sentences_mask"] = [True] * bounded_doc_passage + [False] * (cfg["max_passages"]-bounded_doc_passage)
            document["tokens"] = pad_tokens(pad_sentences(document["tokens"], cfg["max_passages"]), cfg["max_p_terms"])
            if labels is not None:
                document["sentences_labels"] = pad_labels(labels, cfg["max_passages"])
                
    def train_generator(data_generator):
 
        for query, pos_docs, pos_label, neg_docs in train_gen(data_generator):
            
            query = pad_tokens(query, cfg["max_q_terms"])
            
            pos_docs_array = []
            pos_snippets_labels = []
            pos_docs_mask_array = []
            neg_docs_array = []
            neg_docs_mask_array = []
            
            # pad docs, use cache here
            for batch_index in range(len(pos_docs)):
                maybe_padding(pos_docs[batch_index], pos_label[batch_index])
                pos_docs_array.append(pos_docs[batch_index]["tokens"])
                pos_snippets_labels.append(pos_docs[batch_index]["sentences_labels"])
                pos_docs_mask_array.append(pos_docs[batch_index]["sentences_mask"])
                maybe_padding(neg_docs[batch_index])
                neg_docs_array.append(neg_docs[batch_index]["tokens"])
                neg_docs_mask_array.append(neg_docs[batch_index]["sentences_mask"])
            
            yield [query, np.array(pos_docs_array), np.array(pos_docs_mask_array), np.array(pos_snippets_labels)], [query, np.array(neg_docs_array), np.array(neg_docs_mask_array)]
            
    def test_generator(data_generator):
        
        for ids, query, docs in test_gen(data_generator):
            
            docs_ids = []
            docs_array = []
            docs_mask_array = []
            query_array = []
            query_ids = []
            
            for i in range(len(ids)):
                
                for doc in docs[i]:
                    # pad docs, use cache here
                    maybe_padding(doc)
                    docs_array.append(doc["tokens"])
                    docs_mask_array.append(doc["sentences_mask"])
                    docs_ids.append(doc["id"])
                
                query_tokens = pad_tokens([query[i]], cfg["max_q_terms"])[0]
                query_tokens = [query_tokens] * len(docs[i])
                query_array.append(query_tokens)
                    
                query_ids.append([ids[i]]*len(docs[i]))
            
            #print(np.array(docs_mask_array))
            
            yield flat_list(query_ids), [np.array(flat_list(query_array)), np.array(docs_array), np.array(docs_mask_array)], docs_ids, None
            
    return train_generator, test_generator

In [2]:

min_freq = 0
mun_itter = 15
emb_size = 200

train_batch_size=32
type_split_mode=4
use_query_sw = False
use_docs_sw = False

cache_folder = "/backup/BioASQ-9b"
index_name = "bioasq_9b"

tokenizer_class = Regex
tokenizer_cfg = {"class":tokenizer_class,
                    "attr":{
                        "cache_folder": os.path.join(cache_folder, "tokenizers"),
                        "prefix_name": index_name
                    },
                    "min_freq":min_freq}

embeddind_class = Word2Vec
embedding_cfg = {
    "class":embeddind_class,
    "attr":{
        "cache_folder": os.path.join(cache_folder, "embeddings"),
        "prefix_name":index_name,
        "path":"/backup/pre-trained_embeddings/word2vec/"+index_name+"_gensim_iter_"+str(mun_itter)+"_freq"+str(min_freq)+"_"+str(emb_size)+"_"+tokenizer_class.__name__+"_word2vec.bin",
    }
}

model_cfg = {
    "max_q_terms": 50,
    "max_passages": 20,
    "max_p_terms": 70,
    "filters": 16,
    "match_threshold": 0.99,
    "activation": "mish",
    "use_mlp_sentence_scores": False,
    "use_cnn_sentence_scores": True,
    "top_k_list": [3, 5, 10, 15],
    "use_avg_pool":True,
    "use_kmax_avg_pool":True,
    "semantic_normalized_query_match" : False,
}

cfg = {"model":model_cfg, "tokenizer": tokenizer_cfg, "embedding": embedding_cfg}


K.clear_session()

ranking_model = sibm2_wSnippets(**cfg)

DEBUG created tokenizer bioasq_9b_RegexTokenizer
False False
[LOAD FROM CACHE] Load embedding matrix from /backup/BioASQ-9b/embeddings/WORD2VEC_embedding_bioasq_9b_gensim_iter_15_freq0_200_Regex_word2vec_bioasq_9b_RegexTokenizer
Using einsum for mask bq,bps->bpqs and with embedding dim bqe,bpse->bpqs
[EMBEDDING MATRIX SHAPE] (5322623, 200)


In [3]:
train_input_generator, test_input_generator = model_train_generator_for_model(ranking_model)

training_data_used = "joint_training_hardlabel_batch_05_0.6_0.51_250"
train_collection = TrainSnippetsCollectionV2\
                            .load(training_data_used)\
                            .batch_size(train_batch_size)\
                            .set_transform_inputs_fn(train_input_generator)

In [20]:
pos_doc, neg_doc = next(train_collection.generator()) 

In [59]:
y_pos_doc = pos_doc[:3]
label_pos_doc_snppets = pos_doc[3]

pos = ranking_model(y_pos_doc)
neg = ranking_model(neg_doc)

In [22]:
## pairwise doc loss

p_wise_loss = pairwise_cross_entropy(pos_doc_score, neg_doc_score)
p_wise_loss

<tf.Tensor: shape=(), dtype=float32, numpy=0.6818266>

In [23]:
## pairwise snippet loss

pos_sentence_labels = tf.cast(tf.reshape(label_pos_doc_snppets, (-1, model_cfg["max_passages"], 1)), tf.int32)
pos_sentence_scores = tf.cast(tf.math.exp(tf.reshape(pos_snippet_score, (-1, model_cfg["max_passages"], 1))), tf.float32)

In [25]:
def transformations_for_pairwise(vector):
        
    vector_t = tf.transpose(vector, perm=[0,2,1])
    vector_repeat = tf.repeat(vector, model_cfg["max_passages"], axis=-1)
    vector_t_repeat = tf.repeat(vector_t, model_cfg["max_passages"], axis=-2)

    vector_xor = tf.cast(tf.math.logical_xor(tf.cast(vector_repeat, tf.bool), 
                                                          tf.cast(vector_t_repeat, tf.bool)), 
                                      tf.float32)

    vector_repeat = tf.cast(vector_repeat, tf.float32)
    vector_t_repeat = tf.cast(vector_t_repeat, tf.float32)

    return vector_repeat, vector_t_repeat, vector_xor

In [26]:
pos_sentence_labels_repeat, pos_sentence_labels_repeat_transpose, pos_sentence_lables_xor =  transformations_for_pairwise(pos_sentence_labels)

In [27]:
pos_sentence_scores_repeat, pos_sentence_scores_repeat_transpose, pos_sentence_scores_xor =  transformations_for_pairwise(pos_sentence_scores)

In [40]:
snippet_loss_numerator = pos_sentence_lables_xor*(pos_sentence_labels_repeat*pos_sentence_scores_repeat + pos_sentence_labels_repeat_transpose*pos_sentence_scores_repeat_transpose)
print(snippet_loss_numerator[0,:])
snippet_loss_denominator = pos_sentence_scores_repeat*pos_sentence_lables_xor + pos_sentence_lables_xor*pos_sentence_scores_repeat_transpose + 0.0000001
print(snippet_loss_denominator[0,:])

tf.Tensor(
[[0.        1.5615495 1.5615495 1.5615495 0.        0.        0.
  0.        1.5615495 1.5615495 0.        1.5615495 1.5615495 1.5615495
  1.5615495 1.5615495 1.5615495 1.5615495 1.5615495 1.5615495]
 [1.5615495 0.        0.        0.        1.5319595 1.5319486 1.5318403
  1.5318676 0.        0.        1.5172572 0.        0.        0.
  0.        0.        0.        0.        0.        0.       ]
 [1.5615495 0.        0.        0.        1.5319595 1.5319486 1.5318403
  1.5318676 0.        0.        1.5172572 0.        0.        0.
  0.        0.        0.        0.        0.        0.       ]
 [1.5615495 0.        0.        0.        1.5319595 1.5319486 1.5318403
  1.5318676 0.        0.        1.5172572 0.        0.        0.
  0.        0.        0.        0.        0.        0.       ]
 [0.        1.5319595 1.5319595 1.5319595 0.        0.        0.
  0.        1.5319595 1.5319595 0.        1.5319595 1.5319595 1.5319595
  1.5319595 1.5319595 1.5319595 1.5319595 1.5319595 

In [96]:
snippet_loss = snippet_loss_numerator/snippet_loss_denominator
snippet_loss = tf.reshape(snippet_loss, (-1,))

snippet_loss_mask = snippet_loss>0.01
snippet_loss_indices = tf.cast(tf.where(snippet_loss_mask), tf.int32)

In [97]:
snippet_loss = tf.gather_nd(snippet_loss, snippet_loss_indices)

In [98]:
snippet_loss = -tf.math.log(snippet_loss)
snippet_loss  


<tf.Tensor: shape=(2160,), dtype=float32, numpy=
array([0.68411934, 0.67430913, 0.69311476, ..., 0.7043274 , 0.7043274 ,
       0.7043274 ], dtype=float32)>

In [99]:
snippet_loss = tf.math.reduce_mean(snippet_loss)
snippet_loss

<tf.Tensor: shape=(), dtype=float32, numpy=0.68908364>

In [64]:
snippet_loss = snippet_loss_sum/(snippet_loss_num+ 0.0000001)
snippet_loss = tf.math.reduce_mean(snippet_loss)
snippet_loss

<tf.Tensor: shape=(), dtype=float32, numpy=0.64544165>

In [4]:
gamma = 0.65
def transformations_for_pairwise(vector):
        
    vector_t = tf.transpose(vector, perm=[0,2,1])
    vector_repeat = tf.repeat(vector, model_cfg["max_passages"], axis=-1)
    vector_t_repeat = tf.repeat(vector_t, model_cfg["max_passages"], axis=-2)

    vector_xor = tf.cast(tf.math.logical_xor(tf.cast(vector_repeat, tf.bool), 
                                                          tf.cast(vector_t_repeat, tf.bool)), 
                                      tf.float32)

    vector_repeat = tf.cast(vector_repeat, tf.float32)
    vector_t_repeat = tf.cast(vector_t_repeat, tf.float32)

    return vector_repeat, vector_t_repeat, vector_xor
def joint_loss_pairwise(pos, neg, pos_label, neg_label):
        
    pos_score = pos[0]
    neg_score = neg[0]

    p_wise_loss = pairwise_cross_entropy(pos_score, neg_score)

    ## PAIRWISE SNIPPET SCORE LOSS

    pos_sentence_labels = tf.cast(tf.reshape(pos_label, (-1, model_cfg["max_passages"], 1)), tf.int32)
    pos_sentence_scores = tf.cast(tf.math.exp(tf.reshape(pos[1], (-1, model_cfg["max_passages"], 1))), tf.float32)

    # labels
    pos_sentence_labels_repeat, pos_sentence_labels_repeat_transpose, pos_sentence_lables_xor =  transformations_for_pairwise(pos_sentence_labels)

    # scores
    pos_sentence_scores_repeat, pos_sentence_scores_repeat_transpose, pos_sentence_scores_xor =  transformations_for_pairwise(pos_sentence_scores)

    snippet_loss_numerator = pos_sentence_lables_xor*(pos_sentence_labels_repeat*pos_sentence_scores_repeat + pos_sentence_labels_repeat_transpose*pos_sentence_scores_repeat_transpose)

    snippet_loss_denominator = pos_sentence_scores_repeat*pos_sentence_lables_xor + pos_sentence_lables_xor*pos_sentence_scores_repeat_transpose + 0.0000001
    """
    snippet_loss = tf.math.log(snippet_loss_numerator/snippet_loss_denominator)

    snippet_loss = tf.reshape(snippet_loss, (-1,model_cfg["max_passages"]*model_cfg["max_passages"]))

    snippet_loss_mask = tf.cast(snippet_loss>-5, tf.float32)

    snippet_loss = -snippet_loss

    snippet_loss_sum = tf.math.reduce_sum(snippet_loss*snippet_loss_mask, axis=-1)
    snippet_loss_num = tf.math.reduce_sum(snippet_loss_mask,  axis=-1)


    snippet_loss = snippet_loss_sum/(snippet_loss_num+ 0.0000001)

    snippet_loss = tf.math.reduce_mean(snippet_loss)
    """

    snippet_loss = snippet_loss_numerator/snippet_loss_denominator
    snippet_loss = tf.reshape(snippet_loss, (-1,))

    snippet_loss_mask = snippet_loss>0.01
    snippet_loss_indices = tf.cast(tf.where(snippet_loss_mask), tf.int32)

    snippet_loss = tf.gather_nd(snippet_loss, snippet_loss_indices)

    snippet_loss = -tf.math.log(snippet_loss)

    snippet_loss = tf.math.reduce_mean(snippet_loss)

    return (gamma * p_wise_loss) + ((1-gamma) * snippet_loss)

In [5]:

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

@tf.function
def train_step(pos_doc, neg_doc):
    

    y_pos_doc = pos_doc[:3]
    label_pos_doc_snppets = pos_doc[3]

    with tf.GradientTape() as tape:

        pos = ranking_model(y_pos_doc)
        neg = ranking_model(neg_doc)

        loss = joint_loss_pairwise(pos, neg, label_pos_doc_snppets, None)


    grads = tape.gradient(loss, ranking_model.trainable_weights)
    #print(grads)
    optimizer.apply_gradients(zip(grads, ranking_model.trainable_weights))
    return loss
        #input("Press Enter to continue...")

In [7]:
gen = train_collection.generator()

for i in range(800):
    pos_doc, neg_doc = next(gen)
    
    loss = train_step(pos_doc, neg_doc)
    
    print(loss)

tf.Tensor(0.6913051, shape=(), dtype=float32)
tf.Tensor(0.6843848, shape=(), dtype=float32)
tf.Tensor(0.67998147, shape=(), dtype=float32)
tf.Tensor(0.68046916, shape=(), dtype=float32)
tf.Tensor(0.6847, shape=(), dtype=float32)
tf.Tensor(0.68611443, shape=(), dtype=float32)
tf.Tensor(0.6825864, shape=(), dtype=float32)
tf.Tensor(0.6823709, shape=(), dtype=float32)
tf.Tensor(0.6823955, shape=(), dtype=float32)
tf.Tensor(0.6734829, shape=(), dtype=float32)
tf.Tensor(0.6774239, shape=(), dtype=float32)
tf.Tensor(0.6797923, shape=(), dtype=float32)
tf.Tensor(0.682059, shape=(), dtype=float32)
tf.Tensor(0.6846226, shape=(), dtype=float32)
tf.Tensor(0.685148, shape=(), dtype=float32)
tf.Tensor(0.67517567, shape=(), dtype=float32)
tf.Tensor(0.673415, shape=(), dtype=float32)
tf.Tensor(0.6837712, shape=(), dtype=float32)
tf.Tensor(0.6749089, shape=(), dtype=float32)
tf.Tensor(0.66995656, shape=(), dtype=float32)
tf.Tensor(0.6668084, shape=(), dtype=float32)
tf.Tensor(0.6843816, shape=(), dtyp

KeyboardInterrupt: 