In [8]:
from pyserini.search import pysearch
import tensorflow as tf
import tensorflow_hub as hub
USE_SUMMARY = True
FIND_PDFS = True
SEARCH_MEDRXIV = True
SEARCH_PUBMED = True

import torch
luceneDir = 'lucene-index-covid-2020-04-10/'
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
from transformers import BartTokenizer, BartForConditionalGeneration
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

QA_MODEL = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
QA_TOKENIZER = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
QA_MODEL.to(torch_device)
QA_MODEL.eval()

if USE_SUMMARY:
    SUMMARY_TOKENIZER = BartTokenizer.from_pretrained('bart-large-cnn')
    SUMMARY_MODEL = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
    SUMMARY_MODEL.to(torch_device)
    SUMMARY_MODEL.eval()

from Bio import Entrez, Medline

try:
    from StringIO import StringIO
except ImportError:
    from io import StringIO
import re

import json
searcher = pysearch.SimpleSearcher(luceneDir)
hits = searcher.search(query + '. ' + keywords)
n_hits = len(hits)
## collect the relevant data in a hit dictionary
hit_dictionary = {}
for i in range(0, n_hits):
    doc_json = json.loads(hits[i].raw)
    idx = str(hits[i].docid)
    hit_dictionary[idx] = doc_json
    hit_dictionary[idx]['title'] = hits[i].lucene_document.get("title")
    hit_dictionary[idx]['authors'] = hits[i].lucene_document.get("authors")
    hit_dictionary[idx]['doi'] = hits[i].lucene_document.get("doi")


## scrub the abstracts in prep for BERT-SQuAD
for idx,v in hit_dictionary.items():
    abs_dirty = v['abstract']
    # looks like the abstract value can be an empty list
    v['abstract_paragraphs'] = []
    v['abstract_full'] = ''

    if abs_dirty:
        # looks like if it is a list, then the only entry is a dictionary wher text is in 'text' key
        # looks like it is broken up by paragraph if it is in that form.  lets make lists for every paragraph
        # and a new entry that is full abstract text as both could be valuable for BERT derrived QA


        if isinstance(abs_dirty, list):
            for p in abs_dirty:
                v['abstract_paragraphs'].append(p['text'])
                v['abstract_full'] += p['text'] + ' \n\n'

        # looks like in some cases the abstract can be straight up text so we can actually leave that alone
        if isinstance(abs_dirty, str):
            v['abstract_paragraphs'].append(abs_dirty)
            v['abstract_full'] += abs_dirty + ' \n\n'

def embed_useT(module):
    with tf.Graph().as_default():
        sentences = tf.compat.v1.placeholder(tf.string)
        embed = hub.Module(module)
        embeddings = embed(sentences)
        session = tf.compat.v1.train.MonitoredSession()
    return lambda x: session.run(embeddings, {sentences: x})

embed_fn = embed_useT('kaggle/working/sentence_wise_email/module/module_useT')


import numpy as np
def reconstructText(tokens, start=0, stop=-1):
    tokens = tokens[start: stop]
    if '[SEP]' in tokens:
        sepind = tokens.index('[SEP]')
        tokens = tokens[sepind+1:]
    txt = ' '.join(tokens)
    txt = txt.replace(' ##', '')
    txt = txt.replace('##', '')
    txt = txt.strip()
    txt = " ".join(txt.split())
    txt = txt.replace(' .', '.')
    txt = txt.replace('( ', '(')
    txt = txt.replace(' )', ')')
    txt = txt.replace(' - ', '-')
    txt_list = txt.split(' , ')
    txt = ''
    nTxtL = len(txt_list)
    if nTxtL == 1:
        return txt_list[0]
    newList =[]
    for i,t in enumerate(txt_list):
        if i < nTxtL -1:
            if t[-1].isdigit() and txt_list[i+1][0].isdigit():
                newList += [t,',']
            else:
                newList += [t, ', ']
        else:
            newList += [t]
    return ''.join(newList)


