In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
import sys
import time
from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering
import textwrap
import re
import attr
import abc
import string
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from IPython.display import HTML
from os import listdir
from os.path import isfile, join


In [2]:
import warnings  
warnings.filterwarnings('ignore')
MAX_ARTICLES = 1000
base_dir = '/kaggle/input'
data_path = base_dir + '/covid19/covid19.csv'
model_path ='clagator/biobert_squad2_cased'

## Document Retrieval based on Query

In [3]:
dff=pd.read_csv(data_path)
dff

In [4]:
class Retrieval(abc.ABC):
    """Base class for retrieval methods."""

    def __init__(self, docs, keys=None):
        self._docs = docs.copy()
        if keys is not None:
            self._docs.index = keys
        self._model = None
        self._doc_vecs = None

    def _top_documents(self, q_vec, top_n=10):
        similarity = cosine_similarity(self._doc_vecs, q_vec)
        rankings = np.argsort(np.squeeze(similarity))[::-1]
        ranked_indices = self._docs.index[rankings]
        return self._docs[ranked_indices][:top_n]

    @abc.abstractmethod
    def retrieve(self, query, top_n=10):
        pass
    
class TFIDFRetrieval(Retrieval):
    """Retrieve documents based on cosine similarity of TF-IDF vectors with query."""
    def __init__(self, docs, keys=None):
        """
        Args:
          docs: a list or pd.Series of strings. The text to retrieve.
          keys: a list or pd.Series. Keys (e.g. ID, title) associated with each document.
        """
        super(TFIDFRetrieval, self).__init__(docs, keys)
        self._model = TfidfVectorizer()
        self._doc_vecs = self._model.fit_transform(docs)
        
    def retrieve(self, query, top_n=10):
        q_vec = self._model.transform([query])
        return self._top_documents(q_vec, top_n)

## Model

In [5]:
class ResearchQA(object):
    def __init__(self, data_path, model_path):
        print('Loading data from', data_path)
        self.df = pd.read_csv(data_path)
        print('Initializing model from', model_path)
        self.model = TFAutoModelForQuestionAnswering.from_pretrained(model_path, from_pt=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.retrievers = {}
        self.build_retrievers()
        self.main_question_dict = dict()
        
    
    def build_retrievers(self):
        df = self.df
        abstracts = df[df.abstract.notna()].abstract
        self.retrievers['abstract'] = TFIDFRetrieval(abstracts)

    def retrieve_candidates(self, section_path, question, top_n):
        candidates = self.retrievers[section_path[0]].retrieve(question, top_n)
        return self.df.loc[candidates.index]
    
        
    def get_answers(self, question, section='abstract', keyword=None, max_articles=1000, batch_size=12):
        df = self.df
        answers = []
        section_path = section.split('/')

        if keyword:
            candidates = df[df[section_path[0]].str.contains(keyword, na=False, case=False)]
        else:
            candidates = self.retrieve_candidates(section_path, question, top_n=max_articles)
        if max_articles:
            candidates = candidates.head(max_articles)

        text_list = []
        indices = []
        for idx, row in candidates.iterrows():
            text = row[section]
            if text and isinstance(text, str):
                text_list.append(text)
                indices.append(idx)

        num_batches = len(text_list) // batch_size
        all_answers = []
        for i in range(num_batches):
            batch = text_list[i * batch_size:(i+1) * batch_size]
            answers = self.get_answers_from_text_list(question, batch)
            all_answers.extend(answers)

        last_batch = text_list[batch_size * num_batches:]
        if last_batch:
            all_answers.extend(self.get_answers_from_text_list(question, last_batch))

        columns = ['doi', 'authors', 'journal', 'publish_time', 'title']
        processed_answers = []
        for i, a in enumerate(all_answers):
            if a:
                row = candidates.reindex([indices[i]])
#                 print(row)
                new_row = [a.text, a.start_score, a.end_score, a.input_text]
                new_row.extend(row[columns].values[0])
#                 print(row[columns].values[0], len(new_row))
                processed_answers.append(new_row)
        answer_df = pd.DataFrame(processed_answers, columns=(['answer', 'start_score',
                                                 'end_score', 'context'] + columns))
        return answer_df.sort_values(['start_score', 'end_score'], ascending=False)

    def get_answers_from_text_list(self, question, text_list, max_tokens=512):
        tokenizer = self.tokenizer
        model = self.model
        inputs = tokenizer.batch_encode_plus(
          [(question, text) for text in text_list], add_special_tokens=True, return_tensors='tf',
          max_length=max_tokens, truncation_strategy='only_second', pad_to_max_length=True)
        input_ids = inputs['input_ids'].numpy()
#         answer_start_scores, answer_end_scores = model(inputs)
        output = model(inputs)
        answer_start = tf.argmax(
          output.start_logits, axis=1
      ).numpy()  # Get the most likely beginning of each answer with the argmax of the score
        answer_end = (
          tf.argmax(output.end_logits, axis=1) + 1
      ).numpy()  # Get the most likely end of each answer with the argmax of the score

        answers = []
        for i, text in enumerate(text_list):
            input_text = tokenizer.decode(input_ids[i, :], clean_up_tokenization_spaces=True)
            input_text = input_text.split('[SEP] ', 2)[1]
            answer = tokenizer.decode(
                input_ids[i, answer_start[i]:answer_end[i]], clean_up_tokenization_spaces=True)
            score_start = output.start_logits.numpy()[i][answer_start[i]]
            score_end = output.end_logits.numpy()[i][answer_end[i]-1]
            if answer and not '[CLS]' in answer:
                answers.append(Answer(answer, score_start, score_end, input_text))
            else:
                answers.append(None)
        return answers

In [6]:
@attr.s
class Answer(object):
    text = attr.ib()
    start_score = attr.ib()
    end_score = attr.ib()
    input_text = attr.ib()

In [7]:
qa = ResearchQA(data_path, model_path)

In [8]:
answers = qa.get_answers('What kind of cytokines play a major role in host response?',max_articles=5)
answers["answer"]

In [9]:
answers = qa.get_answers('What drugs are effective?',max_articles=5)
answers["answer"]