<a href="https://colab.research.google.com/github/masonnlp/bioasq_qa_system/blob/master/BioASQ_QA_System.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

In [None]:
import json
import pandas as pd
import numpy as np
!pip install transformers
import torch
import torch.nn.functional as F
device=torch.device('cuda')
from transformers import BertTokenizer,BertForSequenceClassification,AdamW,BertConfig,get_linear_schedule_with_warmup
from lxml import etree as ET
!pip3 install scispacy
!pip3 install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.4/en_core_sci_lg-0.2.4.tar.gz
import spacy
import scispacy
import en_core_sci_lg
nlp = en_core_sci_lg.load()
from bs4 import BeautifulSoup

Read input file (.csv) and predict type for each question

In [None]:
def preprocess(df):
  df.encoded_tokens = [tokenizer.encode_plus(text,add_special_tokens=True)['input_ids'] for text in df['Question']] #encoded tokens for each tweet
  df.attention_mask = [tokenizer.encode_plus(text,add_special_tokens=True)['attention_mask'] for text in df['Question']]
  encoded_tokens = list(df.encoded_tokens)
  attention_mask = list(df.attention_mask)
  return encoded_tokens,attention_mask

# Convert indices to Torch tensor and dump into cuda
def feed_generator(encoded_tokens,attention_mask):

    batch_size = 16
    batch_seq = [x for x in range(int(len(encoded_tokens)/batch_size))]


    shuffled_encoded_tokens,shuffled_attention_mask = encoded_tokens,attention_mask

    res = len(encoded_tokens)%batch_size
    if res != 0:
        batch_seq = [x for x in range(int(len(encoded_tokens)/batch_size)+1)]
    shuffled_encoded_tokens = shuffled_encoded_tokens+shuffled_encoded_tokens[:res]
    shuffled_attention_mask = shuffled_attention_mask+shuffled_attention_mask[:res]

    for batch in batch_seq:
        maxlen_sent = max([len(i) for i in shuffled_encoded_tokens[batch*batch_size:(batch+1)*batch_size]])
        token_tensor = torch.tensor([tokens+[0]*(maxlen_sent-len(tokens)) for tokens in shuffled_encoded_tokens[batch*batch_size:(batch+1)*batch_size]])
        attention_mask = torch.tensor([tokens+[0]*(maxlen_sent-len(tokens)) for tokens in shuffled_attention_mask[batch*batch_size:(batch+1)*batch_size]]) 

        token_tensor = token_tensor.to('cuda')
        attention_mask = attention_mask.to('cuda')

        yield token_tensor,attention_mask

def predict(model,data):
    model.eval()
    model.cuda()
    preds = []
    batch_count = 0
    for token_tensor, attention_mask in data:
        with torch.no_grad():
            logits = model(token_tensor,token_type_ids=None,attention_mask=attention_mask)[0]
        tmp_preds = torch.argmax(logits,-1).detach().cpu().numpy().tolist()
        preds += tmp_preds             
    return preds

test_data_path = '/content/gdrive/My Drive/Colab Notebooks/BioASQ/input.csv'
testing_df = pd.read_csv(test_data_path,sep=',',header=0)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
#checkpoint_dir = "gdrive/My Drive/Colab Notebooks/bert-large-v3/"


model = BertForSequenceClassification.from_pretrained('/content/gdrive/My Drive/Colab Notebooks/BioASQ/Model/', cache_dir=None)

encoded_tokens_Test,attention_mask_Test = preprocess(testing_df)
data_test = feed_generator(encoded_tokens_Test, attention_mask_Test)
preds_test = predict(model,data_test)


indices_to_label = {0: 'factoid', 1: 'list', 2: 'summary', 3: 'yesno'}

predict_label = []
for i in preds_test[0:len(testing_df['Question'])]:
  for j in indices_to_label:
    if i == j:
      predict_label.append(indices_to_label[j])

testing_df['type'] = predict_label


Create output file (XML) providing question type to Answer Processing system and query for Information Retrieval system

