In [None]:
# !pip install transformers
# !pip install pyserini

In [3]:
# import os

# os.environ['JAVA_HOME'] = '/Library/Java/JavaVirtualMachines/jdk-13.0.2.jdk/Contents/Home'

In [57]:
from matplotlib import pyplot as plot

In [29]:
import torch

In [61]:
import numpy

In [203]:
from transformers import *

tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')

# tokenizer = AutoTokenizer.from_pretrained('roberta-base')
# model = AutoModel.from_pretrained('roberta-base')

In [187]:
from pyserini.search import pysearch

In [188]:
searcher = pysearch.SimpleSearcher('./lucene-index-covid/')

In [204]:
query = 'cure for coronavirus which is similar to mers'

In [205]:
hits = searcher.search(query)

In [206]:
def relevant_snippets(query_ids, query_state, abstract_ids, abstract_state, 
                      n_abstract_words = 5, window = 5, prefix='\t'):
    query_state = query_state / torch.sqrt((query_state ** 2).sum(-1, keepdims=True))
    abstract_state = abstract_state / torch.sqrt((abstract_state ** 2).sum(-1, keepdims=True))

    query_state_ = query_state.view(1, query_state.size(1), 1, query_state.size(2))
    abstract_state_ = abstract_state.view(1, 1, abstract_state.size(1), abstract_state.size(2))
    query_abstract_match = (query_state_ * abstract_state_).sum(-1)

    qa_match = query_abstract_match.data.numpy()[0]

    abstract_match = qa_match.max(0)
    abstract_words_idx = numpy.argsort(abstract_match)[-n_abstract_words:][::-1]

    abstract_words = tokenizer.convert_ids_to_tokens(abstract_ids[0])[1:-1]
    important_words = [abstract_words[idx] for idx in abstract_words_idx]

    query_words = tokenizer.convert_ids_to_tokens(query_ids[0])[1:-1]

    for word_idx, word in zip(abstract_words_idx,important_words):
        best_qword = numpy.argmax(qa_match[:, word_idx])
        surround_text_l = ' '.join(abstract_words[word_idx-window:word_idx])
        surround_text_r = ' '.join(abstract_words[word_idx+1:word_idx+window])
        surround_text = '... ' + surround_text_l + F' $${abstract_words[word_idx]}$$ ' + surround_text_r + ' ...'
        if 'Ġ' in surround_text:
            surround_text = ''.join(surround_text.split(' ')).replace('Ġ', ' ')
        else:
            surround_text = surround_text.replace(' ##', '')
        print(F'{prefix}{query_words[best_qword]}: {surround_text}')

In [207]:
query_ids = torch.tensor([tokenizer.encode(query, add_special_tokens=True)])

with torch.no_grad():
    query_state = model(query_ids)[0]
    query_state = query_state[:, 1:-1, :]

In [208]:
n_documents = 10

for ii in range(0, n_documents):
    print(F'{ii+1} {hits[ii].docid} {hits[ii].score} {searcher.doc(hits[ii].ldocid).object.getField("title").stringValue()}')
    
    doc = searcher.doc(hits[ii].ldocid).object
    content = hits[0].content # raw content

    abstract_ids = torch.tensor([tokenizer.encode(doc.getField('abstract').stringValue(), add_special_tokens=True)])
    abstract_ids = abstract_ids[0,:512].unsqueeze(0)
    with torch.no_grad():
        abstract_state = model(abstract_ids)[0]
        abstract_state = abstract_state[:, 1:-1, :]
    
    relevant_snippets(query_ids, query_state, abstract_ids, abstract_state, window=7)

1 40425 6.499300003051758 Clinical Characteristics of Two Human to Human Transmitted Coronaviruses: Corona Virus Disease 2019 versus Middle East Respiratory Syndrome Coronavirus.
	corona: ... 2012 . currently , a novel human $$corona$$virus has caused a major disease ...
	for: ... main complication . the most effective drug $$for$$ mers - cov is rib ...
	mer: ... of the middle east respiratory syndrome ( $$mer$$s ) worldwide in 2012 . ...
	corona: ... the similarities and differences between the two $$corona$$virus diseases remain to be unknown ...
	mer: ... complication . the most effective drug for $$mer$$s - cov is ribavir ...
2 40189 6.06279993057251 Epidemic Situation of Novel Coronavirus Pneumonia in China mainland
	corona: ... objective ] analyze the occurrence of novel $$corona$$virus pneumonia ( ncp ) ...
	##virus: ... ] analyze the occurrence of novel corona $$##virus$$ pneumonia ( ncp ) in ...
	##virus: ... analyze the occurrence of novel coronavirus $$pneumonia$$ ( ncp ) in