In [None]:
# !pip install transformers
# !pip install pyserini

In [3]:
# import os

# os.environ['JAVA_HOME'] = '/Library/Java/JavaVirtualMachines/jdk-13.0.2.jdk/Contents/Home'

In [57]:
from matplotlib import pyplot as plot

In [29]:
import torch

In [61]:
import numpy

In [224]:
from transformers import *

tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')

# tokenizer = AutoTokenizer.from_pretrained('roberta-base')
# model = AutoModel.from_pretrained('roberta-base')

In [225]:
from pyserini.search import pysearch

In [226]:
searcher = pysearch.SimpleSearcher('./lucene-index-covid/')

In [288]:
query = 'degradation integration rna virus'

In [289]:
hits = searcher.search(query)

In [290]:
def relevant_snippets(query_ids, query_state, abstract_ids, abstract_state, 
                      n_abstract_words = 5, window = 5, prefix='\t'):
    query_state = query_state / torch.sqrt((query_state ** 2).sum(-1, keepdims=True))
    abstract_state = abstract_state / torch.sqrt((abstract_state ** 2).sum(-1, keepdims=True))

    query_state_ = query_state.view(1, query_state.size(1), 1, query_state.size(2))
    abstract_state_ = abstract_state.view(1, 1, abstract_state.size(1), abstract_state.size(2))
    query_abstract_match = (query_state_ * abstract_state_).sum(-1)

    qa_match = query_abstract_match.data.numpy()[0]

    ''' best matching words from the abstract overall '''
    abstract_match = qa_match.max(0)
    abstract_words_idx = numpy.argsort(abstract_match)[-n_abstract_words:][::-1]

    abstract_words = tokenizer.convert_ids_to_tokens(abstract_ids[0])[1:-1]
    important_words = [abstract_words[idx] for idx in abstract_words_idx]

    query_words = tokenizer.convert_ids_to_tokens(query_ids[0])[1:-1]

    print()
    for word_idx, word in zip(abstract_words_idx,important_words):
        best_qword = numpy.argmax(qa_match[:, word_idx])
        surround_text_l = ' '.join(abstract_words[word_idx-window:word_idx])
        surround_text_r = ' '.join(abstract_words[word_idx+1:word_idx+window])
        surround_text = '... ' + surround_text_l + F' $${abstract_words[word_idx]}$$ ' + surround_text_r + ' ...'
        if 'Ġ' in surround_text:
            surround_text = ''.join(surround_text.split(' ')).replace('Ġ', ' ')
        else:
            surround_text = surround_text.replace(' ##', '')
        print(F'{prefix}{query_words[best_qword]}: {surround_text}')
        
    ''' best matching phrases from the abstract per query word '''
    print()
    for wii, (word_idx, word) in enumerate(zip(query_ids[0][1:-1], query_words)):
        best_aword = numpy.argmax(qa_match[wii, :])
        surround_text_l = ' '.join(abstract_words[best_aword-window:best_aword])
        surround_text_r = ' '.join(abstract_words[best_aword+1:best_aword+window])
        surround_text = '... ' + surround_text_l + F' $${abstract_words[best_aword]}$$ ' + surround_text_r + ' ...'
        if 'Ġ' in surround_text:
            surround_text = ''.join(surround_text.split(' ')).replace('Ġ', ' ')
        else:
            surround_text = surround_text.replace(' ##', '')
        print(F'{prefix}{word}: {surround_text}')

In [291]:
query_ids = torch.tensor([tokenizer.encode(query, add_special_tokens=True)])

with torch.no_grad():
    query_state = model(query_ids)[0]
    query_state = query_state[:, 1:-1, :]

In [292]:
n_documents = 10

print(F'Query: {query}\n')

for ii in range(0, n_documents-1):
#     print(F'{ii+1} {hits[ii].docid} {hits[ii].score} {searcher.doc(hits[ii].ldocid).object.getField("title").stringValue()}')
    print(F'{ii+1} {hits[ii].docid} '
          F'Title: {searcher.doc(hits[ii].ldocid).object.getField("title").stringValue()} \n'
          F'Authors: {searcher.doc(hits[ii].ldocid).object.getField("author_string").stringValue()} \n'
          F'Journal: {searcher.doc(hits[ii].ldocid).object.getField("journal").stringValue()} '
          F'({searcher.doc(hits[ii].ldocid).object.getField("publish_time").stringValue()})')
    
    doc = searcher.doc(hits[ii].ldocid).object
    content = hits[ii].content # raw content

    abstract_ids = torch.tensor([tokenizer.encode(doc.getField('contents').stringValue(), add_special_tokens=True)])
    abstract_ids = abstract_ids[0,:512].unsqueeze(0)
    with torch.no_grad():
        abstract_state = model(abstract_ids)[0]
        abstract_state = abstract_state[:, 1:-1, :]
    
    relevant_snippets(query_ids, query_state, abstract_ids, abstract_state, window=7)
    print()

Query: degradation integration rna virus

1 28164.0032 Title: Towards standardization of RNA quality assessment using user-independent classifiers of microcapillary electrophoresis traces 
Authors: Imbeaud, Sandrine; Graudens, Esther; Boulanger, Virginie; Barlet, Xavier; Zaborski, Patrick; Eveno, Eric; Mueller, Odilo; Schroeder, Andreas; Auffray, Charles 
Journal: Nucleic Acids Res (2005 Mar 30)

	degradation: ... in a mathematical calculation together with the $$ribosomal$$ peak heights . it allowed character ...
	degradation: ... classes of samples , namely good ( $$hc$$  ...
	degradation: ... relevant classes of samples , namely good $$($$ hc ...
	degradation: ... - validation of the user - independent $$qualification$$ systems tested . both resulted in ...
	degradation: ... based on the identification of additional ' $$degradation$$ peak signals ' and their integration ...

	degradation: ... in a mathematical calculation together with the $$ribosomal$$ peak heights . it allowed cha