In [None]:
def xml_tree(df):
    root = ET.Element("Input")
    for ind in df.index:
      id = df['ID'][ind]
      question = df['Question'][ind]
      qtype = df['type'][ind]
      q = ET.SubElement(root,"Q")
      q.set('id',str(id))
      q.text = question
      qp = ET.SubElement(q,"QP")
      qp_type = ET.SubElement(qp,'Type')
      qp_type.text = qtype
      doc = nlp(question)
      ent_list = []
      for ent in doc.ents:
        ent_list.append(str(ent))
        qp_en = ET.SubElement(qp,'Entities') 
        qp_en.text = str(ent)
      qp_query = ET.SubElement(qp,'Query')
      qp_query.text = str(' '.join(ent_list))
      # Create IR tag
      IR = ET.SubElement(q, "IR")
      
    tree = ET.ElementTree(root)
    tree.write('gdrive/My Drive/Colab Notebooks/BioASQ/qp_demo.xml', pretty_print=True)
    

xml_tree(testing_df)    
    

**Start IR Module**

PubmedArticle

In [None]:
"""
This module implements the class DataSetReader which contains
 the implementation of code to read the BioAsq dataset
"""
from typing import List

class PubmedArticle:

    def fromDict(data: dict):
        pmid = data["pmid"]
        title = data["title"]
        journal = data["journal"]
        mesh_major = data["meshMajor"]
        year = data["year"]
        abstract_text = data["abstractText"]
        return PubmedArticle(pmid, title, journal,
                             year, abstract_text, mesh_major)

    def __init__(self, pmid: str, title: str, journal: str,
                 year: str, abstract_text: str, mesh_major: List[str]):
        self.journal = journal
        self.mesh_major = mesh_major
        self.year = year
        self.abstract_text = abstract_text
        self.pmid = pmid
        self.title = title

PubmedReader

In [None]:
"""
This modeule implements reading pubmed xml fragments
"""
import os
import gzip
import xml.etree.ElementTree as ET
from typing import List
# No need to import PubmedArticle since it's in the same notebook


class PubmedReader:
    """
    This class is responsible for reading the Pubmed dataset
    """

    def __init__(self):
        """
        default constructor doesn't do anything
        """
        pass

    def get_xml_frags(self, dir: str) -> List[str]:
        """
        given a directory where all the xml fragments reside
        will return the list of all the xml fragments
        """
        file_names = os.listdir(dir)
        file_indexes = [i for i, val in enumerate(
            map(lambda nm: nm.startswith("pubmed")
                and nm.endswith(".xml.gz"),
                file_names)) if val]
        return list(map(lambda i: file_names[i], file_indexes))

    def process_xml_frags(
            self, dir: str,
            max_article_count: int):
        frags = self.get_xml_frags(dir)
        remaining_count = max_article_count
        for frag in frags:
            if remaining_count > 0:
                articles = self.process_xml_frag(dir + "/"
                                                 + frag, remaining_count)
                remaining_count -= len(articles)
                if len(articles) == 0:
                    break
                for article in articles:
                    yield article
            else:
                break

    def process_xml_frag(
            self, fname: str, max_article_count:
            int):
        """
        This method reads to a complete gzipped xml file
        and extracts each PubmedArticle, and returns a list
        of PubmedArticle objects that contain all the relevant
        fields
        """
        articles = []
        with gzip.open(fname, 'rt', encoding="utf-8") as f:
            count = 0
            pubmed_article_txt = ""
            record = False
            while True:
                line = f.readline()
                if not line:
                    break
                if '<PubmedArticle>' in line:
                    record = True
                if record:
                    pubmed_article_txt += line
                if '</PubmedArticle>' in line:
                    if count >= max_article_count:
                        print("reached max article count ending read")
                        break
                    count += 1
                    record = False
                    articles.append(
                        self.process_pubmed_article_xml(pubmed_article_txt))
                    pubmed_article_txt = ""
        print("fname", fname, "articles", count)
        return articles

    def process_pubmed_article_xml(self, txt: str) -> PubmedArticle:
        """
        this article takes an XML fragment of a single Pubmed article
        entry and parses it for data
        It returns a populated PubmedArticle object
        """
        root = ET.fromstring(txt)
        pmid = root.findtext('.//PMID')
        title = root.findtext('.//ArticleTitle')
        abstract_text = root.findtext('.//AbstractText')
        journal = root.findtext('.//Title')
        if root.findtext('.//PubDate/Year'):
          year = root.findtext('.//PubDate/Year')
        else:
          year = 0000
        mesh_major = list(
            map(lambda x: x.text, root.findall(".//DescriptorName")))
        return PubmedArticle(
            pmid, title, journal, year, abstract_text, mesh_major)