def makeBERTSQuADPrediction(document, question):
    ## we need to rewrite this function so that it chuncks the document into 250-300 word segments with
    ## 50 word overlaps on either end so that it can understand and check longer abstracts
    nWords = len(document.split())
    input_ids_all = QA_TOKENIZER.encode(question, document)
    tokens_all = QA_TOKENIZER.convert_ids_to_tokens(input_ids_all)
    overlapFac = 1.1
    if len(input_ids_all)*overlapFac > 2048:
        nSearchWords = int(np.ceil(nWords/5))
        quarter = int(np.ceil(nWords/4))
        docSplit = document.split()
        docPieces = [' '.join(docSplit[:int(nSearchWords*overlapFac)]), 
                     ' '.join(docSplit[quarter-int(nSearchWords*overlapFac/2):quarter+int(quarter*overlapFac/2)]),
                     ' '.join(docSplit[quarter*2-int(nSearchWords*overlapFac/2):quarter*2+int(quarter*overlapFac/2)]),
                     ' '.join(docSplit[quarter*3-int(nSearchWords*overlapFac/2):quarter*3+int(quarter*overlapFac/2)]),
                     ' '.join(docSplit[-int(nSearchWords*overlapFac):])]
        input_ids = [QA_TOKENIZER.encode(question, dp) for dp in docPieces]        
        
    elif len(input_ids_all)*overlapFac > 1536:
        nSearchWords = int(np.ceil(nWords/4))
        third = int(np.ceil(nWords/3))
        docSplit = document.split()
        docPieces = [' '.join(docSplit[:int(nSearchWords*overlapFac)]), 
                     ' '.join(docSplit[third-int(nSearchWords*overlapFac/2):third+int(nSearchWords*overlapFac/2)]),
                     ' '.join(docSplit[third*2-int(nSearchWords*overlapFac/2):third*2+int(nSearchWords*overlapFac/2)]),
                     ' '.join(docSplit[-int(nSearchWords*overlapFac):])]
        input_ids = [QA_TOKENIZER.encode(question, dp) for dp in docPieces]        
        
    elif len(input_ids_all)*overlapFac > 1024:
        nSearchWords = int(np.ceil(nWords/3))
        middle = int(np.ceil(nWords/2))
        docSplit = document.split()
        docPieces = [' '.join(docSplit[:int(nSearchWords*overlapFac)]), 
                     ' '.join(docSplit[middle-int(nSearchWords*overlapFac/2):middle+int(nSearchWords*overlapFac/2)]),
                     ' '.join(docSplit[-int(nSearchWords*overlapFac):])]
        input_ids = [QA_TOKENIZER.encode(question, dp) for dp in docPieces]
    elif len(input_ids_all)*overlapFac > 512:
        nSearchWords = int(np.ceil(nWords/2))
        docSplit = document.split()
        docPieces = [' '.join(docSplit[:int(nSearchWords*overlapFac)]), ' '.join(docSplit[-int(nSearchWords*overlapFac):])]
        input_ids = [QA_TOKENIZER.encode(question, dp) for dp in docPieces]
    else:
        input_ids = [input_ids_all]
    absTooLong = False    
    
    answers = []
    cons = []
    for iptIds in input_ids:
        tokens = QA_TOKENIZER.convert_ids_to_tokens(iptIds)
        sep_index = iptIds.index(QA_TOKENIZER.sep_token_id)
        num_seg_a = sep_index + 1
        num_seg_b = len(iptIds) - num_seg_a
        segment_ids = [0]*num_seg_a + [1]*num_seg_b
        assert len(segment_ids) == len(iptIds)
        n_ids = len(segment_ids)
        #print(n_ids)

        if n_ids < 512:
            start_scores, end_scores = QA_MODEL(torch.tensor([iptIds]).to(torch_device), 
                                     token_type_ids=torch.tensor([segment_ids]).to(torch_device))
        else:
            #this cuts off the text if its more than 512 words so it fits in model space
            #need run multiple inferences for longer text. add to the todo
            print('****** warning only considering first 512 tokens, document is '+str(nWords)+' words long.  There are '+str(n_ids)+ ' tokens')
            absTooLong = True
            start_scores, end_scores = QA_MODEL(torch.tensor([iptIds[:512]]).to(torch_device), 
                                     token_type_ids=torch.tensor([segment_ids[:512]]).to(torch_device))
        start_scores = start_scores[:,1:-1]
        end_scores = end_scores[:,1:-1]
        answer_start = torch.argmax(start_scores)
        answer_end = torch.argmax(end_scores)
        #print(answer_start, answer_end)
        answer = reconstructText(tokens, answer_start, answer_end+2)
    
        if answer.startswith('. ') or answer.startswith(', '):
            answer = answer[2:]
            
        c = start_scores[0,answer_start].item()+end_scores[0,answer_end].item()
        answers.append(answer)
        cons.append(c)
    
    maxC = max(cons)
    iMaxC = [i for i, j in enumerate(cons) if j == maxC][0]
    confidence = cons[iMaxC]
    answer = answers[iMaxC]
    
    sep_index = tokens_all.index('[SEP]')
    full_txt_tokens = tokens_all[sep_index+1:]
    
    abs_returned = reconstructText(full_txt_tokens)

    ans={}
    ans['answer'] = answer
    #print(answer)
    if answer.startswith('[CLS]') or answer_end.item() < sep_index or answer.endswith('[SEP]'):
        ans['confidence'] = -1000000
    else:
        #confidence = torch.max(start_scores) + torch.max(end_scores)
        #confidence = np.log(confidence.item())
        ans['confidence'] = confidence
    #ans['start'] = answer_start.item()
    #ans['end'] = answer_end.item()
    ans['abstract_bert'] = abs_returned
    ans['abs_too_long'] = absTooLong
    return ans

