In [1]:
import glob
import itertools as it
import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

### Data manipulation

In [2]:
class Document:
    
    def __init__(self, paper_id, abstract, body_text):
        self.paper_id = paper_id
        self.abstract = abstract
        self.body_text = body_text
    
    @classmethod
    def from_json(cls, path):
        with open(path, 'r') as fd:
            data = json.load(fd)
        
        paper_id = data['paper_id']
        abstract = '\n'.join([record['text'] for record in data['abstract']])
        body_text = '\n'.join([record['text'] for record in data['body_text']])
        return cls(paper_id, abstract, body_text)
    
    def __repr__(self):
        return f'{self.paper_id}: {self.abstract[:200]} ... {self.body_text[:200]} ...'
        
    def _repr_html_(self):
        paper_html = f'<b>Paper ID:</b> {self.paper_id}'
        abstract_html = ['<p>' + record + '</p>' for record in self.abstract.split('\n')]
        abstract_html = '<h3>' + 'Abstract' + '</h3>' + ''.join(abstract_html)
        body_text_html = ['<p>' + record + '</p>' for record in self.body_text.split('\n')]
        body_text_html = '<h3>' + 'Body text' + '</h3>' + ''.join(body_text_html)  
        return paper_html + abstract_html + body_text_html

In [3]:
class CollectionLoader:
    
    def __init__(self, dirs, spec=''):
        spec = self._parse_spec(spec, dirs)   

        docfiles = []
        for i, dirname in enumerate(dirs):
            dirfiles = glob.glob(f'{dirname}/**/*.json', recursive=True)
            limit = spec[i] or len(dirfiles)            
            docfiles.extend(dirfiles[:limit])
        
        self.docfiles = docfiles
            
    @staticmethod
    def _parse_spec(spec, dirs):
        if not spec: return [None] * len(dirs)
            
        spec_to_int = [int(s) if s.isdigit() else None 
                       for s in spec.split(':')]
        
        if len(dirs) != len(spec_to_int):
            raise ValueError('length of dirs does not match length of spec')
        
        return spec_to_int
        
    def __iter__(self):
        for fname in self.docfiles:
            yield Document.from_json(fname)

#### Basic usage 



In [4]:
# specify list of directories
# note: if topmost directory does not contain json files, 
# recursive search is performed
dirs = ('./dataset/noncomm_use100', 
        './dataset/comm_use100',
        './dataset/biorxiv_medrxiv100')

# pass above list and spec string
# each entry, delimited by :, in spec string represents the number of json files 
# that will be read from corresponding directory
collection_loader = CollectionLoader(dirs, spec='2:1:3')
collection = list(collection_loader)

# sanity check
print(collection[3])
print('number of documents:', len(collection))

00d16927588fb04d4be0e6b269fc02f0d3c2aa7b: Infectious bronchitis (IB) causes significant economic losses in the global poultry industry. Control of infectious bronchitis is hindered by the genetic diversity of the causative agent, infectious b ... Infectious bronchitis (IB), which is caused by infectious bronchitis virus (IBV), is one of the most important diseases of poultry, causing severe economic losses worldwide. 8 Clinical signs of diseas ...
number of documents: 6


In [5]:
# rich output (only available in Jupyter)
# collection[3]

### Preprocessing pipeline

In [5]:
import re

import spacy

