In [1]:
import pickle
import shelve
import marisa_trie
import pandas as pd
from datetime import datetime
import re
import json
import numpy as np

### Load rankers for labs, meds, symptoms, and conditions

In [2]:
from med_lab_autocomplete_utils import LabAutocomplete, MedAutocomplete

In [3]:
from symptom_utils.symptom_autocomplete_utils import SymptomAutocomplete
from symptom_utils.symptom_autocomplete_inferior_models import SymptomAutocompleteChiefComplaint, SymptomAutocompleteLR, SymptomAutocompleteNB

In [4]:
lab_autocomplete = LabAutocomplete()
medication_autocomplete = MedAutocomplete()

In [5]:
symptom_autocomplete = SymptomAutocomplete();
symptom_cc = SymptomAutocompleteChiefComplaint();
symptom_lr =  SymptomAutocompleteLR();
symptom_nb = SymptomAutocompleteNB();



### Load data
* ED data contains triage information and notes written in the emergency department, per patient (shuffled in a random order).
* OMR/EHR data refers to all prior clinical notes in a patient's record (keyed by PatientID) 

In [6]:
print('Loading ED and EHR (prior medical record) data...')
ed_visits_pkl = pickle.load(open('/data/BIDMC/ed_data/visits_full.pkl'))
omr_data = shelve.open('/data/BIDMC/ed_data/omr/omrShelfPatient_py3_jclinic')

Loading ED and EHR (prior medical record) data...


In [7]:
print('Load Chief Complaints per ED ')
with open('/data/BIDMC/jclinic/extracted_data/ed_chief_complaints.pkl', 'rb') as h:
    ed_chief_complaints = pickle.load(h)

Load Chief Complaints per ED 


### Open ontologies and create trie-based datastructure to find clinical concepts

In [8]:
hpi_ontology = pd.read_csv('ontologies/hpi_autocomplete_ontology.csv', index_col=0)
hpi_ontology['synonyms'] = hpi_ontology['synonyms'].apply(pd.eval)
symptom_ontology = pd.read_csv('ontologies/symptom_autocomplete_ontology.csv')
symptom_ontology['synonyms'] = symptom_ontology['synonyms'].apply(pd.eval)
with open('ontologies/medication_ontology.json', 'r') as h:
    med_list = json.load(h)['freq']
with open('ontologies/lab_ontology.json', 'r') as h:
    lab_list = json.load(h)['freq']

In [9]:
trie_keys = set() # set of synonyms to look for in the trie
term_lookup = {} # mapping from each synonym to (A, b) where A is its type and b is its index in its respective ontology

In [10]:
for i, med in enumerate(med_list):
    trie_keys.add(med)
    term_lookup[med] = ('MEDICATION', i)
for i, syns in enumerate(hpi_ontology['synonyms']):
    if hpi_ontology.loc[i]['ignore']:
        continue
    for s in syns:
        if len(s) < 3: # ignore short synonyms that might be ambiguous such as "MI" or "AS"
            continue 
        trie_keys.add(s)
        term_lookup[s] = ('DISEASE', i)
for i, syns in enumerate(symptom_ontology['synonyms']):
    if symptom_ontology.loc[i]['ignore']:
        continue
    for s in syns:
        trie_keys.add(s)
        term_lookup[s] = ('SYMPTOM', i)
for i, lab in enumerate(lab_list):
    trie_keys.add(lab)
    term_lookup[lab] = ('LAB', i)

In [11]:
search_trie = marisa_trie.Trie(trie_keys)

In [12]:
all_synonyms = {
    'DISEASE' : list(hpi_ontology['synonyms']), 
    'SYMPTOM' : list(symptom_ontology['synonyms']), 
    'LAB' : [[lab] for lab in lab_list],
    'MEDICATION' : [[med] for med in med_list]
}

### Find clinical concepts retrospectively in text

In [13]:
def find_concepts(md_comments):
    concepts = {'DISEASE' : [], 'SYMPTOM' : [], 'LAB' : [], 'MEDICATION' : []}
    blacklisted_toks = set()
    tokens = md_comments.lower().split()
    for i, tok in enumerate(tokens):
        if i in blacklisted_toks:
            continue
        future_txt = ' '.join(tokens[i:])
        potentials = search_trie.prefixes(unicode(future_txt))
        if not potentials:
            continue
        best = max(potentials, key=len)
        if len(future_txt) == len(best) or future_txt[len(best)] in ' ,;:.':
            concept = term_lookup[best][0]
            concepts[concept].append(best)
            for j in range(len(re.split('(\s|[\.!\?,;])', best))):
                blacklisted_toks.add(j + i)
    return concepts

