# THE OBJECTIVE OF THIS PROJECT IS TO BUILD AN AI SEARCH ENGINE THAT QUERIES A DATABASE,FILTERS OUT THE RESULTS ANDRETURNS THE ANSWER BASED ON THE RELEVANCE

#THE METHODS I WILL USE INCLUDE

1.BERT TRANSFORMER 
2.BART FOR ABSTRACT SUMMARIZATION
3.META PUB TO FIND PDFs WITH THE RESEARCH QUESTION
4.GOOGLE UNIVERSAL SENTENCE ENCODER 
5.USING SEGMENT SCORES TO CHOOSE THE BEST ANSWERS

In [None]:
# First I will set some default parametres
FIND_PDFS = True
USE_SUMMARY = True
SEARCH_PUBMED = True
SEARCH_MEDXRIV = True
# Find pdfs will be linked to Metapub,Use_summary will be used with transformers

In [1]:
# I will download JDK and have it set up
import os
!curl -O https://download.java.net/java/GA/jdk11/9/GPL/openjdk-11.0.2_linux-x64_bin.tar.gz
!mv openjdk-11.0.2_linux-x64_bin.tar.gz /usr/lib/jvm/; cd /usr/lib/jvm/; tar -zxvf openjdk-11.0.2_linux-x64_bin.tar.gz
!update-alternatives --install /usr/bin/java java /usr/lib/jvm/jdk-11.0.2/bin/java 1
!update-alternatives --set java /usr/lib/jvm/jdk-11.0.2/bin/java
os.environ["JAVA_HOME"] = "/usr/lib/jvm/jdk-11.0.2"

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0curl: (6) Could not resolve host: download.java.net
'mv' is not recognized as an internal or external command,
operable program or batch file.
'update-alternatives' is not recognized as an internal or external command,
operable program or batch file.
'update-alternatives' is not recognized as an internal or external command,
operable program or batch file.


In [None]:
# Since we will be using Anserini ,I will have to download pyserini ,an application for python and anserini
!pip install pyserini
from pyserini import pysearch

In [None]:
# Lucene Database ,where we will get some of our data on the CORD 19 database
!wget -O lucene.tar.gz https://www.dropbox.com/s/j55t617yhvmegy8/lucene-index-covid-2020-04-10.tar.gz
!tar xvfz lucene.tar.gz
minDate = '2020/04/09'
luceneDir = 'lucene-index-covid-2020-04-10/'

In [None]:
# To use Google Universal Sentence Encoder ,I will use tensorflow and set it up
import tensorflow as tf
import tensorflow_hub as hub
!curl -L "https://tfhub.dev/google/universal-sentence-encoder-large/3?tf-hub-format=compressed" | tar -zxvC /kaggle/working//sentence_wise_email/module/module_useT

In [None]:
# Transformers for Pretrained Models and Abstractive Text Summarization
# Bert -Squad Pretrained Model and Bart Text Sumarization- trained on cnn data
import torch
from transformers import BertForQuestionAnswering,BertTokenizer
from transformers import BartForConditionalGeneration,BartTokenizer
device = 'cuda' if torch.cuda.is_available else 'cpu'

QA_MODEL = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
QA_TOKENIZER = BertTokenizer('bert-large-uncased-whole-word-masking-finetuned-squad')
QA_MODEL.to(device)
QA_MODEL.eval()

if USE_SUMMARY :
    SUMMARY_MODEL = BartForConditionalGeneration('bart-large-cnn')
    SUMMARY_MODEL = BartTokenizer('bart-large-cnn')
    SUMMARY_MODEL.to(device)
    SUMMARY_MODEL.eval()


In [None]:
# Set up FIND PDFs and BioPython before we query our data 
if FIND_PDFS:
    !pip install metapub

In [None]:
# Biopython 
# The purpose of biopython is so that we could search the pubmed database if need be
!pip install biopython 
from Bio import Entrez,Medline
try :
    from StringIO import StringIO
except ImportError:
    from io import StringIO
import re

