# Research and Implementation of multi-type question answering system in medical field based on machine reading comprehension

#### SPAN EXTRACTION - FINE TUNED MODEL ####

In [1]:
from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer

# LOCAL PATH
model_path = "medical_trained_model" #"Eladio/bert-medical-emrqa-squadv2" con huggingface

# UPLOAD MODEL
loaded_model = AutoModelForQuestionAnswering.from_pretrained(model_path)

In [2]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
qa_pipeline = pipeline("question-answering", model=loaded_model, tokenizer=tokenizer)

#### CLOZE TEST MODEL ####

In [3]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

__author__ = 'Petros' #Biomrc code

my_seed = 1989
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.autograd as autograd
import os
import sys
import numpy as np
import pickle
import torch.backends.cudnn as cudnn
import random
import json
import copy
from pytorch_pretrained_bert import BertTokenizer, BertModel
from tqdm import tqdm
from nltk import sent_tokenize

random.seed(my_seed)
torch.manual_seed(my_seed)
np.random.seed(my_seed)

cudnn.benchmark = True

embedding_dim = 30
hidden_dim = 100
gpu_device = 0
use_cuda = torch.cuda.is_available()
if (use_cuda):
    torch.cuda.manual_seed(my_seed)
    print("Using GPU")

In [4]:
#MAIN FUNCTIONS
def print_params(show_model=True):
    if show_model:
        print(40 * '=')
        print(model)
    print(40 * '=')
    total_params = 0
    print('Trainable Parameters\n')
    for parameter in model.parameters():
        if parameter.requires_grad:
            v = 1
            for s in parameter.size():
                v *= s
            total_params += v
    print(40 * '=')
    print(total_params)
    print(40 * '=')


