In [1]:
import pandas as pd
import json
import numpy as np
from collections import Counter
import pickle
from tqdm import tqdm
from transformers import AutoModel

import torch
from torch.utils.data import DataLoader
from transformers import AdamW
from transformers import BertTokenizer, TFBertForSequenceClassification, BertConfig ,DistilBertTokenizerFast, DistilBertForQuestionAnswering
from transformers import AutoModel, AutoTokenizer, AutoModelForQuestionAnswering
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
import gradio as gr

import torch
import torch.nn as nn
import torch.nn.functional as F

np.random.seed(1)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

# Load Data

In [2]:
medication_samples = pd.read_parquet('processed_data/medication_qa.parquet')
medication_samples['context_question'] = medication_samples.apply(lambda x: ' '.join(list(x.sub_context)+list(x.question)), axis=1)
relations_samples = pd.read_parquet('processed_data/relations_qa.parquet')
relations_samples['context_question'] = relations_samples.apply(lambda x: ' '.join(list(x.sub_context)+list(x.question)), axis=1)

dataset_to_samples = {
    'Medication': medication_samples,
    'Relations': relations_samples
}

In [3]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
tokenizer_bio = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

In [4]:
class DistilBERTEncoder(torch.nn.Module):
    def __init__(self, frozen=True):
        super(DistilBERTEncoder, self).__init__()
        self.encoder = DistilBertModel.from_pretrained('distilbert-base-uncased', output_hidden_states = True)
        self.encoder.to(device)
        if frozen:
            self.encoder.requires_grad = False
            self.encoder.eval()

    def forward(self, input_ids, attention_mask):
        output = self.encoder(input_ids, attention_mask = attention_mask)
        embedding = output.last_hidden_state # [batch, 512, 3072]

        return embedding
    
    
class SimpleReader(torch.nn.Module):
    def __init__(self, in_features=768, out_features=1):
        super(SimpleReader, self).__init__()
        self.encoder = DistilBERTEncoder(frozen=False)
        self.linear = nn.Linear(in_features=in_features, out_features=out_features)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        embeddings = self.encoder(input_ids, attention_mask = attention_mask)
        embedding_first_token = torch.squeeze(embeddings[:, 0, :], axis = 1) # [batch, 3072]    
        linear = self.linear(embedding_first_token) # [batch, 1] 
        logit = self.sigmoid(linear)  # [batch, 1]   
        return logit

In [5]:
medication_models = {
    'EHReader': (
        tokenizer, 
        [
            torch.load('models/SimpleReader/m_1_uf_e_2_vl_0.1676.model').to(device),
            torch.load('models/BERT/m_1_uf_e_6_vl_0.4139.model').to(device)
        ],
        -0.55
    ),
    'distilBERT': (
        tokenizer, 
        [
        torch.load('models/BERT/m_1_uf_e_6_vl_0.4139.model').to(device), 
        ],
        -.63
    ),
    'bioBERT': (
        tokenizer_bio,
        [
        torch.load('models/bioBERT/m_1_uf_e_8_vl_0.7559.model').to(device),
        ],
        .51
    ),
}

relations_models = {
    'EHReader': (
        tokenizer, 
        [
        torch.load('models/SimpleReader/r_1_uf_e_2_vl_0.1141.model').to(device),
        torch.load('models/BERT/r_1_uf_e_7_vl_0.1851.model').to(device), 
        ],
        -0.43
    ),
    'distilBERT': (
        tokenizer, 
        [
        torch.load('models/BERT/r_1_uf_e_7_vl_0.1851.model').to(device),
        ],  
        -.65
    ),
    'bioBERT': (
        tokenizer_bio, 
        [
        torch.load('models/bioBERT/r_1_uf_e_2_vl_0.4504.model').to(device), 
        ],    
        -0.59
    ),
}

dataset_to_models = {
    'Medication': medication_models,
    'Relations': relations_models
}

