In [None]:
import pickle
import shelve
import marisa_trie
import pandas as pd
from datetime import datetime
import json
import re
import numpy as np 

### Load rankers for labs, meds, symptoms, and conditions

In [None]:
import hpi_utils
reload(hpi_utils)

In [None]:
from med_lab_autocomplete_utils import LabAutocomplete, MedAutocomplete
from symptom_utils.symptom_autocomplete_utils import SymptomAutocomplete
from hpi_utils.hpi_autocomplete_utils import HPIAutocomplete

In [None]:
lab_autocomplete = LabAutocomplete()
medication_autocomplete = MedAutocomplete()
symptom_autocomplete = SymptomAutocomplete()
hpi_autocomplete = HPIAutocomplete()

### Load data
* ED data contains triage information and notes written in the emergency department, per patient (shuffled in a random order).
* OMR/EHR data refers to all prior clinical notes in a patient's record (keyed by PatientID) 

In [None]:
print('Loading ED and OMR data...')
ed_visits_pkl = pickle.load(open('/data/BIDMC/ed_data/visits_full.pkl'))
omr_data = shelve.open('/data/BIDMC/ed_data/omr/omrShelfPatient_py3_jclinic')
with open('/data/BIDMC/jclinic/extracted_data/ed_chief_complaints.pkl', 'rb') as h:
    ed_chief_complaints = pickle.load(h)

### Open ontologies and create trie-based datastructure to find clinical concepts

In [None]:
hpi_ontology = pd.read_csv('ontologies/hpi_autocomplete_ontology.csv', index_col=0)
hpi_ontology['synonyms'] = hpi_ontology['synonyms'].apply(pd.eval)
symptom_ontology = pd.read_csv('ontologies/symptom_autocomplete_ontology.csv')
symptom_ontology['synonyms'] = symptom_ontology['synonyms'].apply(pd.eval)
with open('ontologies/medication_ontology.json', 'r') as h:
    med_list = json.load(h)['freq']
with open('ontologies/lab_ontology.json', 'r') as h:
    lab_list = json.load(h)['freq']

In [None]:
trie_keys = set()
term_lookup = {}

In [None]:
for i, med in enumerate(med_list):
    trie_keys.add(med)
    term_lookup[med] = ('MEDICATION', i)
for i, syns in enumerate(hpi_ontology['synonyms']):
    if hpi_ontology.loc[i]['ignore']:
        continue
    for s in syns:
        if s in ['as', 'ks', 'uc', 'rd', 'vt', 'di']: # cannot be disambiguated:
            continue
        trie_keys.add(s)
        term_lookup[s] = ('DISEASE', i)
for i, syns in enumerate(symptom_ontology['synonyms']):
    if symptom_ontology.loc[i]['ignore']:
        continue
    for s in syns:
        trie_keys.add(s)
        term_lookup[s] = ('SYMPTOM', i)
for i, lab in enumerate(lab_list):
    trie_keys.add(lab)
    term_lookup[lab] = ('LAB', i)

In [None]:
search_trie = marisa_trie.Trie(trie_keys)

In [None]:
all_synonyms = {
    'DISEASE' : list(hpi_ontology['synonyms']), 
    'SYMPTOM' : list(symptom_ontology['synonyms']), 
    'LAB' : [[lab] for lab in lab_list],
    'MEDICATION' : [[med] for med in med_list]
}

### Find clinical concepts retrospectively in text

In [None]:
def find_concepts(tokens):
    concepts = {}
    blacklisted_toks = set()
    for i, tok in enumerate(tokens):
        if i in blacklisted_toks:
            continue
        future_txt = ''.join(tokens[i:])
        potentials = search_trie.prefixes(unicode(future_txt))
        if not potentials:
            continue
        best = max(potentials, key=len)
        if len(future_txt) == len(best) or future_txt[len(best)] in ' ,;:.':
            concept = term_lookup[best][0]
            concepts[i] = (concept, best)
            for j in range(len(re.split('(\s|[\.!\?,;])', best))):
                blacklisted_toks.add(j + i)
    return concepts

### Measure MRR, MAPK, and keystroke burden 

In [None]:
def update_ranking(suggested, term): 
    suggested.remove(term)
    suggested.append(term)
    return suggested

