## Via pattern matching over BERT to a summary on risk factors of COVID-19

### The offical proposed question which I want to find answers with this notebook is "What do we know about COVID-19 risk factors?"

This notebook contains a pattern matching approach applied to the stated question above.

In [1]:
import numpy as np
import pandas as pd
import glob
import os
import json
import string
from collections import Counter
import nltk
nltk.download('punkt')
nltk.download('stopwords')
from nltk.corpus import stopwords as nltkstopwords
from nltk.tokenize import word_tokenize
import re
from gensim.parsing.preprocessing import preprocess_string, strip_tags, strip_punctuation, strip_numeric, remove_stopwords, strip_multiple_whitespaces
import matplotlib.pyplot as plt
import torch
from transformers import BertTokenizer, AutoTokenizer, BertForQuestionAnswering, AutoModelForQuestionAnswering

[nltk_data] Downloading package punkt to /home/nic/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/nic/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
unable to import 'smart_open.gcs', disabling that module


Setting up all the paths

In [2]:
base_path = "../input/CORD-19-research-challenge"

biorxiv_medrxiv = "biorxiv_medrxiv/biorxiv_medrxiv/pdf_json"
comm_use_subset = "comm_use_subset/comm_use_subset/pdf_json"
noncomm_use_subset = "noncomm_use_subset/noncomm_use_subset/pdf_json"
custom_license = "custom_license/custom_license/pdf_json"

In [3]:
covid_kaggle_questions = {
"data":[
          {
              "task": "What do we know about COVID-19 risk factors?",
              "questions": [
                  "risk factors covid 19",
                  "risk factors corona",
                  "What risk factors contribute to the severity of 2019-nCoV?",
                  #"How does hypertension affect patients?",
                  #"How does heart disease affect patients?",
                  #"How does copd affect patients?",
                  #"How does smoking affect patients?",
                  #"How does pregnancy affect patients?",
                  #"What is the fatality rate of 2019-nCoV?",
                  #"What public health policies prevent or control the spread of 2019-nCoV?"
              ]
          }
   ]
}

In [4]:
class Document:
    
    def __init__(self, title, file_path, body_text, abstract):
        self.title = title
        self.file_path = file_path
        self.body_text = body_text
        self.abstract = abstract

In [5]:
%%time
files = []
file_names = []
documents = []
def read_files(directory):
    files_in_dir = [f for f in glob.glob(os.path.join(base_path, directory) + "/*.json", recursive=True)]
    files.extend(files_in_dir)
    for file_path in files_in_dir:
        data = json.load(open(file_path))
        body_text = ""
        for i in range(len(data["body_text"])):
            body_text += " " + data["body_text"][i]["text"]
        
        body_text = re.sub(' +', ' ', body_text)
        abstract_text = ""
        for i in range(len(data["abstract"])):
            abstract_text += " " + data["abstract"][i]["text"]
        title = data["metadata"]["title"]
        documents.append(Document(title, file_path, body_text, abstract_text))
    return len(files_in_dir)

print("Number of biorxiv_medrxiv documents: {}".format(read_files(biorxiv_medrxiv)))
#print("Number of comm_use_subset documents: {}".format(read_files(comm_use_subset)))
#print("Number of noncomm_use_subset documents: {}".format(read_files(noncomm_use_subset)))
#print("Number of custom license documents: {}".format(read_files(custom_license)))
print("Total number of documents: {}".format(len(documents)))

Number of biorxiv_medrxiv documents: 1342
Total number of documents: 1342
CPU times: user 1.99 s, sys: 87.4 ms, total: 2.08 s
Wall time: 2.43 s


In [6]:
# filter duplicates

counter = 0

abstract_lengths = {}
text_lengths = {}
duplicates = []
for i in range(len(documents)):
    text_a = documents[i].body_text
    for j in range(i+1, len(documents)):
        text_b = documents[j].body_text
        if text_a == text_b:
            duplicates.append(j)
print("Found {} duplicates".format(len(duplicates)))

Found 1 duplicates