### Measure MRR, MAPK, and keystroke burden 

In [14]:
def update_ranking(suggested, term): 
    suggested.remove(term)
    suggested.append(term)
    return suggested

def get_mrr(suggested, relevant):
    if suggested is None or relevant is None:
        return None
    curr_suggested = suggested[:]
    mrrs = []
    for b in relevant:
        mrr = (1.0/float(curr_suggested.index(b) - len(relevant) + 2) if curr_suggested.index(b) >= len(relevant) else 1.0)
        curr_suggested = update_ranking(curr_suggested, b)
        mrrs.append(mrr)
    return np.mean(mrrs)

def get_mapk(suggested, relevant):
    curr_suggested = suggested[:]
    mapks = []
    for r in relevant:
        mapk = len([i for i in curr_suggested[:curr_suggested.index(r) + 1] if i in relevant]) / float(curr_suggested.index(r) + 1)
        mapks.append(mapk)
    return np.mean(mapks)

def get_num_keystrokes_singlequery(suggested, relevant_term, relevant_text, concept_type, k=3):
    def word_in_set(word, syns):
        for s in syns:
            if word in s:
                return True
        return False
    for i in range(len(relevant_text)):
        word = relevant_text[:i]
        new_suggested = [s for s in suggested if word_in_set(word, all_synonyms[concept_type][s])]
        if relevant_term in new_suggested[:min(k, len(new_suggested))]:
            return i
    return len(relevant_text)

def get_num_keystrokes(suggested, relevant, relevant_text, concept_type):
    num_keystrokes = []
    curr_suggested = suggested[:]
    for r, r_text in zip(relevant, relevant_text):
        num_keystrokes.append(get_num_keystrokes_singlequery(curr_suggested, r, r_text, concept_type))
        curr_suggested = update_ranking(curr_suggested, r)
    return np.mean(num_keystrokes)

def get_metrics(ranking, ground_truth, ground_truth_text, concept_type):
    mrr = get_mrr(ranking, ground_truth)
    map_k = get_mapk(ranking, ground_truth)
    num_keystrokes = get_num_keystrokes(ranking, ground_truth, ground_truth_text, concept_type)
    return mrr, map_k, num_keystrokes

In [15]:
np.random.seed(0) # fix random seed so train/test split is preserved 
test_indices = np.random.choice(len(ed_visits_pkl), 25000)

### Get metrics for labs, meds

In [16]:
def run_all_autocompletes(visit_ix):
    res = {
        'LAB' : {'regular': None, 'spell' : None, 'freq' : None, 'contextual' : None},
        'MEDICATION' : {'regular': None, 'spell' : None, 'freq' : None, 'contextual' : None},
    } # initialize results dict with Nones
    visit = ed_visits_pkl[visit_ix] # get visit level information 
    md_comments = visit['MDcomments'][0] # get MD comment, or ED note 
    visit_date = datetime.strptime(visit['Date'][0][14:24], '%Y-%m-%d') # get visit date 
    pid = visit['PatientID'][0] # get patientID (how to key into OMR)
    if not md_comments:
        return res
    triage_assessment = visit['TriageAssessment'][0] # get triage assessment 
    vitals = {
        'Age' : visit['Age'][0],
        'Temp' : visit['TriageTemp'][0],
        'RR' : visit['TriageRR'][0],
        'Pulse' : visit['TriageHR'][0],
        'O2Sat' : visit['TriageSaO2'][0],
        'BP' : visit['TriageBP'][0] 
    } # get triage vitals 
    chief_complaint = list(ed_chief_complaints[visit_ix][-1]) # get chief complaint 
    chief_complaint = chief_complaint[0] if chief_complaint else None 
    concepts = find_concepts(md_comments)
    disease_concepts = concepts.get('DISEASE')
    symptom_concepts = concepts.get('SYMPTOM')
    lab_concepts = concepts.get('LAB')
    medication_concepts = concepts.get('MEDICATION')

    for concept_type, ranker in [ ('LAB', lab_autocomplete), ('MEDICATION', medication_autocomplete)]:
        ground_truth = [term_lookup[i][1] for i in concepts.get(concept_type)]
        ground_truth_text = concepts.get(concept_type)
        # for no autocomplete, spell based autocomplete, freq-based autocomplete, contextual autocomplete...
        # find (MRR, MAPK, keystroke burden) 
        res[concept_type]['regular'] = None, None, np.mean([len(i) for i in ground_truth_text])
        res[concept_type]['spell'] = get_metrics(ranker.get_spell_ranking(), ground_truth, ground_truth_text, concept_type)
        res[concept_type]['freq'] = get_metrics(ranker.get_frequency_ranking(), ground_truth, ground_truth_text, concept_type)
        res[concept_type]['contextual'] = None, None, None
    return res

