In [1]:
import sqlite3
import os
import pickle

db_dir = '/home/mjc/github/EHRVis/data/database/'
processed_dir = '/home/mjc/github/EHRVis/data/preprocessed/'
dict_dir = '/home/mjc/github/EHRVis/data/dictionaries/'
years = [2014,2015]

In [2]:
# functions
def remove_unknown(lst):
    lst = list(set(lst))
    for item in [0,500,768]:
        if item in lst:
            lst.remove(item)
    if len(lst)==0:
        return [0]
    else:
        return lst

def get_classified_sickness(sample,alphabet_dict1,alphabet_dict2):
    if len(sample)<3:
        return 0
    c, num = sample[0],int(sample[1:])
    for i,rng in enumerate(alphabet_dict1[c]):
        if '-' in rng:
            lower, upper = rng.split('-')
            lower = int(lower[1:])
            upper = int(upper[1:])
            if (lower<=num) & (upper>=num):
                answer = alphabet_dict2[c][i]
                return answer
        else:
            if num==int(rng[1:]):
                answer = alphabet_dict2[c][i]
                return answer
    return 0

def to_dict_idx(string, data_type, D):
    if data_type=='diag':
        try:
            out = D[string]
            return out
        except KeyError:
            return 0
    elif data_type=='sick':
        try:
            out = get_classified_sickness(string, D[0],D[1])
            return out+500
        except KeyError:
            return 500
    elif data_type=='pres':
        try:
            out = D[string]
            return out + 768
        except KeyError:
            return 768

In [3]:
# 0 - open database
year = years[1]
con = sqlite3.connect(os.path.join(db_dir,('%d.db' %year)))
cur = con.cursor()
con.text_factory = str

In [None]:
# 1 - select all ids from main_table
print('Select all ids from main table')
cur.execute('SELECT jid,id,to_date,msick_ab,fom FROM main_table WHERE to_date IS NOT NULL ORDER BY jid,to_date')
rows = cur.fetchall()
rows2 = []
# we omit visits without an end date
for tup in rows:
    if len(str(tup[2]))==8:
        if (tup[4]==31)|(tup[4]==21):
            rows2.append(tup)
del rows

Select all ids from main table


In [None]:
# 2 - remove events that happen more than once
tup = rows2[0]
prev_date = tup[2]
rows3 = [tup]
for i,tup in enumerate(rows2):
    if prev_date!=tup[2]:
        rows3.append(tup)
    prev_date = tup[2]
del rows2

In [None]:
# 3 - replace current msick with msick of next stage
jid_list , id_list, date_list, out_list, fom_list = zip(*rows3)
list1 = zip(jid_list[:-1],id_list[:-1],date_list[:-1],fom_list[:-1],out_list[1:])
del jid_list, id_list, date_list, fom_list
del rows3
with open(os.path.join(processed_dir,'%d_list1.pckl'%year),'wb') as f:
    pickle.dump(list1,f)

In [None]:
# 4 - create list of all tuples
with open(os.path.join(processed_dir,'%d_list1.pckl'%year),'rb') as f:
    list1 = pickle.load(f)
list2 = []
with open(os.path.join(dict_dir,'d2i.pckl'),'rb') as f:
    d2i = pickle.load(f)
with open(os.path.join(dict_dir,'sick_converter.pckl'),'rb') as f:
    s2i = pickle.load(f)
with open(os.path.join(dict_dir,'p2i.pckl'),'rb') as f:
    p2i = pickle.load(f)
cnt = 0
import time
start = time.time()
for tup in list1: # for every tuple
    cnt+=1
    jid,id,date,fom,out = tup
    out_list = [] # get list for each tup
    # get list of diagnosis for each visit
    cur.execute('SELECT div_code FROM diag_table WHERE id IS %d'%id)
    diag_list = cur.fetchall()
    if len(diag_list)>0:
        out_list.extend([to_dict_idx(str(x[0]).strip(),'diag',d2i) for x in diag_list])
    # get list of sickness for each visit
    cur.execute('SELECT sick_code_ab FROM sick_table WHERE id IS %d'%id)
    sick_list = cur.fetchall()
    if len(sick_list)>0: 
        out_list.extend([to_dict_idx(str(x[0]).strip(),'sick',s2i) for x in sick_list])
    # get list of prescriptions for each visit
    cur.execute('SELECT gnl_code FROM pres_table WHERE id IS %d'%id)
    pres_list = cur.fetchall()
    if len(pres_list)>0:
        out_list.extend([to_dict_idx(str(x[0]).strip(),'pres',p2i) for x in pres_list])
    # merge all to out_list
    out_list = list(set(out_list))
    out_list = remove_unknown(out_list)
    out = to_dict_idx(out,'sick',s2i)-500
    list2.append((jid,id,date,fom,out_list,out))
    if cnt%100000==0:
        print(cnt,time.time()-start)
with open(os.path.join(processed_dir,'%d_list2.pckl'%year),'wb') as f:
    pickle.dump(list2,f)
del list1

In [None]:
# 5 - create dict containing each jid
out_dict = dict()
current_jid = 1
cnt = 0
tmp1 = [] # dictionary value for current_jid
for tup in list2:
    cnt+=1
    jid,id,date,fom,out_list,out = tup
    if jid!=current_jid:
        tmp1.pop(-1) # remove last item
        out_dict[current_jid] = tmp1 # key-value
        tmp1 = [] # reset tmp1
        current_jid = jid # update jid
    tmp1.append((date,fom,out_list,out))
tmp1.pop(-1) # for last jid
out_dict[current_jid] = tmp1
with open(os.path.join(processed_dir,'%d_out_dict.pckl'%year),'wb') as f:
    pickle.dump(out_dict,f)
del list2

In [None]:
# 6 - split data into train, val, test sets
len_dict = dict()
for k,v in out_dict.items():
    l = len(v)
    if l not in len_dict:
        len_dict[l]=[k]
    else:
        len_dict[l].append(k)
all_lengths = list(len_dict.keys())
all_lengths = all_lengths[5:300]
import random
random.shuffle(all_lengths)
# train/val/test: 70/5/25
type_dict = dict()
type_dict['train'] = all_lengths[:int(len(all_lengths)*0.7)]
type_dict['val'] = all_lengths[int(len(all_lengths)*0.7):int(len(all_lengths)*0.75)]
type_dict['test'] = all_lengths[int(len(all_lengths)*0.75):]
types = ['train','test','val']
for typ in types:
    for item in type_dict[typ]:
        jid_list = len_dict[item]
        out_list = []
        for jid in jid_list:
            out_list.append(out_dict[jid])
        with open('/home/mjc/github/EHRVis/data/batches/%s/%d_%d.pckl'%(typ,year,item),'wb') as f:
            pickle.dump(out_list,f)
del out_list, out_dict

In [None]:
out_list[0]