In [None]:
import pickle
import shelve
import marisa_trie
import pandas as pd
from datetime import datetime
import json
import re
import numpy as np 

In [None]:
from med_lab_autocomplete_utils import LabAutocomplete, MedAutocomplete
from symptom_utils.symptom_autocomplete_utils import SymptomAutocomplete
from hpi_utils.hpi_autocomplete_utils import HPIAutocomplete

In [None]:
lab_autocomplete = LabAutocomplete()
medication_autocomplete = MedAutocomplete()
symptom_autocomplete = SymptomAutocomplete()
hpi_autocomplete = HPIAutocomplete()

In [None]:
DATA_BASE_FP = '' # fill in here 

In [None]:
print('Loading ED and OMR data...')
ed_visits_pkl = pickle.load(open(f'{DATA_BASE_FP}/ed_data/visits_full.pkl'))
omr_data = shelve.open(f'{DATA_BASE_FP}/ed_data/omr/omrShelfPatient_py3_jclinic')
with open(f'{DATA_BASE_FP}/jclinic/extracted_data/ed_chief_complaints.pkl', 'rb') as h:
    ed_chief_complaints = pickle.load(h)

In [None]:
with open(f'{DATA_BASE_FP}/jclinic/extracted_data/allowable_umls_lookups.pkl', 'r') as h:
    allowable_umls_terms = pickle.load(h)

In [None]:
with open(f'{DATA_BASE_FP}/jclinic/extracted_data/umls_to_history_bucket_v2.pkl', 'r') as h:
    umls_to_hist_bucket = pickle.load(h)

In [None]:
hpi_ontology = pd.read_csv('ontologies/hpi_autocomplete_ontology.csv', index_col=0)
hpi_ontology['synonyms'] = hpi_ontology['synonyms'].apply(pd.eval)
symptom_ontology = pd.read_csv('ontologies/symptom_autocomplete_ontology.csv')
symptom_ontology['synonyms'] = symptom_ontology['synonyms'].apply(pd.eval)
with open('ontologies/medication_ontology.json', 'r') as h:
    med_list = json.load(h)['freq']
with open('ontologies/lab_ontology.json', 'r') as h:
    lab_list = json.load(h)['freq']

In [None]:
trie_keys = set()
term_lookup = {}

In [None]:
for i, med in enumerate(med_list):
    trie_keys.add(med)
    term_lookup[med] = ('MEDICATION', i)
for i, syns in enumerate(hpi_ontology['synonyms']):
    if hpi_ontology.loc[i]['ignore']:
        continue
    for s in syns:
        trie_keys.add(s)
        term_lookup[s] = ('DISEASE', i)
for i, syns in enumerate(symptom_ontology['synonyms']):
    if symptom_ontology.loc[i]['ignore']:
        continue
    for s in syns:
        trie_keys.add(s)
        term_lookup[s] = ('SYMPTOM', i)
for i, lab in enumerate(lab_list):
    trie_keys.add(lab)
    term_lookup[lab] = ('LAB', i)

In [None]:
search_trie = marisa_trie.Trie(trie_keys)

In [None]:
all_synonyms = {
    'DISEASE' : list(hpi_ontology['synonyms']), 
    'SYMPTOM' : list(symptom_ontology['synonyms']), 
    'LAB' : [[lab] for lab in lab_list],
    'MEDICATION' : [[med] for med in med_list]
}

In [None]:
def find_concepts(tokens):
    concepts = {}
    blacklisted_toks = set()
    for i, tok in enumerate(tokens):
        if i in blacklisted_toks:
            continue
        future_txt = ''.join(tokens[i:])
        potentials = search_trie.prefixes(unicode(future_txt))
        if not potentials:
            continue
        best = max(potentials, key=len)
        if len(future_txt) == len(best) or future_txt[len(best)] in ' ,;:.':
            concept = term_lookup[best][0]
            concepts[i] = (concept, best)
            for j in range(len(re.split('(\s|[\.!\?,;])', best))):
                blacklisted_toks.add(j + i)
    return concepts

In [None]:
import numpy as np 
np.random.seed(0)
test_indices = np.random.choice(len(ed_visits_pkl), 25000)

##  Logic to detect scope and type 

In [None]:
def has_omr(visit_ix):
    visit = ed_visits_pkl[visit_ix]
    md_comments = visit['MDcomments'][0]
    visit_date = datetime.strptime(visit['Date'][0][14:24], '%Y-%m-%d')
    pid = visit['PatientID'][0]
    if not md_comments:
        return None
    triage_assessment = visit['TriageAssessment'][0]
    omr_buckets, omr_terms = [], []
    if pid in omr_data:
        last_omr_note = 0
        for i, note in enumerate(omr_data[pid]):
            if note['time'] < visit_date:
                return True
    return False

In [None]:
omr_date_parser = lambda x : datetime.strptime(x['time'], '%Y-%m-%d %H:%M:%S')

def generate_lime_data(visit_ix, verbose=False):
    visit = ed_visits_pkl[visit_ix]
    md_comments = visit['MDcomments'][0]
    visit_date = datetime.strptime(visit['Date'][0][14:24], '%Y-%m-%d')
    pid = visit['PatientID'][0]
    triage_assessment = visit['TriageAssessment'][0]
    if not md_comments or not triage_assessment:
        return None
    hpi = md_comments.split('\n')[0].lower()
    if verbose:
        print(hpi)
    tokens = [t for t in re.split('(\s|[\.!\?,;])', hpi) if t != '']
    concepts = find_concepts(tokens)
    omr_buckets, omr_terms = [], []
    if pid in omr_data:
        last_omr_note = 0
        for i, note in enumerate(omr_data[pid]):
            if omr_date_parser(note) >= visit_date:
                last_omr_note = i
                break
        omr_notes = [] if (last_omr_note==0) else omr_data[pid][:i]
        omr_buckets, omr_terms = hpi_autocomplete.get_omr_buckets(omr_notes)
    if verbose:
        print(omr_terms)
        print(triage_assessment, list(omr_buckets))
    X1, X2, y = hpi_autocomplete.get_lime_data(triage_assessment, list(omr_buckets))
    return X1, X2, y

In [None]:
from sklearn.linear_model import Lasso

In [None]:
lime_data = [generate_lime_data(_) for _ in range(10000)]

In [None]:
lime_data = [l for l in lime_data if l is not None]

In [None]:
X1 = np.array([np.squeeze(x1) for x1, x2, y in lime_data])
X2 = np.array([np.squeeze(x2) for x1, x2, y in lime_data])
Y = np.array([y for x1, x2, y in lime_data])

In [None]:
reg_model = Lasso(alpha=0.001)

In [None]:
diabetes_index = 1
reg_model = reg_model.fit(np.hstack((X1, X2)), np.log(Y[:, diabetes_index]/(1-Y[:, diabetes_index])))

In [None]:
tfidf_cats = {i : word for word, i in hpi_autocomplete.triage_vectorizer.vocabulary_.items()}