In [None]:
import pandas as pd
import sqlite3
import os
import numpy as np
import pickle
import time

In [None]:
# select year
year = 2015

In [None]:
# open out_dict, that is, collection of all patients
with open('/home/data/EHR/db/out_dict_%d.pckl'%year,'rb') as f:
    out_dict = pickle.load(f)

In [None]:
# open database
con = sqlite3.connect('/home/data/EHR/db/%d.db'%year)
cur = con.cursor()
con.text_factory = str

In [None]:
cur.execute('SELECT spec_id FROM sick_table WHERE sick_ab IS "I50"') # get all instances where I50 is observed
id_list = [x[0] for x in cur.fetchall()]

In [None]:
# preprocessing 1 - remove patients whose observation was conducted earlier than 6/31
jid_list = []
jid_remove_list = []
for id in id_list:
    cur.execute('SELECT jid,date FROM main_table WHERE spec_id is %d'%id)
    for (jid,date) in cur.fetchall():
        jid_list.append(jid) # add to jid list
        if date<(year*10000+631): # if the discovery was made 
            jid_remove_list.append(jid) # make sure we remove this later
jid_list = list(set(jid_list))
print(len(jid_list))
jid_remove_list = list(set(jid_remove_list))
print(len(jid_remove_list))
for jid in jid_remove_list:
    jid_list.remove(jid)
print(len(jid_list))

In [None]:
# preprocessing 2 - remove patients with less than 5 visits earlier than 6/31
jid_remove_list = []
for jid in jid_list:
    if len(out_dict[jid])<5: # remove patients less than 5 visits
        jid_remove_list.append(jid)
        continue
    if out_dict[jid][4][1]>(year*10000+631): # if a patient has less than 5 visits beforehand, remove as well
        jid_remove_list.append(jid)
        continue
for jid in jid_remove_list:
    jid_list.remove(jid)
print(len(jid_list))

In [None]:
# for each positive jid, save all visits that take place up to 6/31, and store it to out_dict2
out_dict2 = dict()
for jid in jid_list:
    cnt = 0
    for tup in out_dict[jid]:
        if tup[1]>(year*10000+631):
            break
        cnt+=1
    out_dict2[jid]=out_dict[jid][:cnt]

In [None]:
# save positive outputs to out_pos in /data/preprocessed/I50
with open('../data/preprocessed/I50/out_pos_%d.pckl'%year,'wb') as f:
    pickle.dump(out_dict2,f)

In [None]:
with open('../data/preprocessed/I50/out_pos_%d.pckl'%year,'rb') as f:
    out_dict2 = pickle.load(f)

In [None]:
pos_list = list(out_dict2.keys())
print(len(pos_list))

In [None]:
with open('/home/data/EHR/db/%d_demo.pckl'%year,'rb') as f:
    df_demo = pickle.load(f)

In [None]:
# count number of visits for each patient prior to 0631
visit_dict = dict()
for k,V in out_dict.items():
    cnt = 0
    for (_,date,_,_) in V:
        if date>(year*10000+631):
            break
        cnt+=1
    visit_dict[k]=cnt
with open('../data/preprocessed/I50/visits_%d.pckl'%year,'wb') as f:
    pickle.dump(visit_dict,f)

In [None]:
# exclude list - patients with at least one diagnosis in I50
exclude_list = []
cur.execute('SELECT spec_id FROM sick_table WHERE sick_ab IS "I50"') # get all instances where I50 is observed
id_list = [x[0] for x in cur.fetchall()]
for id in id_list:
    cur.execute('SELECT jid FROM main_table WHERE spec_id is %d'%id)
    for (jid,) in cur.fetchall():
        exclude_list.append(jid) # add to jid list
exclude_list = list(set(exclude_list)-set(pos_list))
candidate_list = list(set(out_dict.keys())-set(exclude_list))

In [None]:
len(exclude_list)

In [None]:
D = df_demo[['jid','agg','gender']]

In [None]:
D = D.drop_duplicates()

In [None]:
D2 = D[D['jid'].isin(candidate_list)]

In [None]:
cur.execute('DROP TABLE demo_table')

In [None]:
D2.to_sql('demo_table',con=con)
cur.execute('CREATE INDEX idx_demo1 ON demo_table(jid)')
cur.execute('CREATE INDEX idx_demo2 ON demo_table(agg,gender)')
con.commit()

In [None]:
# df_demo.to_sql('demo_table',con=con)
# cur.execute('CREATE INDEX idx_demo ON demo_table(jid)')
# cur.execute('CREATE INDEX idx_demo2 ON demo_table(agg,gender)')
# con.commit()

In [None]:
import time
start = time.time()
out_list = []
cnt = 0
for jid in pos_list:
    length = visit_dict[jid]
    cnt+=1
    if cnt%100==0:
        print(cnt)
        print(time.time()-start)
    cur.execute('SELECT agg,gender FROM demo_table WHERE jid IS %d' %jid)
    out = cur.fetchone()
    cur.execute('SELECT jid FROM demo_table WHERE agg IS %d AND gender IS %d' %out)
    results = list(set([x[0] for x in cur.fetchall()]))
    neg = []
    for jid2 in results:
        try:
            len2 = visit_dict[jid2]
            if (len2>=length) & (len2<=length*1.3):
                neg.append(jid2)
            if len(neg)>=11:
                break  
        except KeyError:
            continue
    if jid in neg:
        neg.remove(jid)
    else:
        neg = neg[:10]
    if len(neg)==10:
        out_list.append((jid,neg))

In [None]:
out_list2 = []
answers = list(np.zeros(11,dtype=int))
answers[0] = 1
for (pos,neg_list) in out_list:
    visit_len = visit_dict[pos]
    input_list = []
    target_list = list(np.zeros(len(neg_list)+1,dtype=int))
    target_list[0] = 1
    input_list.append(out_dict[pos][:visit_len])
    for neg in neg_list:
        input_list.append(out_dict[neg][:visit_len])
    out_list2.append((input_list,target_list))

In [None]:
with open('../data/preprocessed/I50/list_data_%d.pckl'%year,'wb') as f:
    pickle.dump(out_list2,f)

In [None]:
con.close()

In [None]:
len(out_list2)