<a href="https://colab.research.google.com/github/kyunghyuncho/bio-ret-viz/blob/master/pyserini_scibert_covid19_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pyserini Demo on COVID-19 Dataset (Title + Abstract Index) 
# with HuggingFace Transformers based visualization


This notebook provides a demo on how to get started in searching the [COVID-19 Open Research Dataset](https://pages.semanticscholar.org/coronavirus-research) (release of 2020/03/20) from AI2.
In this notebook, we'll be working with the title + abstract index. 
Specifically, we're not indexing the full text (that'll come later, soon!).


In [None]:
from IPython.core.display import display, HTML

First, install Python dependencies

In [2]:
%%capture
!pip install pyserini==0.8.1.0
!pip install transformers

import json
import os
os.environ["JAVA_HOME"] = "/Library/Java/JavaVirtualMachines/jdk-13.0.2.jdk/Contents/Home"

Let's grab the pre-built index:

In [3]:
%%capture
!wget https://www.dropbox.com/s/uvjwgy4re2myq5s/lucene-index-covid-2020-03-20.tar.gz
!tar xvfz lucene-index-covid-2020-03-20.tar.gz

In [4]:
!du -h lucene-index-covid-2020-03-20

1.2G	lucene-index-covid-2020-03-20


Let's load BioBERT (https://arxiv.org/abs/1901.08746) from HuggingFace Transformers

In [8]:
import torch
import numpy

In [9]:
from tqdm import tqdm

In [10]:
from transformers import *

tokenizer = AutoTokenizer.from_pretrained('monologg/biobert_v1.0_pubmed_pmc', do_lower_case=False)
model = AutoModel.from_pretrained('monologg/biobert_v1.0_pubmed_pmc')

Sanity check of index size (should be 1.3G):

You can use `pysearch` to search over an index. Here's the basic usage:

In [11]:
query = 'these differences reside in the molecular structure of spike proteins and some other factors.  Which receptor combination(s) will cause maximum harm'

In [12]:
from pyserini.search import pysearch

searcher = pysearch.SimpleSearcher('lucene-index-covid-2020-03-20/')
hits = searcher.search(query)

display(HTML('<div style="font-family: Times New Roman; font-size: 20px; padding-bottom:12px"><b>Query</b>: '+query+'</div>'))

# Prints the first 10 hits
for i in range(0, 10):
  display(HTML('<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:10px">' + 
               F'{i+1} {hits[i].docid} ({hits[i].score:1.2f}) -- ' +
               F'{hits[i].lucene_document.get("authors")} et al. ' +
              #  F'{hits[i].lucene_document.get("journal")}. ' +
              #  F'{hits[i].lucene_document.get("publish_time")}. ' +
               F'{hits[i].lucene_document.get("title")}. ' +
               F'<a href="https://doi.org/{hits[i].lucene_document.get("doi")}">{hits[i].lucene_document.get("doi")}</a>.'
               + '</div>'))