In [19]:
lab_med_res = [run_all_autocompletes(i) for i in test_indices]

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


In [20]:
lab_med_df = {}
for dt in ['LAB', 'MEDICATION']:
    for rank_type in ['regular', 'spell', 'freq', 'contextual']:
        for metric_i, metric in enumerate(['mrr', 'map', 'keystrokes']):
            col_header = '{}_{}_{}'.format(dt, rank_type, metric)
            lab_med_df[col_header] = []
            for r in lab_med_res:
                if not r[dt][rank_type]:
                    lab_med_df[col_header].append(None)
                    continue
                lab_med_df[col_header].append(r[dt][rank_type][metric_i])

In [21]:
import numpy as np, scipy.stats as st
conf_interval = lambda a : st.t.interval(0.95, len(a)-1, loc=np.mean(a), scale=st.sem(a))

In [22]:
lab_med_df = pd.DataFrame(lab_med_df)

In [876]:
# lab_med_df.to_csv('labmed_only_analysis.csv')

### Test symptom autocompletes against each other 

In [17]:
def run_symptom_autocompletes(visit_ix):
    res = {
        'SYMPTOM_CC' : {'spell' : None, 'freq' : None, 'contextual' : None},
        'SYMPTOM_CCVIT' : {'spell' : None, 'freq' : None, 'contextual' : None},
        'SYMPTOM_LR' : {'regular': None, 'spell' : None, 'freq' : None, 'contextual' : None},
        'SYMPTOM_NB' : {'regular': None, 'spell' : None, 'freq' : None, 'contextual' : None},
    } # compare symptom models against each other, in the order presented in MLHC paper
    visit = ed_visits_pkl[visit_ix]
    md_comments = visit['MDcomments'][0]
    visit_date = datetime.strptime(visit['Date'][0][14:24], '%Y-%m-%d')
    pid = visit['PatientID'][0]
    if not md_comments:
        return res
    triage_assessment = visit['TriageAssessment'][0]
    vitals = {
        'Age' : visit['Age'][0],
        'Temp' : visit['TriageTemp'][0],
        'RR' : visit['TriageRR'][0],
        'Pulse' : visit['TriageHR'][0],
        'O2Sat' : visit['TriageSaO2'][0],
        'BP' : visit['TriageBP'][0],
        'Sex' : visit['Sex'][0],
        'Acuity' : visit['TriageAcuity'][0][0] if visit['TriageAcuity'][0] else None 
    }
    chief_complaint = list(ed_chief_complaints[visit_ix][-1])
    chief_complaint = chief_complaint[0] if chief_complaint else None 
    concepts = find_concepts(md_comments)
    omr_notes, omr_buckets, omr_terms = [], None, None
    disease_concepts = concepts.get('DISEASE')
    symptom_concepts = concepts.get('SYMPTOM')
    lab_concepts = concepts.get('LAB')
    medication_concepts = concepts.get('MEDICATION')
    for concept_type, ranker in [('SYMPTOM_CCVIT', symptom_autocomplete), ('SYMPTOM_CC', symptom_cc), ('SYMPTOM_LR', symptom_lr), ('SYMPTOM_NB', symptom_nb)]:
        ground_truth = [term_lookup[i][1] for i in concepts.get('SYMPTOM')]
        ground_truth_text = concepts.get('SYMPTOM')
        res[concept_type]['regular'] = None, None, np.mean([len(i) for i in ground_truth_text])
        res[concept_type]['spell'] = get_metrics(ranker.get_spell_ranking(), ground_truth, ground_truth_text, 'SYMPTOM')
        res[concept_type]['freq'] = get_metrics(ranker.get_frequency_ranking(), ground_truth, ground_truth_text, 'SYMPTOM')
        res[concept_type]['contextual'] = get_metrics(ranker.get_ranking(chief_complaint, vitals), ground_truth, ground_truth_text, 'SYMPTOM')
    return res