PubmedIndexer

Install Whoosh

In [None]:
!pip install whoosh

In [None]:
"""
This module indexes the Pubmed dataset using Whoosh
"""
import os
import os.path
import shutil
from whoosh import index
from whoosh.fields import Schema, TEXT, IDLIST, ID, NUMERIC
from whoosh.analysis import StemmingAnalyzer
from whoosh.qparser import QueryParser
from datetime import datetime
from typing import List


class PubmedIndexer:
    """
    PubmedIndexer is the main class that clients are expected to to use.
    The primary functions it performs are:
    1. Indexing the pubmed articles into a Whoosh index
    2. Allowing the free text searching of the pubmed articles

    NOTES:
    1. The pubmed data is provided here:
      ftp://ftp.ncbi.nlm.nih.gov/pubmed/updatefiles/
    2. We do not index all the fields per article -- we index:
      a. The pubmed ID
      b. The Journal name
      c. The Year of publication
      d. The Article title
      e. The Article Abstract
    3. The complete pubmed dataset is just under 7 GB of compressed
      XML shards (as of this writing)
    4. This module allows all this data to be indexed
    5. The index takes about 5 hours to generate on a medium powered laptop
    6. The index directly is roughly 7 GB
    7. The index directory can be tarred(zipped) and shared between users
    8. We will probably rename this module pubmed_ir soon and relase it to PyPI

    MISSING & DESIRABLE FUNCTIONALITY
    1. It would be good to have utility function that is able to download
      the pub med data
    2. We should get __init__.py, etc. files done so we can publish to PyPi
    3. We should have a partial indexing feature that indexes only data needed
       for biosqr task b
    4. We might make the index generation system more customization interms
       of things such as Analyzers, stop-words, etc.
    5. We may need a customizable result scoring function -- beyond BM25
    6. We may want a more sophisticated querying interface, boolean queries, etc
    7. We need a lot of testing to certify the system
    8. It is not clear if we can add documents to an existing index
    9. It is not clear how we can re-index an existing index
    10. We should swap out prints with a formal logging framework
    11. We should have example modules which demonstrate the use of this system
    12. We really need to modify the directory structure of the project

    BUGS & KNOWN LIMITATIONS
    1. At the moment the free text query only searches the Abstract Text
      it does not search the title

    """

    def __init__(self):
        """
        default construstor it does nothing at the moment
        """
        pass

    def mk_index(self, indexpath: str = "indexdir",
                 overwrite: bool = False) -> None:
        """
        creates a Whoosh based index for subsequent IR operatons

        Prameters
        ---------
        indexpath: str
            The absolute or relative path where you want the index to be stored
               Note: the index path is a directory
               this directory will contain all the Whoosh files
        overwrite: boolean
            This will overwrite any existing index (directory) if set to True
            The default value is set to False (safe setting)

        Returns:
        None
            it is a void method and returns the None value
        """
        use_existing_index = True
        if os.path.exists(indexpath):
            if overwrite:
                shutil.rmtree(indexpath)
                use_existing_index = False
        if not os.path.exists(indexpath):
            os.mkdir(indexpath)
            use_existing_index = False
        self.pubmed_article_schema = Schema(
            pmid=ID(stored=True),
            title=TEXT(stored=True),
            journal=TEXT(stored=True),
            mesh_major=IDLIST(stored=True),
            year=NUMERIC(stored=True),
            abstract_text=TEXT(stored=True, analyzer=StemmingAnalyzer()))
        print(use_existing_index)
        if not use_existing_index:
            self.pubmed_article_ix = index.create_in(
                indexpath,
                self.pubmed_article_schema,
                indexname="pubmed_articles")
        else:
            self.pubmed_article_ix = index.open_dir(
                indexpath, indexname="pubmed_articles")
        print("index object created")

    def rm_index(self, indexpath: str = "indexdir") -> None:
        """
        This is a utility function to delete an existing index

        Parameters
        ----------
        indexpath: str
            The absolute or relative path of the index location

        Returns
        -------
        None
            This void medhod return nothing
        """
        if os.path.exists(indexpath):
            os.rmdir(indexpath)

    def index_docs(self, articles,
                   limit: int):
        """"
        indexes documents into the Whoosh index

        Parameters
        ----------
        articles: List[PubmedArticle]
            The list of articles to be added to the index
        limit: int
            This is a cutoff, beyond which the indexing process will cease
            The purpose of this parameter is to limit the amount of documents
            to be indexed for testing purposes or quick function execution for
            experimental methods

        Returns
        -------
        None:
           this is a void method an returns nothing

        TODO: add handling LockError
        TODO: add handling test for LockError
        """
        print("adding documents")
        pubmed_article_writer = self.pubmed_article_ix.writer()
        count = 0
        for article in articles:
            count += 1
            if count > limit:
                break
            pubmed_article_writer.add_document(
                pmid=article.pmid,
                title=article.title,
                journal=article.journal,
                mesh_major=article.mesh_major,
                year=article.year,
                abstract_text=article.abstract_text)
        pubmed_article_writer.commit()
        print("commiting index, added", count, "documents")

    def search(self, query,
               max_results: int = 1):
        """
        This is our simple starter method to query the index

        Parameters
        ----------
        query: str
           This is a plain text query string that Whoosh searches
           the index for matches
        max_results: int
           This parameter sets the maximum number of results the
           method will return
        """
        res = []
        qp = QueryParser("abstract_text", schema=self.pubmed_article_schema)
        q = qp.parse(query)
        with self.pubmed_article_ix.searcher() as s:
            results = s.search(q, limit=max_results)
            for result in results:
                pa = PubmedArticle(result['pmid'],
                                   result['title'],
                                   result['journal'],
                                   result['year'],
                                   result['abstract_text'],
                                   result['mesh_major'])
                res.append(pa)
            return res