From the hits array, use `.lucene_document` to access the underlying indexed Lucene `Document`, and from there, call `.get(field)` to fetch specific fields, like "title", "doc", etc.
The complete list of available fields is [here](https://github.com/castorini/anserini/blob/master/src/main/java/io/anserini/index/generator/CovidGenerator.java#L46).

For hit #1, we don't have the full text, but we can access available information via `.raw`.

In [13]:
hit1_json = json.loads(hits[0].raw)
# print(json.dumps(hit1_json, indent=4))

For hit #8, we have the full text, which we can also fetch via `.raw`: As an example, let's look at to which section the second paragraph belongs.

In [14]:
hit8_json = json.loads(hits[6].raw)
# print(hit8_json['body_text'][1]['section'])

KeyError: 'body_text'

Let's extract contextualized vectors of queries and abstracts from SciBERT for highlighting relevant paragraphs.

First, extract the contextualized vectors of the query above:

$$q_1, \ldots, q_T = \text{SciBERT}(\text{query})$$

In [163]:
def extract_scibert(text, tokenizer, model):
    text_ids = torch.tensor([tokenizer.encode(text, add_special_tokens=True)])
    text_words = tokenizer.convert_ids_to_tokens(text_ids[0])[1:-1]

    n_chunks = int(numpy.ceil(float(text_ids.size(1))/512))
    states = []
    
    for ci in range(n_chunks):
        text_ids_ = text_ids[0, 1+ci*512:1+(ci+1)*512]            
        torch.cat([text_ids[0, 0].unsqueeze(0), text_ids_])
        if text_ids[0, -1] != text_ids[0, -1]:
            torch.cat([text_ids, text_ids[0,-1].unsqueeze(0)])
        
        with torch.no_grad():
            state = model(text_ids_.unsqueeze(0))[0]
            state = state[:, 1:-1, :]
        states.append(state)

    state = torch.cat(states, axis=1)
    return text_ids, text_words, state[0]

In [164]:
query_ids, query_words, query_state = extract_scibert(query, tokenizer, model)

Second, let's extract contextualized vectors of all the paragraphs from the hit #3:

$$p_1^k, \ldots, p_{T_k}^k = \text{SciBERT}(\text{paragraph}^k)$$

In [165]:
ii = 6

doc_json = json.loads(hits[ii].raw)

paragraph_states = []
for pid in tqdm(range(len(doc_json['body_text']))):
    paragraph_states.append(extract_scibert(doc_json['body_text'][pid]['text'], tokenizer, model))

100%|██████████| 36/36 [00:11<00:00,  3.17it/s]


We then compute the cosine similarity matrix between the query and each paragraph:

$$A^k = [a^k_{ij}] \in \mathbb{R}^{|\text{query}| \times |\text{paragraph}^k|},$$

where

$$a^k_{ij} = \frac{q_i^\top p_j^k}{\| q_i \| \| p_j^k \|}$$


In [166]:
def cross_match(state1, state2):
  state1 = state1 / torch.sqrt((state1 ** 2).sum(1, keepdims=True))
  state2 = state2 / torch.sqrt((state2 ** 2).sum(1, keepdims=True))

  sim = (state1.unsqueeze(1) * state2.unsqueeze(0)).sum(-1)

  return sim

In [167]:
sim_matrices = []

for pid in tqdm(range(len(doc_json['body_text']))):
  sim_matrices.append(cross_match(query_state, paragraph_states[pid][-1]))

100%|██████████| 36/36 [00:00<00:00, 112.01it/s]


Let's retrieve the most relevant paragraphs first, where define the top-$M$ most relevant paragraphs as 

$$\arg\text{top-$M$}_{k=1}^K \max_{i=1,\ldots,|\text{query}|} \max_{j=1,\ldots, |\text{paragraph}^k|} A_{ij}^k$$

that is, a paragraph with the highly matched words to the query words is considered relevant.

In [168]:
paragraph_relevance = [torch.max(sim).item() for sim in sim_matrices]
rel_index = numpy.argsort(paragraph_relevance)[-5:][::-1]

In [170]:
display(HTML('<div style="font-family: Times New Roman; font-size: 20px; padding-bottom:12px"><b>Query</b>: '+query+'</div>'))

display(HTML('<div style="font-family: Times New Roman; font-size: 20px; padding-bottom:12px"><b>Document</b>: '+
             F'{hits[i].lucene_document.get("authors")} et al. ' +
             F'{hits[i].lucene_document.get("title")}. ' + 
             F'<a href="https://doi.org/{hits[i].lucene_document.get("doi")}">{hits[i].lucene_document.get("doi")}</a>.' +
             '</div>'))

for ri in numpy.sort(rel_index):
  display(HTML('<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:10px; margin-left: 15px">' + 
               F'<b>{doc_json["body_text"][ri]["section"]}</b> -- ' +
               F'{" ".join(paragraph_states[ri][1]).replace(" ##","")}' + ' </div>'))

We want to look at more details by highlighting relevant phrases in each paragraph, where we define relevant phrases for each paragraph as

$$\arg\text{top-$M$}_{j=1,\ldots, |\text{paragraph}^k|} \max_{i=1,\ldots,|\text{query}|} A_{ij}^k$$

that is, any word that had a high similarity to each of the query words is considered relevant. given these words, we highlight a window of 10 surrounding each of them.

In [204]:
def highlight_paragraph(ptext, rel_words, max_win=10):
    para = ""
    prev_idx = 0
    for jj in rel_words:
        
        if prev_idx > jj:
            continue
        
        found_start = False
        for kk in range(jj, prev_idx-1, -1):
            if ptext[kk] == "." and (ptext[kk+1][0].isupper() or ptext[kk+1][0] == '['):
                sent_start = kk
                found_start = True
                break
        if not found_start:
            sent_start = prev_idx-1
            
        found_end = False
        for kk in range(jj, len(ptext)-1):
            if ptext[kk] == "." and (ptext[kk+1][0].isupper() or ptext[kk+1][0] == '['):
                sent_end = kk
                found_end = True
                break
                
        if not found_end:
            if kk >= len(ptext) - 2:
                sent_end = len(ptext)
            else:
                sent_end = jj
        
        para = para + " "
        para = para + " ".join(ptext[prev_idx:sent_start+1])
        para = para + " <font color='blue'>"
        para = para + " ".join(ptext[sent_start+1:sent_end])
        para = para + "</font> "
        prev_idx = sent_end
        
    if prev_idx < len(ptext):
        para = para + " ".join(ptext[prev_idx:])

    return para

In [205]:
display(HTML('<div style="font-family: Times New Roman; font-size: 20px; padding-bottom:12px"><b>Query</b>: '+query+'</div>'))

display(HTML('<div style="font-family: Times New Roman; font-size: 20px; padding-bottom:12px"><b>Document</b>: '+
             F'{hits[i].lucene_document.get("authors")} et al. ' +
             F'{hits[i].lucene_document.get("title")}. ' + 
             F'<a href="https://doi.org/{hits[i].lucene_document.get("doi")}">{hits[i].lucene_document.get("doi")}</a>.' +
             '</div>'))

for ri in numpy.sort(rel_index):
  sim = sim_matrices[ri].data.numpy()
  rel_words = numpy.sort(numpy.argsort(sim.max(0))[-2:][::-1])

  ptext = paragraph_states[ri][1]

  para = highlight_paragraph(ptext, rel_words)

  display(HTML('<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:10px; margin-left: 15px">' + 
               F'<b>{doc_json["body_text"][ri]["section"]}</b> -- ' +
               F'{para.replace(" ##","")}' + ' </div>'))