In [6]:
class Pipeline:

    def __init__(self, model, before_tokenizer=None, after_tokenizer=None):
        self.model = model
        self.before_tokenizer = before_tokenizer or []
        self.after_tokenizer = after_tokenizer or []

        self._build()
    
    def _build(self):
        nlp = self._create_tokenizer()

        for component in self.after_tokenizer:
            if isinstance(component, str):
                # spacy component
                if component in self._pretrained:
                    obj = self._pretrained[component]
                else:
                    obj = nlp.create_pipe(component)
                
                nlp.add_pipe(obj)
            else:
                # user-defined component
                name, obj = component
                nlp.add_pipe(obj, name=name)

        # we dont't need cache anymore 
        del self._pretrained
        
        self.nlp = nlp


    def _create_tokenizer(self):
        # hacky way of creating spacy pipeline without components
        
        nlp = spacy.load(self.model)

        # we have to cache the pretrained components in case we need them later
        self._pretrained = {}
        
        for pipe in nlp.pipe_names:
            name, obj = nlp.remove_pipe(pipe)
            
            if name in self.after_tokenizer:
                self._pretrained[name] = obj
        
        return nlp

    def _apply_before_tokenizer(self, text):
        for func in self.before_tokenizer:
            text = func(text)
        return text

    def __call__(self, texts, n_process=1):
        pre_tokenizer = (self._apply_before_tokenizer(text) for text in texts)
       
        # nlp.pipe returns the generator, so yield from it
        yield from self.nlp.pipe(pre_tokenizer, n_process=n_process)

Helper functions for text normalization

In [8]:
NON_ALPHANUM_REG = re.compile(r"[^A-Za-z']")

def lowercase(text):
    return text.lower()

def single_space(text):
    return re.sub(r'\s+', ' ', text)

def remove_non_alpha(text):
    return re.sub(NON_ALPHANUM_REG, ' ', text)

The next step is to build a class for the convenient access to the tokens of a document

In [9]:
class DocViewer:

    def __init__(self, doc):
        self.doc = doc

    def __getitem__(self, key):
        # if the key is a normal attribute, get its value
        # otherwise ask for forgiveness
        
        try:
            return [getattr(token, key) for token in self.doc]
        except:
            pass

        values = []
        for token in self.doc:
            extension_holder = getattr(token, '_')
            values.append(getattr(extension_holder, key))
            
        return values

There is no stemmer in spacy, so let's provide one

In [10]:
from nltk.stem.snowball import SnowballStemmer
from spacy.tokens import Token

class Stemmer:

    def __init__(self, language='english'):
        self._stemmer = SnowballStemmer(language)
        Token.set_extension('stem', default=None, force=True)

    def __call__(self, doc):
        for token in doc:
            token._.set('stem', self._stemmer.stem(token.text))  
        return doc

#### Usage

In [11]:
pipeline = Pipeline(model='en_core_web_sm', before_tokenizer= [lowercase, remove_non_alpha, single_space], 
                                            after_tokenizer= [('stemmer', Stemmer()), 'tagger'])

processed = list(pipeline([doc.abstract for doc in collection]))
viewer = DocViewer(processed[3])

In [12]:
# print textual representation
# print(viewer['text'])

In [13]:
# print stemms
# print(viewer['stem'])

In [14]:
# print lemmas (in spacy, lemmatization is performed by default)
# print(viewer['lemma_'])

In [15]:
# check if a word is a stopword
# print(viewer['is_stop'])

In [16]:
# print POS tags
# print(viewer['tag_'])

### Corpus

In [13]:
def extract_doc_field(documents, field='abstract'):
    return [getattr(doc, field) for doc in documents]

def remove_empty_doc(corpus):
    return [doc for doc in corpus if doc]

In [14]:
DOCUMENTS = list(CollectionLoader(dirs, spec='40:40:40'))
RAW_CORPUS = remove_empty_doc(extract_doc_field(DOCUMENTS))
PIPELINE = Pipeline(model='en_core_web_sm', before_tokenizer=[remove_non_alpha, single_space])

In [15]:
def process(corpus, pipeline, out_field='text'):
    processed_corpus = list(pipeline(corpus))
    return [DocViewer(doc)[out_field] for doc in processed_corpus]

In [16]:
PROCESSED_CORPUS = process(RAW_CORPUS, PIPELINE)

### Metrics

In [106]:
def precision_at_k(retrieved_doc_ids, relevant_doc_ids, k):
    rel_cnt = 0
    for doc_id in retrieved_doc_ids[:k]:
        if doc_in in relevant_doc_ids:
            rel_cnt += 1
    return np.array(rel_cnt / k)