In [7]:
filtered_documents = [doc for idx, doc in enumerate(documents) if idx not in duplicates]

### Search papers by regex patterns

Let us define some regex patterns to find text passages where risk factors might be mentioned.
To do so we are looking for phrases/words which might appear in such a context.

In [8]:
# set up some constants
NUMBER_OF_FILES = len(filtered_documents)

#set up patterns, which might appear in the context of high risk factors
factor___risk_pattern = r"(factor(.){0,9}risk)" # for example for "factors of risk"
risk_factor_pattern = r"(risk(.){0,4}factor)" # for example for "risk factors"
high_risk_pattern = r"(high(.){0,6}risk)" # for example for "high risk" or "highly risky"
averse_outcomes_pattern = r"(advers(.){0,4}outcome)"
#risk_pattern = r"(risk)"
#comorbdit_pattern = r"(comorbdit)"
#co_infects_pattern = r"(co(.){0,4}infect)"
#neonat_pattern = r"(neonat)"
#pregnant_pattern = r"(pregnant)"
#smoking_pattern = r"(smoking)"
#cancer_pattern = r"(cancer)"

PATTERNS = [
    factor___risk_pattern,
    risk_factor_pattern,
    high_risk_pattern,
    averse_outcomes_pattern
]
POTENTIAL_RISKS = [
    "smoking",
    "pulmonary diseas", 
    "elder",
    "diabetes",
    "old",
    "age",
    "cancer", 
    "cardiac",
    "cardio"]

CUSTOM_FILTERS = [lambda x: x.lower(), strip_tags, strip_punctuation]
CUSTOM_FILTERS_EXCLUDE_NUMERIC = [lambda x: x.lower(), strip_tags, strip_punctuation, strip_numeric]

Now we are going through the texts and search for the patterns defined above.
We are going to append start and end indices of each matche to a list.

In [9]:
%%time
def extract_windows_containing(docs, pattern, print_out=True):
    indices = []
    for idx in range(NUMBER_OF_FILES):
        body_text = docs[idx].body_text
        
        indices_of_file = [(m.start(0), m.end(0)) for m in re.finditer(pattern, body_text)]
        indices.append(indices_of_file)
    
    return indices

indices = [[] for _ in range(NUMBER_OF_FILES)]
for pattern in PATTERNS:
    indices_ = extract_windows_containing(filtered_documents, pattern)
    for i in range(len(indices_)):
        indices[i].extend(indices_[i])

print("Found {} candidate text extracts".format(len([1 for a in indices if len(a)!=0])))

Found 244 candidate text extracts
CPU times: user 95.5 ms, sys: 84 µs, total: 95.6 ms
Wall time: 95.2 ms


In [23]:
len(filtered_documents), len(indices)

(1341, 1341)

Till now only have the indices of where the patterns matched.
Now we go through them and get a text window around them, we just choose a window of 500 characters (WINDOW_SIZE), which means that we extract at most 250 characters before and after the match.

While iterating through all matches we also remove special characters like punctuation and numeric values.

In [10]:
%%time

WINDOW_SIZE = 500
def process_indices_for_file(file_number, indices_of_file, filters, docs):
    contexts = []
    for match in indices_of_file:
        start = match[0]-int(WINDOW_SIZE/2)
        end = match[1]+int(WINDOW_SIZE/2)
        context = docs[file_number].body_text[start:end]
        contexts.append(context)

    return contexts

def remove_special_character(text):
    return text.translate(str.maketrans('', '', string.punctuation))

def tokenize(text):
    words = nltk.word_tokenize(text)
    return [str(word).lower() for word in words if len(word) > 1 and not word.isnumeric()]

potential_contexts = []
length_of_longest_context = 0
for file_number, indices_of_file in enumerate(indices):
    if len(indices_of_file) != 0:
        for context in process_indices_for_file(file_number, indices_of_file, CUSTOM_FILTERS_EXCLUDE_NUMERIC, filtered_documents):
            processed_context = " ".join(tokenize(remove_special_character(context)))
            if len(processed_context) > length_of_longest_context:
                length_of_longest_context = len(processed_context)
            potential_contexts.append(processed_context)