def get_num_keystrokes_singlequery(suggested_ranking, relevant_term, relevant_text, k=1):
    # relevant terms are now (TERM_TYPE, index) 
    def word_in_set(word, suggested_term):
        suggested_term_type, suggested_term_index = suggested_term
        syns = all_synonyms[suggested_term_type][suggested_term_index]
        for s in syns:
            if word in s:
                return True
        return False
    for i in range(len(relevant_text)):
        word = relevant_text[:i]
        new_suggested = [s for s in suggested_ranking if word_in_set(word, s)]
        if relevant_term in new_suggested[:min(k, len(new_suggested))]:
            return i
    return len(relevant_text)

def reciprocal_rank(suggested_ranking, relevant_term, num_relevant):
    return (1.0/float(suggested_ranking.index(relevant_term) - num_relevant + 2) if suggested_ranking.index(relevant_term) >= num_relevant else 1.0)

In [None]:
import numpy as np 
np.random.seed(0) # fix random seed so train/test split is preserved 
test_indices = np.random.choice(len(ed_visits_pkl), 25000)

##  Logic to detect autocomplete scope and type 

In [None]:
# define trigger set and what concept type it maps to
scope_trigger_types = { 
    'p/w' : 'SYMPTOM',
    'presents with' : 'SYMPTOM',
    'presented with' : 'SYMPTOM',
    'presenting with' : 'SYMPTOM',
    'presents w/' : 'SYMPTOM',
    'presented w/' : 'SYMPTOM',
    'presenting w/' : 'SYMPTOM',
    'came in with' : 'SYMPTOM',
    'c/o' : 'SYMPTOM',
    'complains of' : 'SYMPTOM',
    'complained of' : 'SYMPTOM',
    'complaining of' : 'SYMPTOM',
    's/p' : 'SYMPTOM',
    'status post' : 'SYMPTOM',
    'h/o' : 'DISEASE',
    'hx of' : 'DISEASE',
    'pmh' : 'DISEASE',
    'history of' : 'DISEASE',
    'on' : 'MEDICATION',
    'had' : 'SYMPTOM',
    'has' : 'SYMPTOM',
    'had r' : 'SYMPTOM',
    'had l' : 'SYMPTOM',
    'had right' : 'SYMPTOM',
    'had left' : 'SYMPTOM',
    'but no' : 'SYMPTOM',
    'onset of' : 'SYMPTOM',
    'describes' : 'SYMPTOM',
    'describes having' : 'SYMPTOM',
    'denies' : 'SYMPTOM',
    'notes' : 'SYMPTOM',
    'diagnosed with' : 'DISEASE',
    'has' : 'DISEASE',
    'felt like' : 'SYMPTOM', 
    'takes' : 'MEDICATION',
    'treated with' : 'MEDICATION',
    'with' : 'SYMPTOM'
}

# enumerate each concept type 
term_trigger_types = {
    0 : 'DISEASE',
    1 : 'SYMPTOM',
    2 : 'MEDICATION', 
    3 : 'LAB'
}

scope_trigger_trie = marisa_trie.Trie(scope_trigger_types.keys())

# continuation tokens do not affect scope 
scope_continuation_tokens = ['and', 'any', 'or', 'no', 'of', 'with', 'but', ',', '"', 'abd', 'l', 'r', 'left', 'right']

def predict_autocomplete_scope(tokens, is_structured):
    """
    Greedily determines how to autocomplete from a set of tokens. 
    
    Note that tokens include whitespace and punctuation-- so something something like re.split('(\s|[\.!\?,;])', text)
    is_structured is a list of same length as tokens, with booleans 
    representing whether corresponding token is part of structured info.
    Possible outputs:
    1. Start autocomplete of type X
    2. Stop autocomplete
    3. Continue previous autocomplete type
    """
    prompt_autocomplete = False
    autocomplete_type = None
    whitelisted_tokens = set()
    for i, (token, token_is_structured) in enumerate(zip(tokens, is_structured)):
        if '(' in token:
            continue 
        triggers = scope_trigger_trie.prefixes(unicode(''.join(tokens[i:])))
        if triggers:
            trigger = max(triggers, key=len)
            prompt_autocomplete = True
            autocomplete_type = scope_trigger_types[trigger]
            for j in range(len(re.split('(\s|[\.!\?,;])', trigger))):
                whitelisted_tokens.add(j + i)
            continue
        if token_is_structured != -1:
            prompt_autocomplete = True
            autocomplete_type = term_trigger_types[token_is_structured]
            continue
        keep_scope = (i in whitelisted_tokens or (token_is_structured!=-1) or token in scope_continuation_tokens or bool(re.match('\s', token)))
        if not keep_scope:
            prompt_autocomplete = False
            autocomplete_type = None
    return autocomplete_type   