# Backend

In [6]:
def find_truth(samples, context, question):
    context_question = context + ' ' + question
    samples = samples[samples.context_question == context_question]
    
    if len(samples) == 0:
        return 'not in dataset'
    
    else:
        char_start, char_end = samples.iloc[0].char_start, samples.iloc[0].char_end
        if char_start < char_end:
            return context[char_start: char_end]
        else:
            return "null"

        
def compute_tav_fast(start_logit, end_logit):
    start_probs, end_probs = (torch.softmax(start_logit, dim=0), torch.softmax(end_logit, dim=0))
    n = len(start_logit)
    s_null = start_probs[0] + end_probs[0]
    s_has = 0
    best_span = (0, 0)
    high_start_idx = 1
  
    for i in range(1, n):
        if start_probs[i] > start_probs[high_start_idx]:
            high_start_idx = i
        if start_probs[high_start_idx] + end_probs[i] > s_has:
            s_has = start_probs[high_start_idx] + end_probs[i]
            best_span = (high_start_idx, i)
            
    s_diff = s_null - s_has
    return s_diff.item(), best_span


def compute_s_ext(sr_logit):
    s_ext =  1 - 2*sr_logit.item()
    return s_ext


def compute_rv(sr_logit, start_logit, end_logit, weight=.5):
    s_ext = compute_s_ext(sr_logit)
    s_diff, best_span = compute_tav_fast(start_logit, end_logit)
    rv = weight * s_diff + (1-weight) * s_ext
    return rv, best_span


def model_question_answer(tokenizer, model_name, model, context, question, threshold):

    encodings = tokenizer.batch_encode_plus(
          [[context, question]], 
          return_tensors='pt',
          add_special_tokens = True,
          return_token_type_ids=True,
          padding='max_length',
          max_length=512,
          return_attention_mask = True,
          truncation='longest_first', 
        )
    input_ids = encodings['input_ids'].to(device)
    attention_mask = encodings['attention_mask'].to(device)
    
    if model_name == 'EHReader':
        sr_out = model[0](input_ids=input_ids, attention_mask=attention_mask)
        dr_out = model[1](input_ids=input_ids, attention_mask=attention_mask)
        
        start_logits = dr_out.start_logits
        end_logits = dr_out.end_logits
        
        rv, best_span = compute_rv(sr_out[0], start_logits[0], end_logits[0])
        if rv < threshold and best_span[1] > best_span[0]:
            char_start = encodings.token_to_chars(best_span[0]+1).start
            char_end = encodings.token_to_chars(best_span[1]+1).end
            answer = context[char_start: char_end]
        else:
            answer = "null"
    
    # TAV Models
    else:
        output = model[0](input_ids=input_ids, attention_mask=attention_mask)
        s_diff, best_span = compute_tav_fast(output.start_logits[0], output.end_logits[0])
        if s_diff < threshold and best_span[0] < best_span[1]:
            char_start = encodings.token_to_chars(best_span[0]+1).start
            char_end = encodings.token_to_chars(best_span[1]+1).end
            answer = context[char_start: char_end]
        else:
            answer = "null"
        
    return answer


def get_sample(index, dataset):
    samples = dataset_to_samples[dataset]
    sub_context = ' '.join(samples.iloc[index].sub_context)
    question = ' '.join(samples.iloc[index].question)
    return [sub_context, question, dataset]

In [7]:
pos_med = medication_samples[medication_samples.answerability == 1]
pos_med.head(50)