print("Length of longest context:", length_of_longest_context)
print("Number of potential_contexts:", len(potential_contexts))

Length of longest context: 507
Number of potential_contexts: 682
CPU times: user 262 ms, sys: 3.13 ms, total: 265 ms
Wall time: 265 ms


In [30]:
%%time
def doit():
    print_counter = 0
    overall_processed_indices_count = 0
    refined_potential_contexts = []
    for file_number, indices_of_file in enumerate(indices):
        if len(indices_of_file) != 0:
            matches = process_indices_for_file(file_number, indices_of_file, CUSTOM_FILTERS, filtered_documents)
            processed_indices_of_file = []
            indices_of_file = sorted(indices_of_file)
            # adjust indices
            for i in range(len(indices_of_file)):
                if i != 0 and len(processed_indices_of_file) != 0:
                    if abs(processed_indices_of_file[-1][0] - indices_of_file[i][0]) > 100 and abs(processed_indices_of_file[-1][1] - indices_of_file[i][1]) > 100:
                        processed_indices_of_file.append(indices_of_file[i])
                    else:
                        min_ = min(indices_of_file[i][0], processed_indices_of_file[-1][0])
                        max_ = max(indices_of_file[i][1], processed_indices_of_file[-1][1])
                        del processed_indices_of_file[-1]
                        processed_indices_of_file.append((min_, max_))
                else:
                    processed_indices_of_file.append(indices_of_file[i])
            overall_processed_indices_count += len(processed_indices_of_file)

            for index, match in zip(processed_indices_of_file, matches):
                for pattern in POTENTIAL_RISKS:
                    if pattern in match:
                        windows_size = max(int(len(" ".join(match))/2), 300)
                        if print_counter < 50:
                            print_counter += 1
                            text = filtered_documents[file_number].body_text
                            start_idx = max(0, index[0]-windows_size)
                            end_idx = min(len(text), index[1]+windows_size)
                            refined_potential_contexts.append(text[start_idx:end_idx])
    return overall_processed_indices_count, refined_potential_contexts

overall_processed_indices_count, refined_potential_contexts = doit()

CPU times: user 12.5 ms, sys: 0 ns, total: 12.5 ms
Wall time: 11.8 ms


In [31]:
print("Total number of possible text passages about risk factors: {}".format(overall_processed_indices_count))

Total number of possible text passages about risk factors: 647


I tested different models

tested models :
 - SCIBERT = "allenai/scibert_scivocab_uncased"
 - BERT_SQUAD = 'bert-large-uncased-whole-word-masking-finetuned-squad'
 - OTHER = "mrm8488/scibert_scivocab-finetuned-CORD19"
 - OTHER = "ktrapeznikov/scibert_scivocab_uncased_squad_v2"
 - OTHER = "ahotrod/roberta_large_squad2"
 - model = BertForQuestionAnswering.from_pretrained(OTHER)
 - tokenizer = BertTokenizer.from_pretrained(OTHER)

In [32]:
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

SCIBERT = "allenai/scibert_scivocab_uncased"
BERT_SQUAD = 'bert-large-uncased-whole-word-masking-finetuned-squad'
#OTHER = "mrm8488/scibert_scivocab-finetuned-CORD19"
OTHER = "ktrapeznikov/scibert_scivocab_uncased_squad_v2"
#OTHER = "ahotrod/roberta_large_squad2"
#model = BertForQuestionAnswering.from_pretrained(OTHER)
#tokenizer = BertTokenizer.from_pretrained(OTHER)

tokenizer = AutoTokenizer.from_pretrained(OTHER)

model = AutoModelForQuestionAnswering.from_pretrained(OTHER)


model = model.to(torch_device)
model.eval()


