In [1]:
from mmnrm.utils import set_random_seed, load_neural_model, load_model, flat_list
from mmnrm.dataset import TestCollectionV2, sentence_splitter_builderV2, TrainSnippetsCollectionV2
from mmnrm.evaluation import BioASQ_JavaEvaluator

from collections import defaultdict
import os
import pickle
import numpy as np
import sys
import math
import time
import tensorflow as tf
from nltk.tokenize.punkt import PunktSentenceTokenizer
from utils import *

import nltk

In [2]:
# some queries to use as an example
def build_data_generators(tokenizer, use_joint=True, queries_sw=None, docs_sw=None):
    
    punkt_sent_tokenizer = PunktSentenceTokenizer().span_tokenize

    def maybe_tokenize(documents):
        if "tokens" not in documents:
            if use_joint:
                split = []
                spans = []
                for _itter, position in enumerate(punkt_sent_tokenizer(documents["text"])):
                    start, end = position
                    _text = documents["text"][start:end]

                    is_title = True
                    if _itter>0: # fix the start and end position for the abstract
                        
                        if _itter == 1: # auxiliar correction to set the abstract at 0 index
                            diff = (len(documents["title"])-1)+(start-(len(documents["title"])-1))
                            
                        start = start-diff
                        end = end-diff
                        is_title = False
                    
                    split.append(_text)
                    spans.append({"start":start,
                                  "end":end-1,
                                  "text":_text,
                                  "is_title":is_title,
                                  "snippet_id":documents["id"]+"_"+str(_itter),
                                  "doc_id":documents["id"]})
                documents["spans"] = spans
            else:
                split = nltk.sent_tokenize(documents["text"])
                
            documents["tokens"] = tokenizer.texts_to_sequences(split)
            if docs_sw is not None:
                for tokenized_sentence in documents["tokens"]:
                    tokenized_sentence = [token for token in tokenized_sentence if token not in docs_sw]

    def test_generator(data_generator):
        for _id, query, docs in data_generator:
            tokenized_queries = []
            for i in range(len(_id)):
                # tokenization
                tokenized_query = tokenizer.texts_to_sequences([query[i]])[0]

                if queries_sw is not None:
                    tokenized_query = [token for token in tokenized_query if token not in queries_sw] 
                
                tokenized_queries.append(tokenized_query)
                    
        
                for doc in docs[i]:
                    maybe_tokenize(doc)
                                                 
            yield _id, tokenized_queries, docs
            
    return test_generator

def get_test_generator_for_model(model, use_joint):

    if "model" in model.savable_config:
        cfg = model.savable_config["model"]
    
    test_gen = build_data_generators(model.tokenizer, use_joint)
    
    pad_tokens = lambda x, max_len, dtype='int32': tf.keras.preprocessing.sequence.pad_sequences(x, 
                                                                                           maxlen=max_len,
                                                                                           dtype=dtype, 
                                                                                           padding='post', 
                                                                                           truncating='post', 
                                                                                           value=0)

    pad_sentences = lambda x, max_lim, dtype='int32': x[:max_lim] + [[]]*(max_lim-len(x))
    
    def maybe_padding(document):
        if isinstance(document["tokens"], list):
            #overflow prevention
            bounded_doc_passage = min(cfg["max_passages"],len(document["tokens"]))
            document["sentences_mask"] = [True] * bounded_doc_passage + [False] * (cfg["max_passages"]-bounded_doc_passage)
            document["tokens"] = pad_tokens(pad_sentences(document["tokens"], cfg["max_passages"]), cfg["max_p_terms"])
            document["spans"] = document["spans"][:cfg["max_passages"]]
            
    def test_generator(data_generator):
        
        for ids, query, docs in test_gen(data_generator):
            
            docs_spans = []
            docs_ids = []
            docs_array = []
            docs_mask_array = []
            query_array = []
            query_ids = []
            
            for i in range(len(ids)):
                
                for doc in docs[i]:
                    # pad docs, use cache here
                    maybe_padding(doc)
                    docs_array.append(doc["tokens"])
                    docs_mask_array.append(doc["sentences_mask"])
                    docs_ids.append(doc["id"])
                    docs_spans.append(doc["spans"])
                    
                query_tokens = pad_tokens([query[i]], cfg["max_q_terms"])[0]
                query_tokens = [query_tokens] * len(docs[i])
                query_array.append(query_tokens)
                    
                query_ids.append([ids[i]]*len(docs[i]))
            
            #print(np.array(docs_mask_array))
            
            yield flat_list(query_ids), [np.array(flat_list(query_array)), np.array(docs_array), np.array(docs_mask_array)], docs_ids, docs_spans
            
    return test_generator