Unnamed: 0,note_id,note_question_id,sub_context,question,answerability,token_start,token_end,char_start,char_end,context_question
0,91996,0,"[RECORD, #91996, 150823816, |, DMC, |, 6095698...","[Was, the, patient, ever, prescribed, aspirin]",1,94,106,548,613,RECORD #91996 150823816 | DMC | 60956989 | | 9...
7,91996,1,"[RECORD, #91996, 150823816, |, DMC, |, 6095698...","[Is, there, history, of, use, of, fioricet.]",1,94,106,548,613,RECORD #91996 150823816 | DMC | 60956989 | | 9...
14,91996,2,"[RECORD, #91996, 150823816, |, DMC, |, 6095698...","[has, there, been, a, prior, tylenol]",1,94,106,548,613,RECORD #91996 150823816 | DMC | 60956989 | | 9...
21,91996,3,"[RECORD, #91996, 150823816, |, DMC, |, 6095698...","[Is, there, a, mention, of, of, valproic, acid...",1,196,208,1191,1254,RECORD #91996 150823816 | DMC | 60956989 | | 9...
29,91996,4,"[documented, recent, seizures;, history, of, a...","[Has, the, patient, ever, tried, medicines]",1,10,19,66,132,documented recent seizures; history of asthma ...
37,91996,5,"[of, chronic, headaches, for, which, she, was,...","[has, the, patient, had, fioricet]",1,8,20,50,115,of chronic headaches for which she was maintai...
44,91996,6,"[of, chronic, headaches, for, which, she, was,...","[Has, the, pt., ever, been, on, demerol, before]",1,52,62,307,374,of chronic headaches for which she was maintai...
51,91996,7,"[of, chronic, headaches, for, which, she, was,...","[Was, the, patient, ever, prescribed, percocet]",1,62,72,375,438,of chronic headaches for which she was maintai...
58,91996,8,"[of, chronic, headaches, for, which, she, was,...","[Has, this, patient, ever, been, prescribed, t...",1,62,72,375,438,of chronic headaches for which she was maintai...
65,91996,9,"[of, chronic, headaches, for, which, she, was,...","[has, the, patient, had, hydrochlorothiazide]",1,191,200,1141,1207,of chronic headaches for which she was maintai...


In [8]:
pos_rel = relations_samples[relations_samples.answerability == 1]
pos_rel.head(50)

Unnamed: 0,note_id,note_question_id,sub_context,question,answerability,token_start,token_end,char_start,char_end,context_question
4,329,0,"[2014-01-20, 10:46, PM, BLOOD, Glucose, -, 114...","[Has, the, patient, had, any, rll, pneumonia, ...",1,51,80,262,408,2014-01-20 10:46 PM BLOOD Glucose - 114 * Lact...
7,329,1,"[However, ,, when, asked, if, she, thinks, he,...","[Has, the, patient, had, any, mild, background...",1,71,79,377,426,"However , when asked if she thinks he would wa..."
17,329,2,"[His, family, was, notified, that, he, was, do...","[Has, the, patient, had, any, early, sepsis, a...",1,65,75,334,386,His family was notified that he was doing wors...
20,329,3,"[Physical, Exam, :, Vitals, :, T, 102.8, HR, 1...","[Has, the, patient, had, any, accessory, muscl...",1,18,42,79,212,Physical Exam : Vitals : T 102.8 HR 113 BP 157...
25,329,4,"[However, ,, when, asked, if, she, thinks, he,...","[When, was, the, patient, evaluated, for, chf,...",1,91,96,465,485,"However , when asked if she thinks he would wa..."
31,329,5,"[However, ,, when, asked, if, she, thinks, he,...","[When, was, the, patient, evaluated, for, deme...",1,113,122,573,618,"However , when asked if she thinks he would wa..."
37,329,6,"[However, ,, when, asked, if, she, thinks, he,...","[has, the, patient, had, lvef, for, chf]",1,91,96,465,485,"However , when asked if she thinks he would wa..."
43,329,7,"[However, ,, when, asked, if, she, thinks, he,...","[Did, the, patient, ever, have, baseline, mmse...",1,113,122,573,618,"However , when asked if she thinks he would wa..."
49,329,8,"[However, ,, when, asked, if, she, thinks, he,...","[What, work, up, has, been, done, for, the, pa...",1,91,96,465,485,"However , when asked if she thinks he would wa..."
55,329,9,"[However, ,, when, asked, if, she, thinks, he,...","[What, work, up, has, been, done, for, the, pa...",1,113,122,573,618,"However , when asked if she thinks he would wa..."