from tqdm import tqdm
def searchAbstracts(hit_dictionary, question):
    abstractResults = {}
    for k,v in tqdm(hit_dictionary.items()):
        abstract = v['abstract_full']
        if abstract:
            ans = makeBERTSQuADPrediction(abstract, question)
            if ans['answer']:
                confidence = ans['confidence']
                abstractResults[confidence]={}
                abstractResults[confidence]['answer'] = ans['answer']
                #abstractResults[confidence]['start'] = ans['start']
                #abstractResults[confidence]['end'] = ans['end']
                abstractResults[confidence]['abstract_bert'] = ans['abstract_bert']
                abstractResults[confidence]['idx'] = k
                abstractResults[confidence]['abs_too_long'] = ans['abs_too_long']
                
    cList = list(abstractResults.keys())

    if cList:
        maxScore = max(cList)
        total = 0.0
        exp_scores = []
        for c in cList:
            s = np.exp(c-maxScore)
            exp_scores.append(s)
        total = sum(exp_scores)
        for i,c in enumerate(cList):
            abstractResults[exp_scores[i]/total] = abstractResults.pop(c)
    return abstractResults

answers = searchAbstracts(hit_dictionary, query)

workingPath = '/kaggle/working'
import pandas as pd
if FIND_PDFS:
    from metapub import UrlReverse
    from metapub import FindIt
from IPython.core.display import display, HTML