queries = load_queries("BioASQ-task9bPhaseA-testset3", maps=[("body","query")])

In [11]:
ranking_model = load_model("trained_models/honest-morning-60_val_collection0_doc_map@10")

DEBUG created tokenizer bioasq_9b_RegexTokenizer
False False
[LOAD FROM CACHE] Load embedding matrix from /backup/BioASQ-9b/embeddings/WORD2VEC_embedding_bioasq_9b_gensim_iter_15_freq0_200_Regex_word2vec_bioasq_9b_RegexTokenizer
Using einsum for mask bq,bps->bpqs and with embedding dim bqe,bpse->bpqs
[EMBEDDING MATRIX SHAPE] (5322623, 200)


In [12]:
test_input_generator = get_test_generator_for_model(ranking_model, use_joint=True)

In [13]:
def rank(model, t_collection):

    generator_Y = t_collection.generator()
                
    q_scores = defaultdict(list)

    for query_id, Y, docs_ids, docs_spans in generator_Y:
        s_time = time.time()
        
        scores = model.predict(Y)
        doc_scores = scores[0][:,0].tolist()
        snippets_scores = scores[1].tolist()
        
        for i in range(len(doc_scores)):
            
            for j in range(len(docs_spans[i])):
                docs_spans[i][j]["score"] = snippets_scores[i][j][0]
            
            #q_scores[query_id].extend(list(zip(docs_ids,scores)))
            q_scores[query_id[i]].append({"id":docs_ids[i],
                                          "score":doc_scores[i],
                                          "snippets":docs_spans[i]})
        
        print("\rEvaluation {} | time {}".format(len(q_scores), time.time()-s_time), end="\r")

    # sort the rankings
    for query_id in q_scores.keys():
        q_scores[query_id].sort(key=lambda x:-x["score"])
        q_scores[query_id] = q_scores[query_id]
    
    return q_scores

def rerank_run(baseline_file, top_k):
    run = load_document_run(baseline_file, dict_format=True)

    tCollection = TestCollectionV2(queries, run)\
                      .batch_size(top_k)\
                      .set_transform_inputs_fn(test_input_generator)
    
    results = rank(ranking_model, tCollection)
    
    return results ## update the run results

In [14]:
rerank = rerank_run(os.path.join("runs/rnd3", "bm25-baseline.run"), 100)

Evaluation 100 | time 0.31703734397888184

In [15]:
def snippetRank_bySort(results):
    snippets_results = {}
    for q in results.keys():
        snippets_list = flat_list([x["snippets"] for x in results[q]])
        snippets_results[q] = sorted(snippets_list, key=lambda x: -x["score"])
        
    return snippets_results

def snippetRank_byThreshold(results, threshold):
    snippets_results = {}
    # this will follow the document order first
    for q in results.keys():
        snippets_results[q] = [y for y in flat_list([x["snippets"] for x in results[q]]) if y["score"] >= threshold]
        
    return snippets_results

def snippetRank_byThreshold_and_TopK(results, threshold, topK):
    snippets_results = {}
    # this will follow the document order first
    for q in results.keys():
        snippets_results[q] = [y for y in flat_list([x["snippets"] for x in results[q][:topK]]) if y["score"] >= threshold]
        
    return snippets_results

In [8]:
evaluator = BioASQ_JavaEvaluator("BioASQ-task9bPhaseB-testset3", 
                                 "/home/tiagoalmeida/BioASQ-9b/BioASQEvaluator/BioASQEvaluation.jar",
                                 write_as_bioasq)

snippets = snippetRank_bySort(rerank)

run = create_document_run(queries, rerank, snippets)

evaluator.evaluate(run)

Remove /tmp/tmprt2bttm0