In [None]:
def has_omr(visit_ix):
    # determine whether patient has any OMR  based on date of ED visit
    visit = ed_visits_pkl[visit_ix]
    md_comments = visit['MDcomments'][0]
    visit_date = datetime.strptime(visit['Date'][0][14:24], '%Y-%m-%d')
    pid = visit['PatientID'][0]
    if not md_comments:
        return None
    triage_assessment = visit['TriageAssessment'][0]
    omr_buckets, omr_terms = [], []
    if pid in omr_data:
        last_omr_note = 0
        for i, note in enumerate(omr_data[pid]):
            if note['time'] < visit_date:
                return True
    return False

In [None]:
import re

omr_date_parser = lambda x : datetime.strptime(x['time'], '%Y-%m-%d %H:%M:%S')

In [None]:
def run_all_autocompletes(visit_ix):
    visit = ed_visits_pkl[visit_ix]
    md_comments = visit['MDcomments'][0]
    visit_date = datetime.strptime(visit['Date'][0][14:24], '%Y-%m-%d')
    pid = visit['PatientID'][0]
    if not md_comments:
        return None, None, None, None
    triage_assessment = visit['TriageAssessment'][0]
    vitals = {
        'Age' : visit['Age'][0],
        'Temp' : visit['TriageTemp'][0],
        'RR' : visit['TriageRR'][0],
        'Pulse' : visit['TriageHR'][0],
        'O2Sat' : visit['TriageSaO2'][0],
        'BP' : visit['TriageBP'][0],
        'Sex' : visit['Sex'][0],
        'Acuity' : visit['TriageAcuity'][0][0] if visit['TriageAcuity'][0] else None 
    }
    chief_complaint = list(ed_chief_complaints[visit_ix][-1])
    chief_complaint = chief_complaint[0] if chief_complaint else None 
    hpi = md_comments.split('\n')[0].lower()
    tokens = [t for t in re.split('(\s|[\.!\?,;])', hpi) if t != '']
    concepts = find_concepts(tokens)
    omr_buckets, omr_terms = [], []
    if pid in omr_data:
        last_omr_note = 0
        for i, note in enumerate(omr_data[pid]):
            if omr_date_parser(note) >= visit_date:
                last_omr_note = i
                break
        omr_notes = [] if (last_omr_note==0) else omr_data[pid][:i]
        omr_buckets, omr_terms = hpi_autocomplete.get_omr_buckets(omr_notes)
        
    # can be changed to get_freq_ranking and get_spell_ranking for others
    lab_ranking = lab_autocomplete.get_frequency_ranking()
    med_ranking = medication_autocomplete.get_frequency_ranking()
    symptom_ranking = symptom_autocomplete.get_ranking(chief_complaint, vitals)  # can change to get_freq_ranking() or get_spell_ranking() for freq/spell baselines
    hpi_ranking = hpi_autocomplete.get_ranking(triage_assessment, list(omr_buckets), omr_terms, None) # can change to get_freq_ranking() or get_spell_ranking() for freq and spell baselines
    rankings = [
        [('DISEASE', t) for t in hpi_ranking],
        [('SYMPTOM', t) for t in symptom_ranking],
        [('MEDICATION', t) for t in med_ranking],
        [('LAB', t) for t in lab_ranking]
    ]
    type_map = {
        'DISEASE': 0,
        'SYMPTOM': 1,
        'MEDICATION': 2,
        'LAB': 3
    }
    is_structured = []
    keystroke_metric = []
    regular_metric = []
    mrrs = []
    scope_misses = []
    type_misses = []
    typed_text_list = []
    in_omr_list = []
    i, tok = 0, tokens[0]
    while i  < len(tokens) - 1:
        tok = tokens[i]
        if i in concepts:
            autocomplete_type = predict_autocomplete_scope(tokens[:i], is_structured) # predicted autocomplete type
            if autocomplete_type is None:
                scope_misses.append(1) # autocomplete type is 
                autocomplete_type = 'DISEASE'
            else:
                scope_misses.append(0)
            relevant_text = concepts[i][1]
            relevant_text_token_length = len(re.split('(\s|[\.!\?,;])', relevant_text))
            relevant_term = term_lookup[relevant_text]
            relevant_term_type = relevant_term[0]
            type_misses.append(int(relevant_term_type != autocomplete_type))
            new_rankings = rankings[type_map[autocomplete_type]] + sum([
                rankings[j] for j in range(4) if j != type_map[autocomplete_type]
            ], [])
            num_keystrokes = get_num_keystrokes_singlequery(new_rankings, relevant_term, relevant_text, k=3)
            mrr = reciprocal_rank(rankings[type_map[relevant_term_type]], relevant_term, len(concepts))
            mrrs.append(mrr)
            keystroke_metric.append(num_keystrokes)
            regular_metric.append(len(relevant_text))
            typed_text_list.append(relevant_text)
            rankings[type_map[relevant_term_type]].remove(relevant_term)
            rankings[type_map[relevant_term_type]].append(relevant_term)
            is_structured += [type_map[relevant_term_type] for _ in range(relevant_text_token_length)]
            i += relevant_text_token_length
            in_omr_list.append(relevant_term[1] in omr_terms)
        else:
            is_structured.append(-1) 
            i += 1
    return regular_metric, keystroke_metric, scope_misses, type_misses, typed_text_list, in_omr_list, mrrs