def answer_question(question, context):
    encoded_dict = tokenizer.encode_plus(
                        question, context,
                        add_special_tokens = True,
                        max_length = 256,
                        pad_to_max_length = True,
                        return_tensors = 'pt'
                   )
    
    input_ids = encoded_dict['input_ids'].to(torch_device)
    token_type_ids = encoded_dict['token_type_ids'].to(torch_device)
    
    start_scores, end_scores = model(input_ids, token_type_ids=token_type_ids)

    all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    start_index = torch.argmax(start_scores)
    end_index = torch.argmax(end_scores)
    
    answer = tokenizer.convert_tokens_to_string(all_tokens[start_index:end_index+1])
    answer = answer.replace('[CLS]', '')
    return answer

In [33]:
NUM_CONTEXT_FOR_EACH_QUESTION = 10

def get_all_context(query, num_results):
    # Return ^num_results' papers that better match the query
    papers_df = cse.search(query, num_results)
    return papers_df['abstract'].str.replace("Abstract", "").tolist()


def get_all_answers(question, all_contexts):
    # Ask the same question to all contexts (all papers)
    
    all_answers = []
    
    for context in all_contexts:
        all_answers.append(answer_question(question, context))
    return all_answers


def create_output_results(question, 
                          all_contexts, 
                          all_answers, 
                          summary_answer='', 
                          summary_context=''):
    # Return results in json format
    
    def find_start_end_index_substring(context, answer):   
        search_re = re.search(re.escape(answer.lower()), context.lower())
        if search_re:
            return search_re.start(), search_re.end()
        else:
            return 0, len(context)
        
    output = {}
    output['question'] = question
    output['summary_answer'] = summary_answer
    output['summary_context'] = summary_context
    results = []
    for c, a in zip(all_contexts, all_answers):

        span = {}
        span['context'] = c
        span['answer'] = a
        span['start_index'], span['end_index'] = find_start_end_index_substring(c,a)

        results.append(span)
    
    output['results'] = results
        
    return output

    
def get_results(question,
                all_contexts,
                summarize=False, 
                num_results=NUM_CONTEXT_FOR_EACH_QUESTION,
                verbose=True):    
    all_answers = get_all_answers(question, all_contexts)
    
    if summarize:
        # NotImplementedYet
        summary_answer = get_summary(all_answers)
        summary_context = get_summary(all_contexts)
    
    return create_output_results(question, 
                                 all_contexts, 
                                 all_answers)

In [34]:
from IPython.display import display, Markdown, Latex, HTML

def layout_style():
    style = """
        div {
            color: black;
        }
        .single_answer {
            border-left: 3px solid #dc7b15;
            padding-left: 10px;
            font-family: Arial;
            font-size: 16px;
            color: #777777;
            margin-left: 5px;

        }
        .answer{
            color: #dc7b15;
        }
        .question_title {
            color: grey;
            display: block;
            text-transform: none;
        }      
        div.output_scroll { 
            height: auto; 
        }
    """
    return "<style>" + style + "</style>"

def dm(x): display(Markdown(x))
def dh(x): display(HTML(layout_style() + x))

In [35]:
def display_single_context(context, start_index, end_index):    
    before_answer = context[:start_index]
    answer = context[start_index:end_index]
    after_answer = context[end_index:]

    content = before_answer + "<span class='answer'>" + answer + "</span>" + after_answer

    return dh("""<div class="single_answer">{}</div>""".format(content))

def display_question_title(question):
    return dh("<h2 class='question_title'>{}</h2>".format(question.capitalize()))


def display_all_contexts(index, question):
    def answer_not_found(context, start_index, end_index):
        return (start_index == 0 and len(context) == end_index) or (start_index == 0 and end_index == 0)

    display_question_title(str(index + 1) + ". " + question['question'].capitalize())
    
    # display context
    for i in question['results']:
        if answer_not_found(i['context'], i['start_index'], i['end_index']):
            continue # skip not found questions
        display_single_context(i['context'], i['start_index'], i['end_index'])

def display_task_title(index, task):
    task_title = "Task " + str(index) + ": " + task
    return dh("<h1 class='task_title'>{}</h1>".format(task_title))