#from summarizer import Summarizer
#summarizerModel = Summarizer()
def displayResults(hit_dictionary, answers, question):
    
    question_HTML = '<div style="font-family: Times New Roman; font-size: 28px; padding-bottom:28px"><b>Query</b>: '+question+'</div>'
    #all_HTML_txt = question_HTML
    confidence = list(answers.keys())
    confidence.sort(reverse=True)
    
    confidence = list(answers.keys())
    confidence.sort(reverse=True)
    

    for c in confidence:
        if c>0 and c <= 1 and len(answers[c]['answer']) != 0:
            if 'idx' not in  answers[c]:
                continue
            rowData = []
            idx = answers[c]['idx']
            title = hit_dictionary[idx]['title']
            authors = hit_dictionary[idx]['authors'] + ' et al.'
            doi = '<a href="https://doi.org/'+hit_dictionary[idx]['doi']+'" target="_blank">' + title +'</a>'

            
            full_abs = answers[c]['abstract_bert']
            bert_ans = answers[c]['answer']
            
            
            split_abs = full_abs.split(bert_ans)
            sentance_beginning = split_abs[0][split_abs[0].rfind('.')+1:]
            if len(split_abs) == 1:
                sentance_end_pos = len(full_abs)
                sentance_end =''
            else:
                sentance_end_pos = split_abs[1].find('. ')+1
                if sentance_end_pos == 0:
                    sentance_end = split_abs[1]
                else:
                    sentance_end = split_abs[1][:sentance_end_pos]
                
            #sentance_full = sentance_beginning + bert_ans+ sentance_end
            answers[c]['full_answer'] = sentance_beginning+bert_ans+sentance_end
            answers[c]['sentence_beginning'] = sentance_beginning
            answers[c]['sentence_end'] = sentance_end
            answers[c]['title'] = title
            answers[c]['doi'] = doi
            if 'pdfLink' in hit_dictionary[idx]:
                answers[c]['pdfLink'] = hit_dictionary[idx]['pdfLink']
                
        else:
            answers.pop(c)
    
    
    ## now rerank based on semantic similarity of the answers to the question
    cList = list(answers.keys())
    allAnswers = [answers[c]['full_answer'] for c in cList]
    
    messages = [question]+allAnswers
    
    encoding_matrix = embed_fn(messages)
    similarity_matrix = np.inner(encoding_matrix, encoding_matrix)
    rankings = similarity_matrix[1:,0]
    
    for i,c in enumerate(cList):
        answers[rankings[i]] = answers.pop(c)

    ## now form pandas dv
    confidence = list(answers.keys())
    confidence.sort(reverse=True)
    pandasData = []
    ranked_aswers = []
    for c in confidence:
        rowData=[]
        title = answers[c]['title']
        doi = answers[c]['doi']
        idx = answers[c]['idx']
        rowData += [idx]            
        sentance_html = '<div>' +answers[c]['sentence_beginning'] + " <font color='red'>"+answers[c]['answer']+"</font> "+answers[c]['sentence_end']+'</div>'
        
        rowData += [sentance_html, c, doi]
        pandasData.append(rowData)
        ranked_aswers.append(' '.join([answers[c]['full_answer']]))
    
    if FIND_PDFS or SEARCH_MEDRXIV:
        pdata2 = []
        for rowData in pandasData:
            rd = rowData
            idx = rowData[0]
            if 'pdfLink' in answers[rowData[2]]:
                rd += ['<a href="'+answers[rowData[2]]['pdfLink']+'" target="_blank">PDF Link</a>']
            elif FIND_PDFS:
                if str(idx).startswith('pm_'):
                    pmid = idx[3:]
                else:
                    try:
                        test = UrlReverse('https://doi.org/'+hit_dictionary[idx]['doi'])
                        if test is not None:
                            pmid = test.pmid
                        else:
                            pmid = None
                    except:
                        pmid = None
                pdfLink = None
                if pmid is not None:
                    try:
                        pdfLink = FindIt(str(pmid))
                    except:
                        pdfLink = None
                if pdfLink is not None:
                    pdfLink = pdfLink.url

                if pdfLink is None:

                    rd += ['Not Available']
                else:
                    rd += ['<a href="'+pdfLink+'" target="_blank">PDF Link</a>']
            else:
                rd += ['Not Available']
            pdata2.append(rowData)
    else:
        pdata2 = pandasData
        
    
    display(HTML(question_HTML))
    
    if USE_SUMMARY:
        ## try generating an exacutive summary with extractive summarizer
        allAnswersTxt = ' '.join(ranked_aswers[:6]).replace('\n','')
    #    exec_sum = summarizerModel(allAnswersTxt, min_length=1, max_length=500)    
     #   execSum_HTML = '<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:18px"><b>BERT Extractive Summary:</b>: '+exec_sum+'</div>'

        answers_input_ids = SUMMARY_TOKENIZER.batch_encode_plus([allAnswersTxt], return_tensors='pt', max_length=1024)['input_ids'].to(torch_device)
        summary_ids = SUMMARY_MODEL.generate(answers_input_ids,
                                               num_beams=10,
                                               length_penalty=1.2,
                                               max_length=1024,
                                               min_length=64,
                                               no_repeat_ngram_size=4)

        exec_sum = SUMMARY_TOKENIZER.decode(summary_ids.squeeze(), skip_special_tokens=True)
        execSum_HTML = '<div style="font-family: Times New Roman; font-size: 18px; margin-bottom:1pt"><b>BART Abstractive Summary:</b>: '+exec_sum+'</div>'
        display(HTML(execSum_HTML))
        warning_HTML = '<div style="font-family: Times New Roman; font-size: 12px; padding-bottom:12px; color:#CCCC00; margin-top:1pt"> Warning this is an autogenerated summary based on semantic search of abstracts, always examine the sources before accepting this conclusion.  If the evidence only mentions topic in passing or the evidence is not clear, the summary will likely not clearly answer the question.</div>'
        display(HTML(warning_HTML))

#    display(HTML('<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:18px"><b>Body of Evidence:</b></div>'))
    
    if FIND_PDFS or SEARCH_MEDRXIV:
        df = pd.DataFrame(pdata2, columns = ['Lucene ID', 'BERT-SQuAD Answer with Highlights', 'Confidence', 'Title/Link','PDF Link'])
    else:
        df = pd.DataFrame(pdata2, columns = ['Lucene ID', 'BERT-SQuAD Answer with Highlights', 'Confidence', 'Title/Link'])
        
    display(HTML(df.to_html(render_links=True, escape=False)))
    
#displayResults(hit_dictionary, answers, query)

import requests
import json
from bs4 import BeautifulSoup
import re
import datetime
import dateutil.parser as dparser