In [None]:
#freq_results = []
#spell_results = []
all_results = []
for j, i in enumerate(test_indices):
    all_results.append(run_all_autocompletes(i))

In [None]:
word_counts = {}
keystrokes_per_concept = {}
keystrokes_per_concept_omr = {}
keystrokes_per_concept_no_prior = {}
keystrokes_per_concept_prior_no_omr = {}
mrr_per_concept_freq = {}
mrr_per_concept_cont = {}
k_burden = []
for x, y, index in zip(freq_results, all_results, test_indices[:2000]):
    omr_present = has_omr(index)
    if len(x) != 7:
        continue
    words_spell = x[-3]
    ks_spell = x[1]
    words_all = y[-3]
    ks_all = y[1]
    for word, k_spell, _, k_all, in_omr, mrr_freq, mrr_cont in zip(words_spell, ks_spell, words_all, ks_all, x[-2], x[-1], y[-1]):
        if word == 'as':
            continue
        term_type = term_lookup[word]
        if term_type[0] != 'DISEASE':
            continue
        word_counts[term_type] = word_counts.get(term_type, 0) + 1
        keystrokes_per_concept[term_type] = keystrokes_per_concept.get(term_type, []) + [k_spell - k_all]
        mrr_per_concept_freq[term_type] = mrr_per_concept_freq.get(term_type, []) + [mrr_freq]
        mrr_per_concept_cont[term_type] = mrr_per_concept_cont.get(term_type, []) + [mrr_cont]
        if in_omr:
            keystrokes_per_concept_omr[term_type] = keystrokes_per_concept_omr.get(term_type, []) + [k_spell - k_all]
            k_burden.append(k_all)
        if not omr_present:
            keystrokes_per_concept_no_prior[term_type] = keystrokes_per_concept_no_prior.get(term_type, []) + [k_spell - k_all]
        if omr_present and not in_omr:
            keystrokes_per_concept_prior_no_omr[term_type] = keystrokes_per_concept_prior_no_omr.get(term_type, []) + [k_spell - k_all]

In [None]:
mean_freqs, mean_conts, conf_freqs, conf_conts, cnames = [], [], [], [], []
for c in [2, 19, 23, 54, 153, 159, 223, 226]: # global interpretability case studies
    common_name = hpi_ontology.loc[c]['common_name']
    term_type = ('DISEASE', c)
    mean_freq = np.mean(mrr_per_concept_freq[term_type])
    conf_freq = conf_interval(mrr_per_concept_freq[term_type])
    mean_cont = np.mean(mrr_per_concept_cont[term_type])
    conf_cont = conf_interval(mrr_per_concept_cont[term_type])
    cnames.append(common_name)
    mean_freqs.append(mean_freq)
    mean_conts.append(mean_cont)
    conf_freqs.append(mean_freq - conf_freq[0])
    conf_conts.append(mean_cont - conf_cont[0])

In [None]:
plt.barh(np.arange(len(mean_conts)), mean_conts, 0.3, xerr = conf_conts, alpha=0.8, label='Contextual', capsize=3, color='purple')
plt.barh(np.arange(len(mean_conts)) + 0.3, mean_freqs, 0.3, xerr = conf_freqs, alpha=0.8, label='Frequency', capsize=3, color='orange')
plt.legend()
plt.xlim(0, 1.4)
plt.yticks(np.arange(len(mean_conts)), cnames);
plt.xlabel('MRR')
plt.tight_layout()
fig = plt.gcf()
fig.set_size_inches(7, 7)
plt.savefig('mrr_by_concept.pdf')

In [None]:
mrr_per_concept_freq = { k : np.mean(v) for k, v in mrr_per_concept_freq.items()}
mrr_per_concept_cont = { k : np.mean(v) for k, v in mrr_per_concept_cont.items()}