In [None]:
# Next I will set up a sample question and some keywords
question = "What are the main symptoms of Covid-19 disease"
keywords  = "COVID-19,symptoms,main"
# So we will use Pyserini to do a search on the Lucene database
# We will then store the information in a dictionary 
# For each of our entries we will check for the abstracts and in a way remove them 
# We will then remain with only relevant data
import json 
searcher = pysearch.SimpleSearcher(luceneDir)
results  = searcher.search(question+"."+keywords)
n_results = len(results)
result_dict = {}
for i in range(0,n_results):
    doc_json = json.load(results[i].raw)
    idx =str(results[i].docid)
    result_dict [idx] = doc_json
    result_dict[idx]['title'] = results[i].lucene_document.get("title")
    result_dict[idx]['authors'] = results[i].lucene_document.get("authors")
    result_dict[idx]['doi'] = results[i].lucene_document.get("doi")
    # We need to scrubb off some abstracts
    for idx,v in result_dict.items():
        abs_dirty = v['abstract']
        v['abstract_paragraph'] =[]
        v['abstract_full'] = ''
        # Here I make an assumption that abstract paragraphs are actually lists
        if abs_dirty:
            if isinstance(abs_dirty,list):
                for p in abs_dirty:
                    v['abstract_paragraph'].append(p['text'])
                    v['abstract_full'] +=p['text']+'\n\n'
            if isinstance(abs_dirty,str):
                v['abstract_paragraph'].append(abs_dirty)
                v['abstract_full']+= abs_dirty+'\n\n'

In [None]:
# I will create a function that reconstructs the text before we pass it to our BertModel
# We first remove all the hasttags and then replace long commas with shorter ones
def text_process(tokens,start=0,stop=1):
    if ['SEP'] in tokens:
        sepind =tokens.index(['SEP'])
        tokens = tokens[sepind+1:]
        txt = ''.join(token)
        txt = txt.replace('##','')
        txt = txt.replace("##",'')
        txt = txt.strip()
        txt = "".join(txt.split())
        txt = txt.replace(' ,',',')
        txt = txt.replace(' .','.')
        txt = txt.replace(' (','(')
        txt = txt.replace(' )',')')
        txt = txt.replace(' _','_')
        txt_list = txt.split(',')
        txt = ''
        nTxtl = len(txt_list)
        if nTxtl==1:
            return txt_list[0]
        new_list = []
        for i,t in enumerate(txt_list):
            if i<nTxtl-1:
                if t[-1].isdigit() and txt_list[i+1][0].isdigit():
                    new_list +=[t,',']
                else :
                    new_list +=[t,' ,']
            else :
                new_list+=[t]
        return ''.join(new_list)
        
    

# BERT PREDICTION MODEL

