In [1]:
import os

import pysolr

from whoosh.index import create_in, open_dir
from whoosh.fields import *
from whoosh.qparser import QueryParser, MultifieldParser, OrGroup
from whoosh import scoring

In [2]:
all_fields = [
    "id", "brief_summary", "brief_title",
    "minimum_age", "gender",
    "primary_outcome", "detailed_description",
    "keywords", "official_title",
    "intervention_type", 
    "intervention_name",
    "intervention_browse",
    "condition_browse",
    "inclusion", "exclusion",
]

In [3]:
def run_query(text, index, bm25f_params={}, task="clinical", **kwargs):
    """
    Queries TREC Abstracts/Trials and returns their fields.
    
    Params
    ------
    
    text: Text to formulate query
    index: Whoosh.index object made from index.open_dir(<path-to-index>)
    bm25f_params: Tuned parameters for BM25f model
    
    **kwargs:
        qf: query field, search 'text' parameter in specified query fields
        return_fields: specify which fields to return
        size: Maximum results to return
        max_year: Filters returns up until 'max_year'
        check_input: print out your inputs (sanity checks)
    
    """
    
    base_field = "brief_summary^1" if task is "clinical" else "text^1"
    qf = base_field if 'qf' not in kwargs else kwargs['qf']
    return_fields = ['id','score'] if 'return_fields' not in kwargs else kwargs['return_fields'] #return fields
    
    size = 1000 if 'size' not in kwargs else kwargs['size']
    max_year = 0 if 'max_year' not in kwargs else kwargs['max_year']
#    parser='edismax' if 'parser' not in kwargs else kwargs['parser']
    
    qf_fields = [s.split("^")[0] for s in qf.split()]
    qf_boosts = [1 if len(s.split("^")) == 1 else float(s.split("^")[1]) for s in qf.split()]
    
    print(f"qf_fields: {qf_fields}")
    print(f"qf_boosts: {qf_boosts}")
    
    qff = [f for f, b in zip(qf_fields, qf_boosts) if b != 0]
    qfb = [b for f, b in zip(qf_fields, qf_boosts) if b != 0]
    
    boost_dict = {}
    for f, b in zip(qff, qfb):
        boost_dict[f] = b
        
    check_input = False if 'check_input' not in kwargs else kwargs["check_input"]
    if check_input:
        print(f"text: {text}")
        print(f"query fields: {qf}")
        print(f"boost_dict: {boost_dict}")
    
    output = []
    if len(bm25f_params) > 0:
        w = scoring.BM25F(**bm25f_params)
    else:
        w = scoring.BM25F()
        print('Default scoring')
    with index.searcher(weighting=w) as searcher:
        query = MultifieldParser(qff, index.schema,
                                 fieldboosts=boost_dict,
                                 group=OrGroup).parse(text)
        if max_year > 0:
            mask_q = QueryParser("year", index.schema).parse("date_i:["+str(max_year)+" to]")
            results = searcher.search(query, limit=size, mask=mask_q)
        else:
            results = searcher.search(query, limit=size)
            
        print("Returning results")
        for r in results:
            results_row = {}
            results_row['score'] = r.score
            for f in return_fields:
                if f not in results_row:
                    if f in r:
                        results_row[f] = r[f]
                    else:
                        results_row[f] = ''
            output.append(results_row)
    return output, return_fields

In [4]:
ct_17_path = "A2A4UMA/indices/ct17_whoosh"
ct_17_idx = open_dir(ct_17_path)
fields = ["id", "score"]

test = run_query(
    text="NCT00001452", 
    index=ct_17_idx,
    qf = "id^1",
    return_fields=all_fields
)

EmptyIndexError: Index 'MAIN' does not exist in FileStorage('A2A4UMA/indices/ct17_whoosh')

In [None]:
test[0]

In [7]:
import glob

path="../../bert_example/bbcsport"

for file in glob.glob(f"{path}/*/*.txt"):
    print(file)

../../bert_example/bbcsport/football/262.txt
../../bert_example/bbcsport/football/060.txt
../../bert_example/bbcsport/football/074.txt
../../bert_example/bbcsport/football/048.txt
../../bert_example/bbcsport/football/114.txt
../../bert_example/bbcsport/football/100.txt
../../bert_example/bbcsport/football/128.txt
../../bert_example/bbcsport/football/129.txt
../../bert_example/bbcsport/football/101.txt
../../bert_example/bbcsport/football/115.txt
../../bert_example/bbcsport/football/049.txt
../../bert_example/bbcsport/football/075.txt
../../bert_example/bbcsport/football/061.txt
../../bert_example/bbcsport/football/263.txt
../../bert_example/bbcsport/football/261.txt
../../bert_example/bbcsport/football/249.txt
../../bert_example/bbcsport/football/088.txt
../../bert_example/bbcsport/football/077.txt
../../bert_example/bbcsport/football/063.txt
../../bert_example/bbcsport/football/103.txt
../../bert_example/bbcsport/football/117.txt
../../bert_example/bbcsport/football/116.txt
../../bert

1 0 NCT00001452 0
1 0 NCT00003911 0
1 0 NCT00016263 0
1 0 NCT00112216 0
1 0 NCT00260390 0
1 0 NCT00264056 0
1 0 NCT00288938 0
1 0 NCT00302588 0
1 0 NCT00339222 0
1 0 NCT00341991 0

In [15]:
sports = ["football", "cricket"]