class SciBertReaderSum(nn.Module):
    def __init__(self, frozen_top):
        super(SciBertReaderSum, self).__init__()

        self.tok = BertTokenizer.from_pretrained('./biomrc/scibert_scivocab_uncased/')
        self.bert = BertModel.from_pretrained('./biomrc/scibert_scivocab_uncased/')
        self.linear = nn.Linear(2 * 768, 100, bias=True)
        self.linear2 = nn.Linear(100, 1, bias=True)

        self.frozen_top = frozen_top

        for p in self.bert.parameters():
            p.requires_grad = False

        if use_cuda:
            self.bert = self.bert.cuda(gpu_device)
            self.linear = self.linear.cuda(gpu_device)
            self.linear2 = self.linear2.cuda(gpu_device)

    def freeze_top(self, optim):
        for p in self.bert.encoder.layer[-1].parameters():
            p.requires_grad = False
        for g in optim.param_groups:
            g['lr'] = 0.001
        self.frozen_top = True

    def unfreeze_top(self, optim):
        for p in self.bert.encoder.layer[-1].parameters():
            p.requires_grad = True
        for g in optim.param_groups:
            g['lr'] = 0.0001
        self.frozen_top = False

    def fix_input(self, abstract, title):
        ab_sents = sent_tokenize(abstract, 'english')
        ab_tok = [self.tok.tokenize(ab) for ab in ab_sents]
        entity_indices = list()
        entity_texts = list()
        for i in range(len(ab_tok)):
            ab_tok[i].insert(0, '[CLS]')
            entity_indices.append(list())
            entity_texts.append(list())
            for j in range(len(ab_tok[i])):
                if ab_tok[i][j] == '@':
                    if ab_tok[i][j+1] == 'entity':
                        entity_indices[-1].append(j)
            for j in entity_indices[-1]:
                n = j+1
                while ab_tok[i][n].startswith('##') or ab_tok[i][n] == 'entity':
                    n += 1
                    if n >= len(ab_tok[i]):
                        break
                entity_texts[-1].append(''.join(ab_tok[i][j:n]).replace('##', ''))
        ti_tok = self.tok.tokenize(title.replace('.XXXX', ' [MASK]').replace('XXXX', '[MASK]'))
        ti_tok.insert(0, '[SEP]')
        combined = list()
        mask_indices = list()
        for i in range(len(ab_tok)):
            combined.append(list())
            combined[i].extend(ab_tok[i])
            combined[i].extend(ti_tok)
            mask_indices.append(combined[i].index('[MASK]'))
        combined_inp = [torch.LongTensor(self.tok.convert_tokens_to_ids(c)).unsqueeze(dim=0) for c in combined]

        if use_cuda:
            combined_inp = [e.cuda(gpu_device) for e in combined_inp]

        return combined_inp, mask_indices, entity_indices, entity_texts

    def forward(self, abstract, title, entity_list, answer, ignore_big=True):
        combined_input, mask_indices, entity_indices, entity_texts = self.fix_input(abstract, title)
        max_len = max([e.shape[-1] for e in combined_input])
        if ignore_big and max_len > 512:
            return None, None
        # Pad combined input
        combined_input = torch.stack([F.pad(e, [0, max_len - e.shape[-1], 0, 0], 'constant', 0) for e in combined_input], dim=0).squeeze(dim=1)
        if use_cuda:
            combined_input = combined_input.cuda(gpu_device)
        out = list()
        bert_out = self.bert(combined_input)[0][-1]
        for i in range(combined_input.shape[0]):
            if len(entity_indices[i]) == 0:
                continue
            if len(entity_indices[i]) != 1:
                bert_out_entities = bert_out[i, entity_indices[i], :].squeeze(dim=0)
            else:
                bert_out_entities = bert_out[i, entity_indices[i], :]
            bert_out_mask = bert_out[i, mask_indices[i], :]
            bert_out_mask = bert_out_mask.expand_as(bert_out_entities)
            bert_out_concat = torch.cat([bert_out_entities, bert_out_mask], dim=-1)
            out.append(self.linear2(F.relu(self.linear(bert_out_concat))))
        entity_texts = [e for e in entity_texts if len(e) != 0]
        # Predict also
        preds = dict()
        for et, ot in zip(entity_texts, out):
            ot = ot.detach().cpu().numpy().tolist()
            for e, o in zip(et, ot):
                if e not in preds:
                    preds[e] = 0
                preds[e] += o[0]
        entity_outs = dict()
        for ent in entity_list:
            entity_outs[ent] = list()
            for r_i, r in enumerate(entity_texts):
                for c_i, c in enumerate(r):
                    if c == ent:
                        entity_outs[ent].append(out[r_i][c_i])
        for ent in list(entity_outs):
            if len(entity_outs[ent]) == 0:
                del entity_outs[ent]
            else:
                entity_outs[ent] = torch.sigmoid(torch.sum(torch.cat(entity_outs[ent])))
        return torch.stack(list(entity_outs.values())), entity_outs.keys(), preds

    def predict(self, abstract, title, entity_list, ignore_big=True):
        combined_input, mask_indices, entity_indices, entity_texts = self.fix_input(abstract, title)
        max_len = max([e.shape[-1] for e in combined_input])
        if ignore_big and max_len > 512:
            return None
        # Pad combined input
        combined_input = torch.stack(
            [F.pad(e, [0, max_len - e.shape[-1], 0, 0], 'constant', 0) for e in combined_input], dim=0).squeeze(dim=1)
        if use_cuda:
            combined_input = combined_input.cuda(gpu_device)
        out = list()
        bert_out = self.bert(combined_input)[0][-1]
        for i in range(combined_input.shape[0]):
            if len(entity_indices[i]) == 0:
                continue
            if len(entity_indices[i]) != 1:
                bert_out_entities = bert_out[i, entity_indices[i], :].squeeze(dim=0)
            else:
                bert_out_entities = bert_out[i, entity_indices[i], :]
            bert_out_mask = bert_out[i, mask_indices[i], :]
            bert_out_mask = bert_out_mask.expand_as(bert_out_entities)
            bert_out_concat = torch.cat([bert_out_entities, bert_out_mask], dim=-1)
            out.append(self.linear2(F.relu(self.linear(bert_out_concat))))
        # Find predictions
        entity_texts = [e for e in entity_texts if len(e)!= 0]
        preds = dict()
        for et, ot in zip(entity_texts, out):
            ot = ot.detach().cpu().numpy().tolist()
            for e, o in zip(et, ot):
                if e not in preds:
                    preds[e] = 0
                preds[e] += o[0]
        return preds

In [5]:
# Load checkpoint
resume_from = './biomrc/model/scibertreadersum_best_checkpoint.pth.tar' #'./bert_model_results/scibertreadersum_checkpoint.pth.tar'
resumed = False
if os.path.exists(resume_from):
    checkpoint = torch.load(resume_from, map_location=torch.device('cpu')) # checkpoint = torch.load(resume_from)
    start_epoch = checkpoint['epoch'] + 1
    best_dev_acc = checkpoint['best_acc']
    best_epoch = checkpoint['best_epoch']
    frozen_t = checkpoint['frozen_top']
    early_stop_counter = checkpoint['early_stop']
    print("=> loaded checkpoint '{}' (epoch {})".format(resume_from, checkpoint['epoch']))
    print([(e, checkpoint[e]) for e in checkpoint.keys() if e!= 'state_dict' and e!='optimizer'])
    resumed = True
else:
    print('No checkpoint to load!')
    start_epoch = 0
    best_dev_acc = -1
    best_epoch = -1
    early_stop_counter = 0
    frozen_t = True