In [1]:
import numpy as np
def BertSquadPrediction(document,question):
    # Using BartTokenizer we will encode all the words and convert to tokens
    # This function will have to rewrite the document into 250-300 pages with 50 overlaps on either end
    nWords = len(document.split())
    input_ids_all = QA_TOKENIZER.encode(question,document)
    tokens_all =QA_TOKENIZER.convert_ids_to_tokens(input_ids_all)
    overlapFac = 1.1
    if len(input_ids_all)*overlapFac>2048:
        nSearch_Words = int(np.ceil(nWords/5))
        quarter = int(np.ceil(nWords/4))
        docsplit = document.split()
        doc_pieces = [''.join(docsplit[:int(nSearch_Words*overlapFac)]),
                     ''.join(docsplit[quarter-int(nSearch_Words*overlapFac/2):quarter+int(quarter*overlapFac/2)]),
                    ''.join(docsplit[quarter*3-int(nSearch_Words*overlapFac/2):quarter*3+int(quarter*overlapFac/2)]),
                    ''.join(docsplit[-int(nSearch_Words*overlapFac/2):])]
        input_ids = [QA_TOKENIZER.encode(question,dp)for dp in docpieces]
    elif  len(input_ids_all)*overlapFac >1536:
        nSearch_Words = int(np.ceil(nWords/4))
        third = int(np.ceil(nWords/3))
        docsplit = document.split()
        doc_pieces = [''.join(docsplit[:int(nSearch_Words*overlapFac)]),
                     ''.join(docsplit[third-int(nSearch_Words*overlapFac/2):third+int(nSearch_Words*overlapFac/2)]),
                     ''.join(docsplit[third*2-int(nSearch_Words*overlapFac/2):third*2+int(nSearch_Words*overlapFac/2)]),
                     ''.join(docsplit[-int(nSearch_Words*overlapFac)]:)]
        input_ids = [QA_TOKENIZER.encode(question,dp)for dp in docpieces]
    elif len(input_ids_all)*overlapFac >1024 :
        nSearch_Words = int(np.ceil(nWords/3))
        middle = int(np.ceil(nWords/2))
        docsplit = document.split()
        docpieces = [''.join(docsplit[:int(nSearch_Words*overlapFac)]),
                    ''.join(docplit[middle-int(nSearch_Words*overlapFac/2):middle+int(nSearch_Words*overlapFac/2)]),
                    ''.join(docsplit[-int(nSearch_Words*overlapFac):])]
        input_ids = [QA_TOKENIZER.encode(question,dp)for dp in docpieces]
    elif len(input_ids_all)*ovelapFac > 512 :
        nSearch_Words = int(np.ceil(nWords/2))
        docsplit = document.split()
        docpieces  = [''.join(docsplit[:int(nSearch_Words*overlapFac)]),
                     ''.join(docsplit[-int(nSearch_Words*overlapFac):])]
        input_ids = [QA_TOKENIZER.encode(question,dp) for dp in docpieces]
    else :
        input_ids = input_ids_all
    absTooLong = False
    answers = []
    cons = []
    
    for iptIds in input_ids :
        tokens = QA_TOKENIZER.convert_ids_to_tokens(iptIds)
        sep_index = iptIds.index(QA_TOKENIZER.sep_token_id)
        num_seg_a = sep_index+1
        num_seg_b = len(iptIds)-num_seg_a
        segment_ids = [0]* num_seg_a + [1]*num_seg_b
        n_ids = len(segment_ids)
        if n_ids < 512 :
            start_scores,end_scores = QA_MODEL(torch.tensor([iptIds]).to(device),
                                               token_type_ids = torch.tensor([segment_ids]).to(device))
        else :
            # For those texts that have more than 512 words :
            print(f'****Document is too long we consider {nWords} it has {n_ids}')
            absTooLong =True  
            start_scores,end_scores = QA_MODEL(torch.tensor(iptIds[:512]).to(device),
                                               token_type_ids= torch.tensor(segment_ids[:512]).to(device))
            # declare what start scores and end scores are
            start_scores = [:1:-1]
            end_scores = [:1:-1]
            # We will be considering the highest scoring scores among a list of possible matches and then we return the top value
            answer_start = torch.argmax(start_scores)
            answer_end = torch.argmax(end_scores)
            # The reason we add +2 on our stop is because we index everything upto -1
            answer = text_process(tokens = tokens,start =answer_start,stop=answer_end+2)
        if answer.startswith('. ') or answer.startswith(', '):
            answer = answer[2:]
            c = start_scores[0,answer_start].item()+end_scores[0,answer_end].item()
            cons.append(c)
        maxC = max(cons)
        iMaxC =[i for i,j in enumerate(cons) if j == maxC][0]
        confidence = cons[iMaxC]
        answer = answers[iMaxC]
        sep_index =tokens_all.index(['SEP'])
        full_txt_token = tokens_all[sep_index+1:]
        abs_returned = text_process(full_txt_tokens)
        ans ={}
        ans['answer'] = answer
        if answer.startswith(['CLS']) or answer_end.item() <sep_index or answer.end_with(['SEP']):
            ans['confidence'] = 1000000
        else :
            ans['confidence'] = confidence
            ans['abstract_bert'] =abs_returned
            ans['abs_too_long'] = absTooLong
        return ans

SyntaxError: invalid syntax (774918465.py, line 14)

# Open Domain QA on our Abstracts

In [2]:
from tqdm import tqdm
def AbstractSearch(result_dict,question):
    abstractResults = {}
    for k,v in tqdm(result_dict.items()):
        abstract = v['abstract_full']
        ans = BertSquadPrediction(abstract,question)
        if ans['answer']:
            confidence=ans['confidence']
            abstractResults[confidence] ={}
            abstractResults[confidence][answer] = ans['answer']
            abstractResults[confidence]['abstract_bert'] = ans['abstract_bert']
            abstractResults[confidence]['abs_too_long'] = ans['abs_too_long']
            absractResult[confidence]['idx'] = k
        Result_List = list(abstractResults.keys())
        if Result_List:
            maxScore = max(Result_List)
            total =0.0
            exp_score=[]
            for c in Result_List:
                s = np.exp(c-maxScore)
                exp_scores.append(s)
                total = sum(exp_scores)
                
        for i,c in enumerate(cList):
            abstractResults [exp_scores[i]/total] = abstractResults.pop(c)
        return abstractResults

IndentationError: expected an indented block (2457645645.py, line 1)