for file in glob.glob(f"{path}/*/*.txt"):
#     print(file)
    if any(s in file for s in sports):
        print(file)

../../bert_example/bbcsport/football/262.txt
../../bert_example/bbcsport/football/060.txt
../../bert_example/bbcsport/football/074.txt
../../bert_example/bbcsport/football/048.txt
../../bert_example/bbcsport/football/114.txt
../../bert_example/bbcsport/football/100.txt
../../bert_example/bbcsport/football/128.txt
../../bert_example/bbcsport/football/129.txt
../../bert_example/bbcsport/football/101.txt
../../bert_example/bbcsport/football/115.txt
../../bert_example/bbcsport/football/049.txt
../../bert_example/bbcsport/football/075.txt
../../bert_example/bbcsport/football/061.txt
../../bert_example/bbcsport/football/263.txt
../../bert_example/bbcsport/football/261.txt
../../bert_example/bbcsport/football/249.txt
../../bert_example/bbcsport/football/088.txt
../../bert_example/bbcsport/football/077.txt
../../bert_example/bbcsport/football/063.txt
../../bert_example/bbcsport/football/103.txt
../../bert_example/bbcsport/football/117.txt
../../bert_example/bbcsport/football/116.txt
../../bert

In [13]:
if any(s in text for s in sports):
    

TypeError: 'in <string>' requires string as left operand, not bool

In [None]:
# This is irrelevant
# solr = pysolr.Solr("http://130.155.204.198:8983/solr/trec-cds-2016", timeout=1200)
# solr

# qf="title_text_en^2 abstract_text_en^2 body_text_en^1.1",
# fields=['id','score'], size=1000, max_year=2016):

In [44]:
import random

RANDOM_STATE = 2019
test_size = 0.4
random.seed(2019)

X = [[random.randint(0, 15) for i in range(3)] for j in range(20)]
X_mask_all = [[random.randint(1, 2) for i in range(3)] for j in range(20)]
y = [random.randint(0, 1) for i in range(20)]

In [68]:
X

[[4, 7, 15],
 [5, 7, 7],
 [10, 9, 13],
 [7, 13, 14],
 [2, 10, 1],
 [11, 11, 10],
 [8, 6, 15],
 [3, 3, 5],
 [9, 14, 13],
 [1, 5, 3],
 [1, 1, 2],
 [12, 10, 14],
 [13, 9, 12],
 [5, 9, 6],
 [9, 3, 9],
 [6, 10, 0],
 [7, 9, 6],
 [10, 0, 5],
 [10, 2, 4],
 [2, 13, 12]]

In [69]:
X_mask_all

[[1, 2, 2],
 [1, 2, 1],
 [2, 2, 2],
 [2, 2, 2],
 [2, 2, 2],
 [2, 1, 2],
 [2, 2, 1],
 [1, 1, 2],
 [2, 1, 2],
 [2, 2, 2],
 [1, 1, 2],
 [1, 2, 1],
 [1, 1, 2],
 [2, 1, 1],
 [1, 1, 2],
 [2, 1, 1],
 [2, 2, 2],
 [1, 1, 2],
 [2, 2, 1],
 [2, 2, 1]]

In [48]:
print(y)

[1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0]


In [61]:
from sklearn.model_selection import train_test_split


X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    random_state=RANDOM_STATE,
    test_size=test_size,
    stratify=y
)
X_mask, X_mask_test, y_mask_train, y_mask_test = train_test_split(
    X_mask_all, y,
    random_state=RANDOM_STATE,
    test_size=test_size,
    stratify=y
)

In [62]:
print("X_train")
print(X_train)

print("X_test")
print(X_test)

X_train
[[6, 10, 0], [13, 9, 12], [1, 5, 3], [10, 0, 5], [8, 6, 15], [10, 2, 4], [2, 10, 1], [9, 14, 13], [3, 3, 5], [7, 13, 14], [5, 7, 7], [10, 9, 13]]
X_test
[[2, 13, 12], [4, 7, 15], [5, 9, 6], [9, 3, 9], [7, 9, 6], [11, 11, 10], [12, 10, 14], [1, 1, 2]]


In [63]:
print("y_train")
print(y_train)

print("y_mask_train")
print(y_mask_train)

y_train
[1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0]
y_mask_train
[1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0]


In [67]:
print("y_test")
print(y_test)

print("y_mask_test")
print(y_mask_test)

y_test
[0, 1, 0, 1, 1, 0, 0, 0]
y_mask_test
[0, 1, 0, 1, 1, 0, 0, 0]


In [52]:
print("X_mask")
print(X_mask)

print("X_mask_test")
print(X_mask_test)

X_mask
[[2, 1, 1], [1, 1, 2], [2, 2, 2], [1, 1, 2], [2, 2, 1], [2, 2, 1], [2, 2, 2], [2, 1, 2], [1, 1, 2], [2, 2, 2], [1, 2, 1], [2, 2, 2]]
X_mask_test
[[2, 2, 1], [1, 2, 2], [2, 1, 1], [1, 1, 2], [2, 2, 2], [2, 1, 2], [1, 2, 1], [1, 1, 2]]


In [70]:
print("X_test")
print(X_test)

X_test
[[2, 13, 12], [4, 7, 15], [5, 9, 6], [9, 3, 9], [7, 9, 6], [11, 11, 10], [12, 10, 14], [1, 1, 2]]