def display_single_task(index, task):    
    display_task_title(index, task['task'])
    
    for i, question in enumerate(task['questions']):
        display_all_contexts(i, question)

In [36]:
%%time
all_tasks = []

all_contexts = potential_contexts[:10]

for i, t in enumerate(covid_kaggle_questions['data']):
    print("Answering questions to task {} ...".format(i+1))
    answers_to_question = []
    for q in t['questions']:
        res = get_results(q, all_contexts, verbose=False)
        answers_to_question.append(res)
    task = {}
    task['task'] = t['task']
    task['questions'] = answers_to_question
    
    all_tasks.append(task)
    print(" answerd")

all_answers = {}
all_answers['data'] = all_tasks

task = 1
display_single_task(task, all_tasks[task-1])

Answering questions to task 1 ...
 answerd


CPU times: user 1.15 s, sys: 8.5 ms, total: 1.16 s
Wall time: 664 ms


In [47]:
%%time
all_tasks = []

all_contexts = refined_potential_contexts[:1000]

for i, t in enumerate(covid_kaggle_questions['data']):
    print("Answering questions to task {} ...".format(i+1))
    answers_to_question = []
    for q in t['questions']:
        res = get_results(q, all_contexts, verbose=False)
        answers_to_question.append(res)
    task = {}
    task['task'] = t['task']
    task['questions'] = answers_to_question
    
    all_tasks.append(task)
    print(" answerd")

all_answers = {}
all_answers['data'] = all_tasks

task = 1
display_single_task(task, all_tasks[task-1])

Answering questions to task 1 ...
 answerd


CPU times: user 3.1 s, sys: 12.7 ms, total: 3.11 s
Wall time: 3.11 s


## Summarization 

I tried to use another language model to summarize the concatenation of some contexts, but the results were very poor. I think it relates to the fact that summarization of those kind of texts is hard in general and to do summarization on concatenated text passages, from different texts makes this even harder.

In [172]:
from transformers import BartTokenizer, BartForConditionalGeneration, AutoModelWithLMHead

#text = documents[0].body_text
text = " ".join(potential_contexts)

model_used = "bart-large-cnn"

tokenizer_summarize = BartTokenizer.from_pretrained(model_used)
model_summarize = BartForConditionalGeneration \
        .from_pretrained(model_used).to(torch_device)


model_summarize.to(torch_device)
_ = model_summarize.eval()


In [173]:
out = tokenizer_summarize.batch_encode_plus(
    [text], return_tensors='pt', max_length=1024
)
input_ids = out['input_ids']
attention_mask = out['attention_mask']

input_ids = input_ids.to(torch_device)
attention_mask = answers_input_ids.to(torch_device)

In [174]:
text[:500]

'referent hospitals in spain hospital clinic hospital del mar hospital universitari bellvitge hospital sant pau and hospital vall dhebron located in barcelona and hospital de poniente in almeria province aimed to evaluate the prevalence and risk factors associated with strongyloides stercoralis infection particularly related to immunosuppressed patients although systematic screening has not been widely implemented at national level it was established that individuals either hospitalized or be ual'

In [178]:
beam_outputs = model_summarize.generate(input_ids,
                                       attention_mask=attention_mask,
                                       num_beams=10,
                                       max_length=500,
                                       early_stopping=True,
                                       no_repeat_ngram_size=2,
                                      )

for i, beam_output in enumerate(beam_outputs):
    summary = tokenizer_summarize.decode(beam_output,
                                         skip_special_tokens=True,
                                         clean_up_tokenization_spaces=True)
    print("{}: {}".format(i, summary.replace(u'\xa0', u' ')))

0: Strongyloides stercoralis infection particularly related to immunosuppressed patients although systematic screening has not been widely implemented at national level it was established that individuals either hospitalized or be uals seems to be lower but these are limited data. High prevalence found in our study supports the need of screening strategies in patients that are potentially immuno-sick.


Some resources used:

    - https://huggingface.co/blog/how-to-generate
    - https://www.kaggle.com/jonathanbesomi/a-qa-model-to-answer-them-all#9.-Export-solutions