In [18]:
symptom_results = [run_symptom_autocompletes(i) for i in test_indices[:5000]]

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


KeyboardInterrupt: 

In [None]:
symptom_df = {}
for dt in ['SYMPTOM_CC', 'SYMPTOM_CCVIT', 'SYMPTOM_LR', 'SYMPTOM_NB']:
    for rank_type in ['regular', 'spell', 'freq', 'contextual']:
        for metric_i, metric in enumerate(['mrr', 'map', 'keystrokes']):
            col_header = '{}_{}_{}'.format(dt, rank_type, metric)
            symptom_df[col_header] = []
            for r in symptom_results:
                if not r.get(dt) or rank_type not in r[dt] or not r[dt][rank_type]:
                    symptom_df[col_header].append(None)
                    continue
                symptom_df[col_header].append(r[dt][rank_type][metric_i])

In [None]:
symptom_df = pd.DataFrame(symptom_df)

In [875]:
# symptom_df.to_csv('symptom_only_analysis.csv')

### Test HPI autocomplete models against each other

In [16]:
omr_date_parser = lambda x : datetime.strptime(x['time'], '%Y-%m-%d %H:%M:%S')

def run_hpi_autocompletes(visit_ix):
    res = {
        'HPI_NN' : {'regular' : None, 'spell' : None, 'freq' : None, 'contextual' : None},
        'HPI_LR' : {'regular' : None, 'spell' : None, 'freq' : None, 'contextual' : None},
        'HPI_LR_AUG' : {'regular': None, 'spell' : None, 'freq' : None, 'contextual' : None},
    }
    visit = ed_visits_pkl[visit_ix]
    md_comments = visit['MDcomments'][0]
    visit_date = datetime.strptime(visit['Date'][0][14:24], '%Y-%m-%d')
    pid = visit['PatientID'][0]
    if not md_comments:
        return res
    triage_assessment = visit['TriageAssessment'][0]
    vitals = {
        'Age' : visit['Age'][0],
        'Temp' : visit['TriageTemp'][0],
        'RR' : visit['TriageRR'][0],
        'Pulse' : visit['TriageHR'][0],
        'O2Sat' : visit['TriageSaO2'][0],
        'BP' : visit['TriageBP'][0],
        'Sex' : visit['Sex'][0],
        'Acuity' : visit['TriageAcuity'][0][0] if visit['TriageAcuity'][0] else None 
    }
    chief_complaint = list(ed_chief_complaints[visit_ix][-1])
    chief_complaint = chief_complaint[0] if chief_complaint else None 
    concepts = find_concepts(md_comments)
    omr_notes, omr_buckets, omr_terms = [], [], []
    last_mentioned = {b : (visit_date - datetime(1, 1, 1)).days for b in range(hpi_autocomplete_lr_aug_model.max_bucket)}
    # for each model relevance bucket, initialize the days since it was last mentioned as infinite
    disease_concepts = concepts.get('DISEASE')
    if len(disease_concepts) > 0 and pid in omr_data: # if there is OMR data...
        last_omr_note = 0
        for i, note in enumerate(omr_data[pid]):
            if omr_date_parser(note) >= visit_date:
                last_omr_note = i
                break
        omr_notes = [] if (last_omr_note==0) else omr_data[pid][:i] # filter OMR notes for those that occurred prior to ED visit date
        for o in omr_notes:
            o['time'] = omr_date_parser(o)
        omr_buckets, omr_terms, last_mentioned = hpi_autocomplete_lr_aug_model.get_omr_buckets(omr_notes, omr_date_parser(note))
    for concept_type, ranker in [('HPI_NN', hpi_autocomplete), ('HPI_LR', hpi_autocomplete_lr_model), ('HPI_LR_AUG', hpi_autocomplete_lr_aug_model)]:
        ground_truth = [term_lookup[i][1] for i in concepts.get('DISEASE')]
        ground_truth_text = concepts.get('DISEASE')
        if concept_type == 'HPI_NN':
            res[concept_type]['regular'] = None, None, np.mean([len(i) for i in ground_truth_text])
            res[concept_type]['spell'] = get_metrics(ranker.get_spell_ranking(), ground_truth, ground_truth_text, 'DISEASE')
            res[concept_type]['freq'] = get_metrics(ranker.get_frequency_ranking(), ground_truth, ground_truth_text, 'DISEASE')
        res[concept_type]['contextual'] = get_metrics(ranker.get_ranking(triage_assessment, list(omr_buckets), omr_terms, last_mentioned), ground_truth, ground_truth_text, 'DISEASE')
    return res