XML Extractor

In [None]:
import lxml.etree as ET

def extract_and_write(filename, results, question_id, query):
    """
    Extract information from IR system and write to XML file. Format is:
    <Result PMID=1>
        <Journal>Title of journal</Journal>
        <Year>Year published</Year>
        <Title>Title of article</Title>
        <Abstract>Abstract (~couple of sentences/a paragraph)</Abstract>
        <MERS>tag1</MERS>
        <MERS>tag2</MERS>
    </Result>
    :param filename: Name of the XML file used in the QA system
    """
    origTree = ET.parse(filename)
    root = origTree.getroot()

    Q = root.find("Q")
    IR = Q.find("IR")

    # Find the IR element to write to
    questions = root.findall("Q")
    for question in questions:
        if question.get("id") == question_id:
            IR = question.find("IR")
            # Create a subelement for each part of the result (there can be many)
            for pa in results:
              queryUsed = ET.SubElement(IR, "QueryUsed")
              queryUsed.text = query
              result = ET.SubElement(IR, "Result")
              result.set("PMID", pa.pmid)
              journal = ET.SubElement(result, "Journal")
              journal.text = pa.journal
              year = ET.SubElement(result, "Year")
              year.text = pa.year
              title = ET.SubElement(result, "Title")
              title.text = pa.title
              abstract = ET.SubElement(result, "Abstract")
              abstract.text = pa.abstract_text
              for mesh in pa.mesh_major:
                  mesh_major = ET.SubElement(result, "MeSH")
                  mesh_major.text = mesh
        tree = ET.ElementTree(root)
        tree.write(filename, pretty_print=True)

Index Documents

In [None]:
# Create new index
pubmed_indexer = PubmedIndexer()
pubmed_indexer.mk_index('indexdir2', overwrite=True)
reader = PubmedReader()
articles = reader.process_xml_frags('gdrive/My Drive/Colab Notebooks/BioASQ/data2', max_article_count=10000)
pubmed_indexer.index_docs(articles, limit=10000)

Run XML extractor

In [None]:
file = 'gdrive/My Drive/Colab Notebooks/BioASQ/qp_demo.xml'
origTree = ET.parse(file)
root = origTree.getroot()
for question in root.findall('Q'):
    # Question ID to write IR results to the appropriate question
    qid = question.get("id")
    qp = question.find("QP")

    # If there is no query, use the original question
    if qp.find("Query").text:
        query = qp.find("Query").text
    else:
        query = question.text

    results = pubmed_indexer.search(query)

    # Only want to call write method if a result was found for the query
    if results:
      extract_and_write(file, results, qid, query)