def medrxivSearch(query, n_pages=1):
    results = {}
    q = query
    for x in range(n_pages):
        PARAMS = {
            'page': x
        }
        r = requests.get('https://www.medrxiv.org/search/' + q, params = PARAMS)
        content = r.text
        page = BeautifulSoup(content, 'lxml')
        
        for entry in page.find_all("a", attrs={"class": "highwire-cite-linked-title"}):
            title = ""
            url = ""
            pubDate = ""
            journal = None
            abstract = ""
            authors = []
            database = "medRxiv"
            
            
            url = "https://www.medrxiv.org" + entry.get('href')
            
            request_entry = requests.get(url)
            content_entry = request_entry.text
            page_entry = BeautifulSoup(content_entry, 'lxml')
            doi = page_entry.find("span", attrs={"class": "highwire-cite-metadata-doi"}).text.split('doi.org/')[-1]
            #print(page_entry)

            #getting pubDate
            pubDate = page_entry.find_all("div", attrs = {"class": "pane-content"})
            pubDate = pubDate[10]
            pubDate = str(dparser.parse(pubDate, fuzzy = True))
            pubDate = datetime.datetime.strptime(pubDate, '%Y-%m-%d %H:%M:%S')
            pubDate = pubDate.strftime('%b %d %Y')
            date = pubDate.split()
            month = date[0]
            day = date[1]
            year = date[2]
            pubDate = {
                'year': year,
                'month': month,
                'day': day
            }

            #getting title
            title = page_entry.find("h1", attrs={"class": "highwire-cite-title"}).text
            
            #getting abstract
            abstract = page_entry.find("p", attrs = {"id": "p-2"}).text.replace('\n', ' ')

            #getting authors 
            givenNames = page_entry.find_all("span", attrs={"class": "nlm-given-names"})
            surnames = page_entry.find_all("span",  attrs={"class": "nlm-surname"})
            names = list(zip(givenNames,surnames))
            for author in names:
                name = author[0].text + ' ' + author[1].text
                if name not in authors:
                    authors.append(name)
            
            result = {
                'title': title,
                'url': url,
                'pubDate': pubDate,
                'journal': journal,
                'abstract': abstract,
                'authors': authors[0],
                'database': database,
                'doi': doi,
                'pdfLink': url+'.full.pdf'
            }
            results['mrx_'+result['doi'].split('/')[-1]] = result
            #break
    return results

def searchDatabase(question, keywords, pysearch, lucene_database, pm_kw = '', minDate='2019/12/01', k=20):
    ## search the lucene database with a combination of the question and the keywords
    searcher = pysearch.SimpleSearcher(lucene_database)
    hits = searcher.search(question + '. ' + keywords, k=k)
    n_hits = len(hits)
    ## collect the relevant data in a hit dictionary
    hit_dictionary = {}
    for i in range(0, n_hits):
        doc_json = json.loads(hits[i].raw)
        idx = str(hits[i].docid)
        hit_dictionary[idx] = doc_json
        hit_dictionary[idx]['title'] = hits[i].lucene_document.get("title")
        hit_dictionary[idx]['authors'] = hits[i].lucene_document.get("authors")
        hit_dictionary[idx]['doi'] = hits[i].lucene_document.get("doi")
        
    
    titleList = [h['title'] for h in hit_dictionary.values()]
    
    if pm_kw:
        if SEARCH_PUBMED:
            new_hits = pubMedSearch(pm_kw, db='pubmed', mindate=minDate)
            for id,h in new_hits.items():
                if h['title'] not in titleList:
                    titleList.append(h['title'])
                hit_dictionary[id] = h
        if SEARCH_MEDRXIV:
            new_hits = medrxivSearch(pm_kw)
            for id,h in new_hits.items():
                if h['title'] not in titleList:
                    titleList.append(h['title'])
                hit_dictionary[id] = h
    
    ## scrub the abstracts in prep for BERT-SQuAD
    for idx,v in hit_dictionary.items():
        abs_dirty = v['abstract']
        # looks like the abstract value can be an empty list
        v['abstract_paragraphs'] = []
        v['abstract_full'] = ''

        if abs_dirty:
            # looks like if it is a list, then the only entry is a dictionary wher text is in 'text' key
            # looks like it is broken up by paragraph if it is in that form.  lets make lists for every paragraph
            # and a new entry that is full abstract text as both could be valuable for BERT derrived QA


            if isinstance(abs_dirty, list):
                for p in abs_dirty:
                    v['abstract_paragraphs'].append(p['text'])
                    v['abstract_full'] += p['text'] + ' \n\n'

            # looks like in some cases the abstract can be straight up text so we can actually leave that alone
            if isinstance(abs_dirty, str):
                v['abstract_paragraphs'].append(abs_dirty)
                v['abstract_full'] += abs_dirty + ' \n\n'
    ## Search collected abstracts with BERT-SQuAD
    answers = searchAbstracts(hit_dictionary, question)
    
    ## display results in a nice format
    displayResults(hit_dictionary, answers, question)

NameError: name 'query' is not defined