model = SciBertReaderSum(frozen_t)

if use_cuda:
    model.cuda(gpu_device)

cross_entropy = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print_params()

=> loaded checkpoint './biomrc/model/scibertreadersum_best_checkpoint.pth.tar' (epoch 0)
[('epoch', 0), ('best_acc', 0.56368), ('best_epoch', 0), ('frozen_top', True), ('early_stop', 0)]
SciBertReaderSum(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31090, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOut

In [6]:
#max_epochs = 40

# Try to resume
if resumed:
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    # Unfreeze top bert layer if we resumed with unfrozen top
    if not model.frozen_top:
        model.unfreeze_top(optimizer)
        print('BERT Top Layer is Unfrozen')

In [7]:
import scispacy
import spacy
from spacy import displacy

nlp = spacy.load("en_ner_bc5cdr_md")

##### PREPROCESSING #####

In [13]:
#PREPROCESSING BEFORE RUN THE MODEL
def preprocessing_cloze_test(context, question):
    context_ner = nlp(context)
    question_ner = nlp(question)
    
    #ENTITY CLASSIFICATION in CONTEXT GROUPING REPEATED ONES IN ORDER TO COUNT AND ASSIGN A LABEL
    entities_context = []
    entities_texts = []
    count = 0
    for ent in context_ner.ents:
        #print(f"Texto completo: {context2.text[ent.start_char:ent.end_char]}\n")
        if ent.text not in entities_texts:
            entity_item = {'label': f"@entity{count}", 'text': ent.text.lower(), 'entity': ent.label_, 'positions': [{'start_char': ent.start_char, 'end_char': ent.end_char}]}    
            entities_context.append(entity_item)
            entities_texts.append(ent.text.lower())
            count+=1
        else:
            index = entities_texts.index(ent.text.lower())
            entities_context[index]['positions'].append({'start_char': ent.start_char, 'end_char': ent.end_char})
    
    #REPLACING IN THE CONTEXT
    for item in entities_context:
        context = context.lower().replace(item['text'].lower(), item['label'])
        #for pos in item['positions']:
            #context = context[:pos['start_char']] + item['label'] + context[pos['end_char']:]
    
    
    
    #ENTITY CLASSIFICATION IN QUESTION GROUPING REPEATED ONES IN ORDER TO COUNT AND ASSIGN A LABEL
    entities_question = []
    #count = 0
    for ent in question_ner.ents:
        if ent.text not in entities_texts:
            entity_item = {'label': f"@entity{count}", 'text': ent.text, 'entity': ent.label_}    
            entities_context.append(entity_item)
            count+=1      
    
    entity_list = []        
    #REPLACING IN THE CONTEXT
    for item in entities_context:
        question = question.replace(item['text'].lower(), item['label'])
        entity_list.append(item['label']) 
    
    #entity_list = sorted(list(set(entity_list)))
    return context, question, entity_list, entities_context

##### PREDICTION #####

In [17]:
#METHODS
def get_predict(context, question, entities_list, entities_context):
    #print(f"context: {context}\n")
    #print(f"question: {question}\n")
    #print(f"entities_list: {entities_list}\n")
    #print(f"entities_context: {entities_context}\n")
    preds = model.predict(context, question, entities_list)
    best_prediction = max(preds, key=preds.get) #Maximum value
    for item in entities_context:
        if item['label'] == best_prediction:
            value_entity = item['text']
            break
    return best_prediction, value_entity

# PREDICTION
def cloze_test(context, question):#, entity_list
    context, question, entities_list, entities_context = preprocessing_cloze_test(context, question)
    best_prediction, value_entity = get_predict(context, question, entities_list, entities_context)
    return best_prediction, value_entity    

#### INPUTS - MODEL OPTIONS DEPENDING OF THE QUESTION ####

In [31]:
option = 1 # Span Extraction
#option = 2 # Cloze Test IMPORTANT -> 'XXXX' MUST BE FREE BEFORE AND AFTER. IT WILL CAUSE AN ERROR: 'XXXX,'

if option == 1:
    question = "Has the 76-year-old female patient with a history of mitral regurgitation, ever been on caltrate plus d2 tablets before"
elif option == 2:
    question = "Has the 76-year-old female patient with a history of XXXX ever been on caltrate plus d2 tablets before"

context = "The patient is a 76-year-old female with a history of mitral regurgitation, congestive heart failure, recurrent UTIs, and uterine prolapse who presented with chills and hypotension and was admitted to the Medical ICU for treatment of septic shock. Mean arterial pressures were kept above 65 with Levophed and antibiotics were changed to penicillin 3 million units IV q.4h. and gentamicin 50 mg IV q.8h. An ATEE on 10/19 showed severe mitral regurgitation with posterior leaflet calcifications and linear density concerning for endocarditis, for which a PICC line was placed on 1/19 for a six-week course of penicillin 3 million units IV q.4h. and two-week course of gentamicin 50 mg IV q.8h. until 2/25. The patient was initially treated with Levophed for her hypotension until 11/0, and was placed on Levofloxacin and Vancomycin to treat Gram-positive cocci bacteremia and UTI. She was maintained on telemetry and was found to be a normal sinus rhythm with ectopy, including short once of nonsustained ventricular tachycardia. She was started on Lopressor 12.5 mg t.i.d. on 3/18, and this was increased to 25 mg b.i.d. at discharge, with her heart rates continuing to be between the 70s and the 90s, however, with less episodes of ectopy. Aspirin was given, and Lipitor was initially held for an initial transaminitis presumed to be secondary to shock liver. She had guaiac positive stools in the medical ICU, her hematocrit was stable around 33%, and her iron studies suggested anemia of chronic disease with possibly overlying iron deficiency. She had a normal random cortisol level of 35.3, and her Hemoglobin A1c was 6.5, so she was maintained thereafter only on insulin sliding scale and rarely required any coverage. The patient was kept on Lovenox and Protonix and her DISCHARGE MEDICATIONS include Aspirin 81 mg daily, iron sulfate 325 mg daily, gentamicin sulfate 50 mg IV q.8h. until 2/25 for a two-week course, penicillin G potassium 3 million units IV q.4h. until 0/12 for a six-week course, Lopressor 25 mg b.i.d., Caltrate plus D2 tablets p.o. daily, Lipitor 10 mg daily, and Protonix 40 mg daily. She was discharged to rehabilitation at Acanmingpeerra Virg Tantblu Medical Center in order to be able to get her antibiotic therapy, and her physicians will attempt to add the ACE back onto her medical regimen for better afterload reduction as her blood pressure tolerates, and potentially they will add her back on to the Lasix as well. She will require weekly lab draws to check her electrolytes and CBC while she is on the antibiotics."

cloze_test_flag = question.find("XXXX")

if cloze_test_flag != -1:
    #CLOZE TEST
    print("CLOZE TEST\n")
    result = cloze_test(context, question)
else: 
    #SPAN EXTRACTION MODEL
    print("SPAN EXTRACTION\n")
    result = qa_pipeline(question=question, context=context)

print(result)

SPAN EXTRACTION

{'score': 0.14743152260780334, 'start': 2029, 'end': 2065, 'answer': 'Caltrate plus D2 tablets p.o. daily,'}


#### CONTEXT AND ANSWER CLASSIFICATION USING NER ####

In [32]:
nlp1 = spacy.load("en_core_web_sm")
nlp2 = spacy.load("en_ner_bc5cdr_md")
med7 = spacy.load("en_core_med7_lg")

In [33]:
# create distinct colours for labels
col_dict = {}
seven_colours = ['#e6194B', '#3cb44b', '#ffe119', '#ffd8b1', '#f58231', '#f032e6', '#42d4f4']
for label, colour in zip(med7.pipe_labels['ner'], seven_colours):
    col_dict[label] = colour

options = {'ents': med7.pipe_labels['ner'], 'colors':col_dict}

##### NER IN CONTEXT #####

In [34]:
context1 = nlp1(context)
context2 = nlp2(context)
context3 = med7(context)

##### QUESTION AND ANSWER #####

In [35]:
print(f"Question: {question}")
if option == 1:
    print(f"Answer: {result['answer']}")
elif option == 2:
    print(f"Answer: {result[1]}")

Question: Has the 76-year-old female patient with a history of mitral regurgitation, ever been on caltrate plus d2 tablets before
Answer: Caltrate plus D2 tablets p.o. daily,


##### NER IN ANSWER #####

In [36]:
if option == 1:
    ans1 = nlp1(result['answer'])
    ans2 = nlp2(result['answer'])
    ans3 = med7(result['answer'])
elif option == 2:
    ans2 = nlp2(result[1])

##### LOOK ANSWER ENTITIES #####

In [37]:
print("Answer entities:")

if option == 1:
    displacy.render(ans1, style='ent')
    displacy.render(ans2, style='ent')
    spacy.displacy.render(ans3, style='ent', jupyter=True, options=options)
elif option == 2:
    displacy.render(ans2, style='ent')

Answer entities:


##### LOOK CONTEXT ENTITIES #####

In [38]:
displacy.render(context1, style='ent')
displacy.render(context2, style='ent')
spacy.displacy.render(context3, style='ent', jupyter=True, options=options)