def r_precision(retrieved_doc_ids, relevant_doc_ids):
    return precision_at_k(retrieved_doc_ids, relevant_doc_ids, len(relevant_doc_ids))

def average_precision(retrieved_doc_ids, relevant_doc_ids):
    precisions = []
    rel_cnt = 0
    for i, doc_id in enumerate(retrieved_doc_ids):
        if doc_id in relevant_doc_ids:
            rel_cnt += 1
            precisions.append(rel_cnt / (i + 1))
    return np.array(precisions).mean()

def mean_average_precision(retrieved_doc_ids, relevant_doc_ids):
    assert len(retrieved_doc_ids) == len(relevant_doc_ids)
    
    average_precisions = [average_precision(ret_ids, rel_ids) in 
                          zip(retrieved_doc_ids, relevant_doc_ids)]
    return np.array(average_precisions).mean()

### Ranking algorithms 

#### BM 25

In [17]:
import rank_bm25

In [18]:
def query(text, pipeline=PIPELINE, out_field='text'):
    return DocViewer(list(pipeline([text]))[0])[out_field]

In [19]:
query('Role of the environment in transmission')

['Role', 'of', 'the', 'environment', 'in', 'transmission']

In [20]:
class BM25:
    NAME2ALG = {
        'BM25Okapi': rank_bm25.BM25Okapi,
        'BM25L': rank_bm25.BM25L,
        'BM25Plus': rank_bm25.BM25Plus,
    }
    
    def __init__(self, alg_type, documents, params=None):
        if alg_type not in BM25.NAME2ALG.keys():
            raise ValueError(f'{alg_type} is not supported')
            
        params = params or {}
        self.alg = BM25.NAME2ALG[alg_type](documents, **params)
        self.documents = documents
    
    def get_scores(self, query):
        return self.alg.get_scores(query)
    
    def top_n(self, query, n=5):
        scores = self.get_scores(query)
        top_n_scores = np.argsort(scores)[::-1][:n]
        return [(i, scores[i]) for i in top_n_scores]

#### Examples

In [21]:
q = query('covid transmission')

In [22]:
bm25okapi = BM25('BM25Okapi', PROCESSED_CORPUS)
bm25okapi.get_scores(q)

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       1.51634976, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 3.26065278, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       2.05718893, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 3.8591491 , 0.        ,
       2.25098326, 0.        , 0.        , 0.        , 0.     

In [23]:
bm25l = BM25('BM25L', PROCESSED_CORPUS)
bm25l.get_scores(q)

array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        2.33813654,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  7.33776457,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        2.72401203,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        , 12.56779579,  0.  

In [24]:
bm25plus = BM25('BM25Plus', PROCESSED_CORPUS)
bm25plus.get_scores(q)

array([ 6.85205035,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        6.85205035,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        6.85205035,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        6.85205035,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        6.85205035,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        8.48899843,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        6.85205035,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        6.85205035,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        6.85205035,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        6.85205035,  6.85205035,  6.85205035, 10.37202935,  6.85205035,
        6.85205035,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        9.07285162,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        6.85205035,  6.85205035,  6.85205035,  6.85205035,  6.85205035,
        6.85205035,  6.85205035,  6.85205035, 11.01812527,  6.85

In [105]:
doc_id, score = bm25okapi.top_n(q, 1)[0]
print('score:', score)
print(DOCUMENTS[doc_id])

score: 4.230003768125611
0a27cb2cd52229472fcfc3e49d3a3cb7179867e4: Strict interventions were successful to control the novel coronavirus (COVID-19) outbreak in China. As transmission intensifies in other countries, the interplay between age, contact patterns, social  ... also estimate age differences in susceptibility to infection and clinical disease based on contact tracing information gathered by the Hunan Provincial Center for Disease Control and Prevention (CDC), ...