In [17]:
def has_omr(visit_ix):
    visit = ed_visits_pkl[visit_ix]
    md_comments = visit['MDcomments'][0]
    visit_date = datetime.strptime(visit['Date'][0][14:24], '%Y-%m-%d')
    pid = visit['PatientID'][0]
    triage_assessment = visit['TriageAssessment'][0]
    if pid not in omr_data:
        return False 
    for i, note in enumerate(omr_data[pid]):
        if note['time'] < visit_date:
            return True
    return False

In [18]:
import hpi_utils.hpi_autocomplete_utils
import hpi_utils.hpi_autocomplete_lr
import hpi_utils.hpi_autocomplete_lr_aug

In [19]:
hpi_autocomplete = hpi_utils.hpi_autocomplete_utils.HPIAutocomplete()
hpi_autocomplete_lr_model = hpi_utils.hpi_autocomplete_lr.HPIAutocomplete_LR()
hpi_autocomplete_lr_aug_model = hpi_utils.hpi_autocomplete_lr_aug.HPIAutocomplete_LRAug()



In [None]:
hpi_results = []
for i in test_indices[:2000]:
    hpi_results.append(run_hpi_autocompletes(i))

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


In [98]:
hpi_df = {}
for dt in ['HPI_NN', 'HPI_LR', 'HPI_LR_AUG']:
    for rank_type in ['regular', 'spell', 'freq', 'contextual']:
        for metric_i, metric in enumerate(['mrr', 'map', 'keystrokes']):
            col_header = '{}_{}_{}'.format(dt, rank_type, metric)
            hpi_df[col_header] = []
            for r in hpi_results:
                if not r.get(dt) or rank_type not in r[dt] or not r[dt][rank_type]:
                    hpi_df[col_header].append(None)
                    continue
                hpi_df[col_header].append(r[dt][rank_type][metric_i])

In [99]:
hpi_df = pd.DataFrame(hpi_df)

In [None]:
def count_omr(visit_ix):
    visit = ed_visits_pkl[visit_ix]
    md_comments = visit['MDcomments'][0]
    visit_date = datetime.strptime(visit['Date'][0][14:24], '%Y-%m-%d')
    pid = visit['PatientID'][0]
    if not md_comments:
        return None, None, None, None, None 
    hpi = md_comments.split('\n')[0].lower()
    concepts = find_concepts(hpi)
    omr_notes, omr_buckets, omr_terms = [], [], []
    last_mentioned = {b : (visit_date - datetime(1, 1, 1)).days for b in range(hpi_autocomplete_lr_aug_model.max_bucket)}
    disease_concepts = concepts.get('DISEASE')
    if len(disease_concepts) > 0 and pid in omr_data:
        last_omr_note = 0
        for i, note in enumerate(omr_data[pid]):
            if note['time'] >= visit_date:
                last_omr_note = i
                break
        omr_notes = [] if (last_omr_note==0) else omr_data[pid][:i]
        omr_buckets, omr_terms, last_mentioned = hpi_autocomplete_lr_aug_model.get_omr_buckets(omr_notes, visit_date)
    return len(omr_terms), len(concepts.get('DISEASE')), len(concepts.get('SYMPTOM')), len(concepts.get('LAB')), len(concepts.get('MEDICATION'))

In [987]:
count_df = pd.DataFrame(data=count_df_data, columns=['num_omr_terms', 'num_tagged_disease', 'num_tagged_symp', 'num_tagged_labs', 'num_tagged_meds'])

In [994]:
count_df = count_df.dropna()
count_df['total_tagged'] = count_df['num_tagged_disease'] + count_df['num_tagged_symp'] + count_df['num_tagged_labs'] + count_df['num_tagged_meds']

In [997]:
t = count_df[count_df['num_omr_terms'] > 0]