{'doc_p@10': 0.11499999999999995,
 'doc_r@10': 0.4654141414141413,
 'doc_f1@10': 0.16651067224828836,
 'doc_map@10': 0.35071740362811793,
 'doc_gmap@10': 0.0069504130851034225,
 'snippet_p@10': 0.08505336416681805,
 'snippet_r@10': 0.17641778927184135,
 'snippet_f1@10': 0.10271486432212351,
 'snippet_map@10': 0.22603522246306731,
 'snippet_gmap@10': 0.0011442407919515175}

In [52]:
run[0]

{'id': '6057c78994d57fd879000034',
 'type': 'factoid',
 'query': 'Which protein is involved in the organization and regulation of pluripotency-associated three-dimensional enhancer networks?',
 'documents': [{'id': '31548608',
   'score': 15.519436836242676,
   'snippets': [{'start': 0,
     'end': 114,
     'text': 'KLF4 is involved in the organization and regulation of pluripotency-associated three-dimensional enhancer networks.',
     'is_title': True,
     'doc_id': '31548608',
     'score': 0.7081087231636047},
    {'start': 0,
     'end': 259,
     'text': 'Cell fate transitions are accompanied by global transcriptional, epigenetic and topological changes driven by transcription factors, as is exemplified by reprogramming somatic cells to pluripotent stem cells through the expression of OCT4, KLF4, SOX2 and cMYC.',
     'is_title': False,
     'doc_id': '31548608',
     'score': 0.15228056907653809},
    {'start': 261,
     'end': 390,
     'text': 'How transcription factors orch

In [54]:
snippets["6057c78994d57fd879000034"]

[{'start': 0,
  'end': 114,
  'text': 'KLF4 is involved in the organization and regulation of pluripotency-associated three-dimensional enhancer networks.',
  'is_title': True,
  'doc_id': '31548608',
  'score': 0.7081087231636047},
 {'start': 0,
  'end': 259,
  'text': 'Cell fate transitions are accompanied by global transcriptional, epigenetic and topological changes driven by transcription factors, as is exemplified by reprogramming somatic cells to pluripotent stem cells through the expression of OCT4, KLF4, SOX2 and cMYC.',
  'is_title': False,
  'doc_id': '31548608',
  'score': 0.15228056907653809},
 {'start': 392,
  'end': 669,
  'text': 'Here, using KLF4 as a paradigm, we provide a transcription-factor-centric view of chromatin reorganization and its association with three-dimensional enhancer rewiring and transcriptional changes during the reprogramming of mouse embryonic fibroblasts to pluripotent stem cells.',
  'is_title': False,
  'doc_id': '31548608',
  'score': 0.3979711

In [11]:
#run = create_document_run(queries, rerank, snippets)

In [18]:

with open("snippet_bioasq_9b_batch_03_snippetRank_byThreshold"+".csv", "w") as fOut:
    metrics = ["snippet_p@10", "snippet_r@10", "snippet_f1@10", "snippet_map@10", "snippet_gmap@10"]
    
    header = "f_name,threshold," + ",".join(metrics)+"\n"
    fOut.write(header)
    
    T = list(map(lambda x:x/10, range(0,101,1)))
    for t in T:
    
        snippets = snippetRank_byThreshold(rerank,t)
        _run = create_document_run(queries, rerank, snippets)
        m = evaluator.evaluate(_run)
        fOut.write("snippetRank_byThreshold," + str(t)+"," + ",".join([ str(m[n]) for n in metrics]) + "\n")
        fOut.flush()
    # snippetRank_byThreshold_and_TopK



Remove /tmp/tmpzgb_vo8x
Remove /tmp/tmp30havt61
Remove /tmp/tmpxg9g9cad
Remove /tmp/tmpj6cn5epz
Remove /tmp/tmpueui852v
Remove /tmp/tmpbr1bg66v
Remove /tmp/tmp33bsg8id
Remove /tmp/tmpxlrmjsgg
Remove /tmp/tmp7b1w23st
Remove /tmp/tmpnp59en7u
Remove /tmp/tmpw0sg6syr
Remove /tmp/tmpxduwpg6b
Remove /tmp/tmpqcfuq880
Remove /tmp/tmpfc_z8xkm
Remove /tmp/tmp0aa57tvm
Remove /tmp/tmpm_82f2l1
Remove /tmp/tmpv1n01i88
Remove /tmp/tmpflmbqluy
Remove /tmp/tmpeqk2jktm
Remove /tmp/tmp0n8iv9vl
Remove /tmp/tmpae8bglav
Remove /tmp/tmp7inyg372
Remove /tmp/tmpye5q_mw8
Remove /tmp/tmpk3q4qy4j
Remove /tmp/tmpei5_vmnn
Remove /tmp/tmprr1epcqu
Remove /tmp/tmpoa2srni9
Remove /tmp/tmpkl80nh9m
Remove /tmp/tmpqq8r6mdv
Remove /tmp/tmpk7fgf2k4
Remove /tmp/tmpem31wlj3
Remove /tmp/tmp9m1mlr2n
Remove /tmp/tmpo1d8mfie
Remove /tmp/tmpc4fndxru
Remove /tmp/tmp0sggiqh5
Remove /tmp/tmpf3jycklk
Remove /tmp/tmpms52i_6p
Remove /tmp/tmpi_b0yiil
Remove /tmp/tmp1zva_qaa
Remove /tmp/tmp5c142fc_
Remove /tmp/tmpxpcqyk23
Remove /tmp/tmpn

KeyboardInterrupt: 

In [16]:
import itertools

with open("snippet_bioasq_9b_batch_03_snippetRank_byThreshold_and_TopK"+".csv", "w") as fOut:
    metrics = ["snippet_p@10", "snippet_r@10", "snippet_f1@10", "snippet_map@10", "snippet_gmap@10"]
    
    header = "f_name,threshold,topK" + ",".join(metrics)+"\n"
    fOut.write(header)
    
    T = list(map(lambda x:x/100, range(0,51,1)))
    TOPK = list(range(1,5,1))
    
    combinations = list(itertools.product(T,TOPK))
    
    for t,k in combinations:
    
        snippets = snippetRank_byThreshold_and_TopK(rerank,t,k)
        _run = create_document_run(queries, rerank, snippets)
        m = evaluator.evaluate(_run)
        fOut.write("snippetRank_byThreshold_and_TopK," + str(t)+"," +str(k)+","+ ",".join([ str(m[n]) for n in metrics]) + "\n")
        fOut.flush()

Remove /tmp/tmp1wwvt48z
Remove /tmp/tmpaequkr0h
Remove /tmp/tmpvc15mw42
Remove /tmp/tmp98uetr0j
Remove /tmp/tmpq70ab_3_
Remove /tmp/tmpjp6haap8
Remove /tmp/tmpmw6qpwar
Remove /tmp/tmpfxpq94lf
Remove /tmp/tmp9q80cql1
Remove /tmp/tmp0mox8ihe
Remove /tmp/tmpyq6e8iad
Remove /tmp/tmpp75bed5a
Remove /tmp/tmp6egxz5qb
Remove /tmp/tmpsanj4sf1
Remove /tmp/tmp2ogewxgl
Remove /tmp/tmputfaywd6
Remove /tmp/tmpt165w6a1
Remove /tmp/tmp1bfbhyka
Remove /tmp/tmpjijj1xt0
Remove /tmp/tmpctmbm9z1
Remove /tmp/tmpkib_y8up
Remove /tmp/tmpy3pql0ef
Remove /tmp/tmp2lm4v1ek
Remove /tmp/tmpv19mmu9e
Remove /tmp/tmpovzu979o
Remove /tmp/tmpfvagcnu_
Remove /tmp/tmphe_pizsr
Remove /tmp/tmpuoyt2so8
Remove /tmp/tmpep9jfzpl
Remove /tmp/tmp1svccdsy
Remove /tmp/tmpyyifnwot
Remove /tmp/tmpwgfrg8_b
Remove /tmp/tmp6atqx1gn
Remove /tmp/tmp2a9wm7qi
Remove /tmp/tmpidldt3t4
Remove /tmp/tmpjgd1ef77
Remove /tmp/tmplan_kdis
Remove /tmp/tmp0uaywycy
Remove /tmp/tmppnlwxxj7
Remove /tmp/tmpz2_04w8p
Remove /tmp/tmpdwlvx8q6
Remove /tmp/tmp4

In [None]:
import pandas as pd
df = pd.read_csv("snippet_bioasq_9b_batch_03_snippetRank_byThreshold.csv")
df = df.sort_values('snippet_f1@10',ascending=False)

In [17]:
import pandas as pd
df = pd.read_csv("snippet_bioasq_9b_batch_03_snippetRank_byThreshold_and_TopK.csv")
df.sort_values('snippet_f1@10',ascending=False)

Unnamed: 0,f_name,threshold,topKsnippet_p@10,snippet_r@10,snippet_f1@10,snippet_map@10,snippet_gmap@10
snippetRank_byThreshold_and_TopK,0.08,1,0.167207,0.215893,0.178072,0.672674,0.000465
snippetRank_byThreshold_and_TopK,0.12,1,0.169362,0.212623,0.177712,0.643631,0.000459
snippetRank_byThreshold_and_TopK,0.11,1,0.167702,0.213821,0.177454,0.652801,0.000461
snippetRank_byThreshold_and_TopK,0.09,1,0.166875,0.214446,0.177414,0.659153,0.000463
snippetRank_byThreshold_and_TopK,0.10,1,0.167288,0.213821,0.177234,0.653516,0.000462
...,...,...,...,...,...,...,...
snippetRank_byThreshold_and_TopK,0.32,2,0.139549,0.227787,0.160634,0.564154,0.000898
snippetRank_byThreshold_and_TopK,0.33,4,0.138182,0.231159,0.160190,0.561348,0.000900
snippetRank_byThreshold_and_TopK,0.33,3,0.138182,0.231159,0.160190,0.561348,0.000900
snippetRank_byThreshold_and_TopK,0.32,3,0.137730,0.230491,0.159620,0.566314,0.000901


In [45]:
punkt_sent_tokenizer = PunktSentenceTokenizer().span_tokenize

In [47]:
list(punkt_sent_tokenizer("asdfa. asdf asdf asdf. asdf asdf"))

[(0, 6), (7, 22), (23, 32)]

In [4]:
# Test Callback Validation with snippets
from mmnrm.callbacks import Validation
from mmnrm.evaluation import BioASQ_JavaEvaluator

evaluator = BioASQ_JavaEvaluator("/home/tiagoalmeida/BioASQ-9b/yearly_data/8B5_golden.json", 
                                 "/home/tiagoalmeida/BioASQ-9b/BioASQEvaluator/BioASQEvaluation.jar",
                                 write_as_bioasq)

validation_collection = TestCollectionV2.load("validation_batch_05_0.4_0.44_100")\
                                    .batch_size(100)\
                                    .set_name("Validation TOP (recall) 100")
validation_collection.evaluator = evaluator

validation_collection.save("validation_wSnippets_batch_05_0.4_0.44_100")

In [4]:
from mmnrm.callbacks import Validation
from mmnrm.evaluation import BioASQ_JavaEvaluator

def snippetRank_byThreshold(threshold):
    
    def snippetRank(results):
    
        snippets_results = {}
        # this will follow the document order first
        for q in results.keys():
            snippets_results[q] = [y for y in flat_list([x["snippets"] for x in results[q]]) if y["score"] >= threshold]

        return snippets_results
    
    return snippetRank

test_input_generator = get_test_generator_for_model(ranking_model, use_joint=True)

evaluator = BioASQ_JavaEvaluator("/home/tiagoalmeida/BioASQ-9b/yearly_data/8B5_golden.json", 
                                 "/home/tiagoalmeida/BioASQ-9b/BioASQEvaluator/BioASQEvaluation.jar",
                                 write_as_bioasq)

validation_collection = TestCollectionV2.load("validation_batch_05_0.4_0.44_100")\
                                    .batch_size(100)\
                                    .set_transform_inputs_fn(test_input_generator)\
                                    .set_name("Validation TOP (recall) 100")

validation_collection.evaluator = evaluator

snippet_rank_f = snippetRank_byThreshold(0.08) 

val = Validation(validation_collection = [validation_collection],# TODO
                 output_metrics=["doc_r@10", "doc_map@10","snippet_f1@10","snippet_map@10"],
                 path_store = None,
                 snippet_rank_f = snippet_rank_f)

In [5]:

val.evaluate(ranking_model.predict, validation_collection)

Evaluation 50 | avg 50-time 0.3404881954193115

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Remove /tmp/tmpqi5vduy0


{'doc_p@10': '0.1499999999999999',
 'doc_r@10': '0.5829252969252969',
 'doc_f1@10': '0.21479179024935535',
 'doc_map@10': '0.4581660918997229',
 'doc_gmap@10': '0.02964828870067375',
 'snippet_p@10': '0.17166116324064698',
 'snippet_r@10': '0.28734795914322975',
 'snippet_f1@10': '0.19953835567849598',
 'snippet_map@10': '0.7424392903661619',
 'snippet_gmap@10': '0.003843435329089425'}

In [6]:


validation_collection.save("validation_wSnippets_batch_05_0.4_0.44_100")

In [7]:
validation_collection = TestCollectionV2.load("validation_wSnippets_batch_05_0.4_0.44_100")

In [8]:
validation_collection.evaluator

<mmnrm.evaluation.BioASQ_JavaEvaluator at 0x7fba15bebf98>

In [13]:
runs = ["8b-b04-bioinfo-2.json"]

evaluator = BioASQ_JavaEvaluator("yearly_data/8B4_golden.json", 
                                 "/home/tiagoalmeida/BioASQ-9b/BioASQEvaluator/BioASQEvaluation.jar")

for r in runs:
    m = evaluator.evaluate(r)
    out = f"{r}:\n\tDoc:\n\t\tMap@10: {m['doc_map@10']}\n\t\tRecall@10: {m['doc_r@10']}"
    if float(m['snippet_r@10'])>0:
        out += f"\n\tSnippet:\n\t\tMap@10: {m['snippet_map@10']}\n\t\tRecall@10: {m['snippet_r@10']}\n\t\tF1@10: {m['snippet_f1@10']}"
    print(out)


Remove /tmp/tmper3ivh24
8b-b04-bioinfo-2.json:
	Doc:
		Map@10: 0.40057199546485267
		Recall@10: 0.5552987012987013
	Snippet:
		Map@10: 0.3658896457209854
		Recall@10: 0.2584178478525874
		F1@10: 0.17229742697268072


In [16]:
runs = ["8b-b04-bioinfo-2.json"]

evaluator = BioASQ_JavaEvaluator("yearly_data/8B4_golden.json", 
                                 "/home/tiagoalmeida/BioASQ-9b/BioASQEvaluator/BioASQEvaluation.jar")

for r in runs:
    m = evaluator.evaluate(r)
    out = f"{r}:\n\tDoc:\n\t\tMap@10: {m['doc_map@10']}\n\t\tRecall@10: {m['doc_r@10']}"
    if float(m['snippet_r@10'])>0:
        out += f"\n\tSnippet:\n\t\tMap@10: {m['snippet_map@10']}\n\t\tRecall@10: {m['snippet_r@10']}\n\t\tF1@10: {m['snippet_f1@10']}"
    print(out)


Remove /tmp/tmp5db62sdn
8b-b04-bioinfo-2.json:
	Doc:
		Map@10: 0.48099051555807515
		Recall@10: 0.6006318518646105
	Snippet:
		Map@10: 0.40186301388153983
		Recall@10: 0.27501143521870997
		F1@10: 0.20362609203479196


In [6]:
[t/50 for t in range(0,10)]

[0.0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18]

In [5]:
def build_data_generators(tokenizer, queries_sw=None, docs_sw=None):
    
    def maybe_tokenize(documents):
        if "tokens" not in documents:
            #split = nltk.sent_tokenize(documents["text"])
            documents["tokens"] = tokenizer.texts_to_sequences(map(lambda x: x["text"],documents["snippets"]))
            if docs_sw is not None:
                for tokenized_sentence in documents["tokens"]:
                    tokenized_sentence = [token for token in tokenized_sentence if token not in docs_sw]
    
    def train_generator(data_generator):
        while True:

            # get the batch triplet
            query, pos_docs, pos_label, neg_docs = next(data_generator)

            # tokenization, this can be cached for efficientcy porpuses NOTE!!
            tokenized_query = tokenizer.texts_to_sequences(query)

            if queries_sw is not None:
                for tokens in tokenized_query:
                    tokenized_query = [token for token in tokens if token not in queries_sw] 
            
            saveReturn = True
            
            for batch_index in range(len(pos_docs)):
                
                # tokenizer with cache in [batch_index][tokens]
                maybe_tokenize(pos_docs[batch_index])
                
                # assertion
                if all([ len(sentence)==0  for sentence in pos_docs[batch_index]["tokens"]]):
                    saveReturn = False
                    break # try a new resampling, NOTE THIS IS A EASY FIX PLS REDO THIS!!!!!!!
                          # for obvious reasons
                
                maybe_tokenize(neg_docs[batch_index])
                
            if saveReturn: # this is not true, if the batch is rejected
                yield tokenized_query, pos_docs, pos_label, neg_docs
                
            
    return train_generator, None

def train_generator_for_model(model):

    if "model" in model.savable_config:
        cfg = model.savable_config["model"]
    
    train_gen, test_gen = build_data_generators(model.tokenizer)
    
    pad_tokens = lambda x, max_len, dtype='int32': tf.keras.preprocessing.sequence.pad_sequences(x, 
                                                                                           maxlen=max_len,
                                                                                           dtype=dtype, 
                                                                                           padding='post', 
                                                                                           truncating='post', 
                                                                                           value=0)

    pad_sentences = lambda x, max_lim, dtype='int32': x[:max_lim] + [[]]*(max_lim-len(x))
    
    pad_labels = lambda x, max_lim, dtype='int32': x[:max_lim] + [0]*(max_lim-len(x))
    
    def maybe_padding(document, labels = None):
        if isinstance(document["tokens"], list):
            #overflow prevention
            bounded_doc_passage = min(cfg["max_passages"],len(document["tokens"]))
            document["sentences_mask"] = [True] * bounded_doc_passage + [False] * (cfg["max_passages"]-bounded_doc_passage)
            document["tokens"] = pad_tokens(pad_sentences(document["tokens"], cfg["max_passages"]), cfg["max_p_terms"])
            if labels is not None:
                document["sentences_labels"] = pad_labels(labels, cfg["max_passages"])
            
    def train_generator(data_generator):
 
        for query, pos_docs, pos_label, neg_docs in train_gen(data_generator):
            
            query = pad_tokens(query, cfg["max_q_terms"])
            
            pos_docs_array = []
            pos_snippets_labels = []
            pos_docs_mask_array = []
            neg_docs_array = []
            neg_docs_mask_array = []
            
            # pad docs, use cache here
            for batch_index in range(len(pos_docs)):
                maybe_padding(pos_docs[batch_index], pos_label[batch_index])
                pos_docs_array.append(pos_docs[batch_index]["tokens"])
                pos_snippets_labels.append(pos_docs[batch_index]["sentences_labels"])
                pos_docs_mask_array.append(pos_docs[batch_index]["sentences_mask"])
                maybe_padding(neg_docs[batch_index])
                neg_docs_array.append(neg_docs[batch_index]["tokens"])
                neg_docs_mask_array.append(neg_docs[batch_index]["sentences_mask"])
            
            yield [query, np.array(pos_docs_array), np.array(pos_docs_mask_array), np.array(pos_snippets_labels)], [query, np.array(neg_docs_array), np.array(neg_docs_mask_array)]
            
    return train_generator

model = load_model("trained_models/earthy-glade-11_val_collection0_map@10")
train_input_generator = train_generator_for_model(model)

training_data_used = "joint_training_batch_05_0.6_0.51_250"
train_collection = TrainSnippetsCollectionV2\
                            .load(training_data_used)\
                            .batch_size(32)\
                            .set_transform_inputs_fn(train_input_generator)

DEBUG created tokenizer bioasq_9b_RegexTokenizer
False False
[LOAD FROM CACHE] Load embedding matrix from /backup/BioASQ-9b/embeddings/WORD2VEC_embedding_bioasq_9b_gensim_iter_15_freq0_200_Regex_word2vec_bioasq_9b_RegexTokenizer
Using einsum for mask bq,bps->bpqs and with embedding dim bqe,bpse->bpqs
[EMBEDDING MATRIX SHAPE] (5322623, 200)


In [6]:
q, pos_doc, pos_l_snippet, neg_doc = next(train_collection.generator())

ValueError: not enough values to unpack (expected 4, got 2)

In [None]:
pos_doc