## Via pattern matching over BERT to a summary on risk factors of COVID-19

### The offical proposed question which I want to find answers for with this notebook is "What do we know about COVID-19 risk factors?"

We do so by first find relevant text passages via pattern matching and then apply BERT on those in the question answering fashion. On the answers we then apply BART to summarize those.

This notebook is based on my previous kernel [Pattern Matching to find risk factors](https://www.kaggle.com/n3xtvision/riskfactors-patternmatching) and another kernel which is not made by me [A QA model to answer them all](https://www.kaggle.com/jonathanbesomi/a-qa-model-to-answer-them-all#9.-Export-solutions)

In [1]:
import numpy as np
import pandas as pd
import glob
import os
import json
import string
from collections import Counter
import nltk
nltk.download('punkt')
nltk.download('stopwords')
from nltk.corpus import stopwords as nltkstopwords
from nltk.tokenize import word_tokenize
import re
from gensim.parsing.preprocessing import preprocess_string, strip_tags, strip_punctuation, strip_numeric, remove_stopwords, strip_multiple_whitespaces
import matplotlib.pyplot as plt
import torch
from transformers import BertTokenizer, AutoTokenizer, BertForQuestionAnswering, AutoModelForQuestionAnswering

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Setting up all the paths

In [2]:
base_path = "../input/CORD-19-research-challenge"

biorxiv_medrxiv = "biorxiv_medrxiv/biorxiv_medrxiv/pdf_json"
comm_use_subset = "comm_use_subset/comm_use_subset/pdf_json"
noncomm_use_subset = "noncomm_use_subset/noncomm_use_subset/pdf_json"
custom_license = "custom_license/custom_license/pdf_json"

Setting up questions (they should all be related to risk factors, because we only search those texts.

In [3]:
questions = {
    "questions": [
        "risk factors covid 19",
        "risk factors corona",
        "What risk factors contribute to the severity of 2019-nCoV?",
        "What do we know about COVID-19 risk factors?",
    ]
}

In [4]:
class Document:
    """
    Helper class to hold information about a document
    """
    def __init__(self, title, file_path, body_text, abstract):
        self.title = title
        self.file_path = file_path
        self.body_text = body_text
        self.abstract = abstract

Reading in all data. We are only interested in the body text, however we are going to store title, abstract text, body text and path.

In [5]:
%%time
files = []
file_names = []
documents = []
def read_files(directory):
    files_in_dir = [f for f in glob.glob(os.path.join(base_path, directory) + "/*.json", recursive=True)]
    files.extend(files_in_dir)
    for file_path in files_in_dir:
        data = json.load(open(file_path))
        body_text = ""
        for i in range(len(data["body_text"])):
            body_text += " " + data["body_text"][i]["text"]
        
        body_text = re.sub(' +', ' ', body_text)
        abstract_text = ""
        for i in range(len(data["abstract"])):
            abstract_text += " " + data["abstract"][i]["text"]
        title = data["metadata"]["title"]
        documents.append(Document(title, file_path, body_text, abstract_text))
    return len(files_in_dir)

print("Number of biorxiv_medrxiv documents: {}".format(read_files(biorxiv_medrxiv)))
print("Number of comm_use_subset documents: {}".format(read_files(comm_use_subset)))
print("Number of noncomm_use_subset documents: {}".format(read_files(noncomm_use_subset)))
print("Number of custom license documents: {}".format(read_files(custom_license)))
print("Total number of documents: {}".format(len(documents)))

Number of biorxiv_medrxiv documents: 1625
Number of comm_use_subset documents: 9524
Number of noncomm_use_subset documents: 2490
Number of custom license documents: 26505
Total number of documents: 40144
CPU times: user 2min 38s, sys: 12.8 s, total: 2min 51s
Wall time: 3min 37s


In [6]:
%%time
# filter duplicates where text is exactly the same
counter = 0

abstract_lengths = {}
text_lengths = {}
duplicates = []
for i in range(len(documents)):
    text_a = documents[i].body_text
    for j in range(i+1, len(documents)):
        text_b = documents[j].body_text
        if text_a == text_b:
            duplicates.append(j)
print("Found {} duplicates".format(len(duplicates)))
filtered_documents = [doc for idx, doc in enumerate(documents) if idx not in duplicates]

Found 530 duplicates
CPU times: user 9min 43s, sys: 0 ns, total: 9min 43s
Wall time: 9min 43s


### Search papers by regex patterns

Let us define some regex patterns to find text passages where risk factors might be mentioned.
To do so we are looking for phrases/words which might appear in such a context.

In [7]:
# set up some constants
NUMBER_OF_FILES = len(filtered_documents)

#set up patterns, which might appear in the context of high risk factors
PATTERNS = [
    r"(factor(.){0,9}risk)", # for example for "factors of risk"
    r"(risk(.){0,4}factor)", # for example for "risk factors"
    r"(high(.){0,6}risk)", # for example for "high risk" or "highly risky"
    r"(advers(.){0,4}outcome)"
]

POTENTIAL_RISKS = [
    "smoking",
    "pulmonary diseas", 
    "elder",
    "diabetes",
    "old",
    "age",
    "cancer", 
    "cardiac",
    "cardio"]

CUSTOM_FILTERS = [lambda x: x.lower(), strip_tags, strip_punctuation]
CUSTOM_FILTERS_EXCLUDE_NUMERIC = [lambda x: x.lower(), strip_tags, strip_punctuation, strip_numeric]

Now we are going through the texts and search for the patterns defined above.
We are going to append start and end indices of each matche to a list.

In [8]:
%%time
def extract_windows_containing(docs, pattern, print_out=True):
    indices = []
    for idx in range(NUMBER_OF_FILES):
        body_text = docs[idx].body_text
        
        indices_of_file = [(m.start(0), m.end(0)) for m in re.finditer(pattern, body_text)]
        indices.append(indices_of_file)
    
    return indices

indices = [[] for _ in range(NUMBER_OF_FILES)]
for pattern in PATTERNS:
    indices_ = extract_windows_containing(filtered_documents, pattern)
    for i in range(len(indices_)):
        indices[i].extend(indices_[i])

print("Found {} candidate text extracts".format(len([1 for a in indices if len(a)!=0])))

Found 9077 candidate text extracts
CPU times: user 10.7 s, sys: 0 ns, total: 10.7 s
Wall time: 10.7 s


Till now only have the indices of where the patterns matched.
Now we go through them and get a text window around them, we just choose a window of 500 characters (WINDOW_SIZE), which means that we extract at most 250 characters before and after the match.

While iterating through all matches we remove special characters like punctuation and numeric values.

In [9]:
%%time

WINDOW_SIZE = 500
def process_indices_for_file(file_number, indices_of_file, filters, docs):
    contexts = []
    for match in indices_of_file:
        start = max(0, match[0]-int(WINDOW_SIZE/2))
        end = min(len(docs[file_number].body_text), match[1]+int(WINDOW_SIZE/2))
        context = docs[file_number].body_text[start:end]
        contexts.append(context)

    return contexts

def remove_special_character(text):
    return text.translate(str.maketrans('', '', string.punctuation))

def tokenize(text):
    words = nltk.word_tokenize(text)
    return [str(word).lower() for word in words if len(word) > 1 and not word.isnumeric()]

potential_contexts = []
length_of_longest_context = 0
for file_number, indices_of_file in enumerate(indices):
    if len(indices_of_file) != 0:
        for context in process_indices_for_file(file_number, indices_of_file, CUSTOM_FILTERS_EXCLUDE_NUMERIC, filtered_documents):
            processed_context = " ".join(tokenize(remove_special_character(context)))
            if len(processed_context) > length_of_longest_context:
                length_of_longest_context = len(processed_context)
            potential_contexts.append(processed_context)

print("Length of longest context:", length_of_longest_context)
print("Number of potential_contexts:", len(potential_contexts))

Length of longest context: 513
Number of potential_contexts: 32864
CPU times: user 33.8 s, sys: 0 ns, total: 33.8 s
Wall time: 33.8 s


Now we got a set of contexts which hopefully contains information about risk factors of COVID-19.
Next we go again through the matches but now we only collect those contexts which contains one of our previously defined POTENTIAL_RISKS .

In [10]:
%%time
def refined_contexts():
    print_counter = 0
    overall_processed_indices_count = 0
    refined_potential_contexts = []
    for file_number, indices_of_file in enumerate(indices):
        if len(indices_of_file) != 0:
            matches = process_indices_for_file(file_number, indices_of_file, CUSTOM_FILTERS, filtered_documents)
            processed_indices_of_file = []
            indices_of_file = sorted(indices_of_file)
            # adjust indices to avoid big overlaps
            for i in range(len(indices_of_file)):
                if i != 0 and len(processed_indices_of_file) != 0:
                    if abs(processed_indices_of_file[-1][0] - indices_of_file[i][0]) > 100 and abs(processed_indices_of_file[-1][1] - indices_of_file[i][1]) > 100:
                        processed_indices_of_file.append(indices_of_file[i])
                    else:
                        min_ = min(indices_of_file[i][0], processed_indices_of_file[-1][0])
                        max_ = max(indices_of_file[i][1], processed_indices_of_file[-1][1])
                        del processed_indices_of_file[-1]
                        processed_indices_of_file.append((min_, max_))
                else:
                    processed_indices_of_file.append(indices_of_file[i])
            overall_processed_indices_count += len(processed_indices_of_file)

            for index, match in zip(processed_indices_of_file, matches):
                for pattern in POTENTIAL_RISKS:
                    if pattern in match:
                        windows_size = max(int(len(" ".join(match))/2), WINDOW_SIZE)
                        text = filtered_documents[file_number].body_text
                        start_idx = max(0, index[0]-windows_size)
                        end_idx = min(len(text), index[1]+windows_size)
                        refined_potential_contexts.append(text[start_idx:end_idx])
                        break
    return overall_processed_indices_count, refined_potential_contexts

overall_processed_indices_count, refined_potential_contexts = refined_contexts()

CPU times: user 735 ms, sys: 0 ns, total: 735 ms
Wall time: 732 ms


In [11]:
print("Total number of possible text passages about risk factors: {}".format(overall_processed_indices_count))

Total number of possible text passages about risk factors: 31289


### Appying BERT for question answering

Now we apply BERT on the found contexts to answer the defined questions, we will apply it seperatly first on the first contexts and then on the refined ones.

I tested different models and by visual inspection the  "ktrapeznikov/scibert_scivocab_uncased_squad_v2" seems to work best.

tested models :
 - allenai/scibert_scivocab_uncased
 - bert-large-uncased-whole-word-masking-finetuned-squad
 - mrm8488/scibert_scivocab-finetuned-CORD19
 - ktrapeznikov/scibert_scivocab_uncased_squad_v2
 - ahotrod/roberta_large_squad2

In [12]:
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

MODEL_AND_TOKENIZER = "ktrapeznikov/scibert_scivocab_uncased_squad_v2"

tokenizer = AutoTokenizer.from_pretrained(MODEL_AND_TOKENIZER)
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_AND_TOKENIZER)

model = model.to(torch_device)
model.eval()


def generate_answer(question, context):
    encoded_dict = tokenizer.encode_plus(
                        question, context,
                        add_special_tokens = True,
                        max_length = 256,
                        pad_to_max_length = True,
                        return_tensors = 'pt'
                   )
    
    input_ids = encoded_dict['input_ids'].to(torch_device)
    token_type_ids = encoded_dict['token_type_ids'].to(torch_device)
    
    start_scores, end_scores = model(input_ids, token_type_ids=token_type_ids)

    all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    start_index = torch.argmax(start_scores)
    end_index = torch.argmax(end_scores)
    
    answer = tokenizer.convert_tokens_to_string(all_tokens[start_index:end_index+1])
    answer = answer.replace('[CLS]', '')
    answer = answer.replace('[PAD]', '')
    return answer

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1219.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=227845.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=23.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=439730252.0, style=ProgressStyle(descri…




In [13]:

def create_output_results(question, 
                          all_contexts, 
                          all_answers):
    
    def find_start_end_index_substring(context, answer):   
        search_re = re.search(re.escape(answer.lower()), context.lower())
        if search_re:
            return search_re.start(), search_re.end()
        else:
            return 0, len(context)
        
    output = {}
    output['question'] = question
    results = []
    for c, a in zip(all_contexts, all_answers):
        span = {}
        span['context'] = c
        span['answer'] = a
        span['start_index'], span['end_index'] = find_start_end_index_substring(c,a)
        results.append(span)
    
    output['results'] = results
        
    return output

    
def get_results(question,
                contexts):
    answers = []
    for context in contexts:
        answers.append(generate_answer(question, context))
    
    return create_output_results(question, contexts, answers)

from IPython.display import display, Markdown, Latex, HTML

def layout_style():
    style = """
        div {
            color: black;
        }
        .single_answer {
            border-left: 3px solid #dc7b15;
            padding-left: 10px;
            font-family: Arial;
            font-size: 16px;
            color: #777777;
            margin-left: 5px;

        }
        .answer{
            color: #dc7b15;
        }
        .question_title {
            color: grey;
            display: block;
            text-transform: none;
        }      
        div.output_scroll { 
            height: auto; 
        }
    """
    return "<style>" + style + "</style>"

def dm(x): display(Markdown(x))
def dh(x): display(HTML(layout_style() + x))
    
def display_single_context(context, start_index, end_index):    
    before_answer = context[:start_index]
    answer = context[start_index:end_index]
    after_answer = context[end_index:]

    content = before_answer + "<span class='answer'>" + answer + "</span>" + after_answer

    return dh("""<div class="single_answer">{}</div>""".format(content))

def display_question_title(question):
    return dh("<h2 class='question_title'>{}</h2>".format(question.capitalize()))


def display_all_contexts(index, question):
    def answer_not_found(context, start_index, end_index):
        return (start_index == 0 and len(context) == end_index) or (start_index == 0 and end_index == 0)

    display_question_title(str(index + 1) + ". " + question['question'].capitalize())
    
    # display context
    for i in question['results']:
        if answer_not_found(i['context'], i['start_index'], i['end_index']):
            continue # skip not found questions
        display_single_context(i['context'], i['start_index'], i['end_index'])

def display_result(result):
    for i, question in enumerate(result):
        display_all_contexts(i, question)

Now we answer questions based on the found contexts.
We only use the first 100, to speed things up.

In [14]:
%%time
contexts = potential_contexts[:100]
result_on_potential_contexts = []
for q in questions['questions']:
    res = get_results(q, contexts)
    result_on_potential_contexts.append(res)
display_result(result_on_potential_contexts)

CPU times: user 7min 25s, sys: 1min 16s, total: 8min 42s
Wall time: 4min 22s


Answer questions based on refined contexts, again we only use the first 100 contexts.

In [15]:
%%time
contexts = refined_potential_contexts[:100]
result_on_refined_contexts = []
for q in questions['questions']:
    res = get_results(q, contexts)
    result_on_refined_contexts.append(res)
display_result(result_on_refined_contexts)

CPU times: user 7min 25s, sys: 1min 17s, total: 8min 42s
Wall time: 4min 23s


We can see that it seems to work sometimes, I don't think we can generally say on which contexts it works better.

Remark:
We applied it only on the first 100 contexts without any further selection on those, so one might consider to implement a ranking of contexts and apply the QA model only on those.

## Summarization 

I used another language model to summarize the concatenation of all answers for each question separately, the results are not so good. I think it relates to the fact that summarization of those kind of texts is hard in general and to do summarization on concatenated text passages, from different texts makes this even harder.

In [16]:
from transformers import BartTokenizer, BartForConditionalGeneration, AutoModelWithLMHead
torch_device = "cpu"
text = " ".join(potential_contexts[500:1000])

model_used = "bart-large-cnn"

tokenizer_summarize = BartTokenizer.from_pretrained(model_used)
model_summarize = BartForConditionalGeneration.from_pretrained(model_used)

model_summarize = model_summarize.to(torch_device)
model_summarize.eval()
print()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1497.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1625270765.0, style=ProgressStyle(descr…





In [17]:
def generate_summary(text):
    out = tokenizer_summarize.batch_encode_plus(
        [text], return_tensors='pt', max_length=512
    )

    input_ids = out['input_ids'].to(torch_device)
    attention_mask = out['attention_mask'].to(torch_device)

    beam_outputs = model_summarize.generate(input_ids,
                                           attention_mask=attention_mask,
                                           num_beams=5,
                                           max_length=256,
                                           early_stopping=True,
                                          )

    summary = tokenizer_summarize.decode(beam_outputs[0],
                                         skip_special_tokens=True,
                                         clean_up_tokenization_spaces=True)
    summary = summary.replace(u'\xa0', u' ')
    summary = summary.replace('[CLS]', '')
    return summary

In [18]:
for question in result_on_potential_contexts:
    print("Question:", question["question"])
    results = question["results"]
    answers = ""
    count = 0
    for result in results:
        answer = result["answer"]
        if answer != "":
            answers += " " + answer
            count += 1
    print("\tNumber of answers: {}".format(count))
    print("\tLength of concatenation: {}".format(len(answers)))   
    answers = answers.replace('[SEP]', '')
    print("\n\tConcatenated answers:", answers)
    summary = generate_summary(answers)
    print("\n\tSummary:", summary)
    print("\n\n")


Question: risk factors covid 19
	Number of answers: 15
	Length of concatenation: 3489

	Concatenated answers:  older age covid 19  ly we observed higher gene expression of ace2 in cells infected by rsv or merscov we speculated that patients simultaneously infected by other viruses may have higher expression of ace2 indicating those who bear coinfection are possibly at higher risk and are more susceptible to sarscov2 infection it may explain the high incidence of hospitalrelated transmissioninfection of sarscov2 in conclusion our findings indicate that patients with pulmonary fibrosis heart failure and virus in          blood group was associated with lower risk of death compared with nono groups with an or of ci                  herd immunity policy will face substantial mortality among high risk population we propose that policy that expects herd immunity are dangerous in the areas indicated in brown red and magenta in fig in areas marked in yellow policies that balance social closure


	Summary: Cardiovascular disease and diabetes put patients at higher risk of mortality. identification of novel risk factors predictive for patients outcome including mortality is needed. Using publicly available clinical data from kaggle we have employed machine learning tool to identify the risk factors that could potentially contribute to the mortality of older age and delayed hospitalisation.



Question: What do we know about COVID-19 risk factors?
	Number of answers: 20
	Length of concatenation: 5632

	Concatenated answers:   what do we know about covid - 19 risk factors ?  pneumonia leading to respiratory failure epidemiological evidence suggests that older age and the associated comorbidities such as cardiovascular disease and diabetes put patients at higher risk of mortality thus identification of novel risk factors predictive for patients outcome including mortality is needed here using the publicly available clinical data from kaggle we have employed machine learning tool t

In [19]:
for question in result_on_refined_contexts:
    print("Question:", question["question"])
    results = question["results"]
    answers = ""
    count = 0
    for result in results:
        answer = result["answer"]
        if answer != "":
            answers += " " + answer
            count += 1
    print("\tNumber of answers: {}".format(count))
    print("\tLength of concatenation: {}".format(len(answers)))   
    answers = answers.replace('[SEP]', '')
    print("\n\tConcatenated answers:", answers)
    summary = generate_summary(answers)
    print("\n\tSummary:", summary)
    print("\n\n")


Question: risk factors covid 19
	Number of answers: 19
	Length of concatenation: 2378

	Concatenated answers:  older age , high sequential organ failure assessment ( sofa ) score and elevated d - dimer levels  risk factors covid 19  ations of in - hospital death were found to be present with older age , high sequential organ failure assessment ( sofa ) score and elevated d - dimer levels very high risk of spread and impact of covid - 19 people aged between 21 and 60 years old persons with wuhan travel histories blood group a was associated with a significantly higher risk for covid - 19 age and underlying diseases  risk factors covid 19  tes were decreased ( 87 . 5 % ) , suggesting that the rising of neutrophils , pct , crp , ctni , d - dimer and ldh levels can be used as indicators of disease progression , as well as the decline of lymphocytes counts . this was a small sample size retrospective study , which was limited by the small numbers of patients and by using a retrospective met


	Summary: What risk factors contribute to the severity of 2019 - ncov?  mitation. the time between introduction 144 of zikv and the actual publication of a research study is dependent on factors both within and between 145 study designs. patients with pulmonary fibrosis, heart failure, and virus infection medical comorbidities.



Question: What do we know about COVID-19 risk factors?
	Number of answers: 6
	Length of concatenation: 1062

	Concatenated answers:   what do we know about covid - 19 risk factors ?  s with covid - 19 . mortality of critically ill patients of covid - 19 is high and co - morbidities including hypertension , diabetes and coronary artery disease are often present in hospitalised patients . though 48 % of the non - survivors had a co - morbid disease , in multivariate analyses , independent associations of in - hospital death were found to be present with older age , high sequential organ failure assessment ( sofa ) score and elevated d - dimer levels ( 6 ) . an

## Results

The first issue is the extraction of relevant texts, which was here done via pattern matching, as shown in my other kernel/notebook this extracts some good but also many bad text passages. To improve the extraction of relevant text one can consider better search/ranking algorithms like BM25.
This step of filtering out non relevant text is very important not only because it improves the performance of the downstream QA model, but also because it reduces the amount of data massively, without it we can not apply a huge language model.

One problem of the QA approach is that the model is not perfectly suited for this problem setting, i.e. even though we used SCIBERT, which was trained on scientific articles, this model was not trained for question answering. There are models which are trained on SQUAD, which is a question answering dataset, but those data is not from the medical domain.

I think the performance can be improved by first use a better ranking algorithm to identify relevant documents/text passages and then train/retrain a language model on scientific/medical question answering.

The summarization part is probably hard to improve, of course we also should be able to improve performance by using a model which was trained on domain specific data, but summarizing of the non contiguous texts seems to be quite hard.

Some resources used:

    - https://huggingface.co/blog/how-to-generate
    - https://www.kaggle.com/jonathanbesomi/a-qa-model-to-answer-them-all#9.-Export-solutions (parts of this notebook are based on that, not my work)
    - https://www.kaggle.com/n3xtvision/riskfactors-patternmatching (parts of this notebook are based on that, my work)