In [9]:
index = 101
dataset = 'Relations'
model_name = 'EHReader'

samples = dataset_to_samples[dataset]
sub_context = ' '.join(samples.iloc[index].sub_context)
question = ' '.join(samples.iloc[index].question)
truth = find_truth(samples, sub_context, question)

tokenizer, model, threshold = dataset_to_models[dataset][model_name]
answer = model_question_answer(tokenizer, model_name, model, sub_context, question, threshold)

print("Context:\n{}".format(sub_context))
print("\nQuestion:\n{}".format(question))
print("\nPredicted Answer:\n{}".format(answer))
print("\nGolden Truth:\n{}".format(truth))

Context:
His family was notified that he was doing worse and came in to visit him . It was his family 's wish that he be made comfortable . Based upon this wish the BIPAP was removed and he was placed on nasal cannula . He was written for a morphine drip however he expired before this could be initiated . Secondary issues Elevated lactate : This was concerning for early sepsis given RLL PNA . His blood pressures remained elevated and he was hydrated with NS conservatively , given h/o CHF and dyspnea . DM2 : He was covered with a Insulin drip while unstable CV : CAD : He was not taking any PO medications per Hunt Center Rehab so all medications were held . Medications on Admission : Meds : - Paxil - Terazosin - Glucophage - Glyburide - Atenolol - Imdur Discharge Disposition : Expired Discharge Diagnosis : Pneumonia Discharge Condition : expired Discharge Instructions : N/A Followup Instructions : N/A Daniel Monica MD 55-804 Completed by : Sheri John MD 30-456 2014-01-21 @ 0659 Signed el

# Interface

In [10]:
def interface_main(context, question, dataset):
    samples = dataset_to_samples[dataset]
    truth = find_truth(samples, context, question)
    
    models = dataset_to_models[dataset]
    answers = []
    for model_name, (tokenizer, model, threshold) in models.items():
        answers.append(model_question_answer(tokenizer, model_name, model, context, question, threshold))
        
    return truth, answers[0], answers[1], answers[2]

In [11]:
inputs = [
    gr.inputs.Textbox(lines=15, label="Context"),
    gr.inputs.Textbox(lines=3, label="Question"),
    gr.inputs.Radio(['Medication', 'Relations'], label="Dataset"),
]

outputs = [
    gr.outputs.Textbox(label="Golden Truth"),
    gr.outputs.Textbox(label="Quick Reader-DistilBERT(+RV) Answer"), 
    gr.outputs.Textbox(label="DistilBERT(+TAV) Answer"), 
    gr.outputs.Textbox(label="bioBERT(+TAV) Answer"),
    
]

examples = []
num_examples = 100
med_index = np.random.choice(pos_med.index, num_examples)
rel_index = np.random.choice(pos_rel.index, num_examples)

for i in range(num_examples):
    examples.append(get_sample(med_index[i], 'Medication'))
    examples.append(get_sample(med_index[i]+1, 'Medication')) 
    examples.append(get_sample(rel_index[i], 'Relations'))
    examples.append(get_sample(rel_index[i]+1, 'Relations')) 

interface = gr.Interface(
    interface_main,
    inputs,
    outputs,
    title="Question Answering for Electronic Health Records",
    examples=examples
)

interface.launch(share=True)

Running on local URL:  http://127.0.0.1:7860/
Running on public URL: https://56315.gradio.app

This share link will expire in 72 hours. To get longer links, send an email to: support@gradio.app


(<Flask 'gradio.networking'>,
 'http://127.0.0.1:7860/',
 'https://56315.gradio.app')