In [1]:
import csv
import os
import sys
import numpy as np
import sklearn.model_selection as ms
import pandas as pd
import _pickle as pickle

In [2]:
class Patient(object):
    def __init__(self, adm_ids, patient_id, diag_codes, medication_codes, procedure_codes, medical_history, icd_length, med_length, proc_length):
        self.patient_id = patient_id
        self.nvisits = len(adm_ids)
        self.diagnosis_codes = np.concatenate(diag_codes, axis=0 )
        self.medication_codes = np.concatenate(medication_codes, axis=0)
        self.procedure_codes = np.concatenate(procedure_codes, axis=0)
        self.diagnosis_length = icd_length
        self.medication_length = med_length 
        self.procedure_length = proc_length
        self.medical_history = medical_history
        
        
def process_patient(infile, patient_history_dict):

    patients = infile#pd.read_parquet(infile)
    
    patient_dict = {}
    
    count=0
    for rowindex, line in patients.iterrows():
        #if count == 1000:
        #    break
        if count % 10000 == 0:
            print(count, end='\r')
            
        patient_id = line['subject_id']
        encounter_id = line['hadm_id']
        readmission = line["label"]
        icd_code = line["diagnos_code"]
        ndc = line["medication_code"]
        procedures = line["procedure_code"]

        if patient_id not in patient_dict:
            patient_dict[patient_id] = []
        patient_dict[patient_id].append(encounter_id)
        patient_dict[patient_id].append(readmission)
        patient_dict[patient_id].append(icd_code)
        patient_dict[patient_id].append(ndc)
        patient_dict[patient_id].append(procedures)
        
        count +=1
    
    for patient_id, information in patient_dict.items():
        if sum(information[1]) > 0:
            medical_history = []
            icd_codes, med_codes, proc_codes = [], [], []
            icd_length, med_length, proc_length = [], [], []
            insert_index = 0
            for visit_no in range(information[0].shape[0]):
                c_hadm = information[0][visit_no]
                c_readmission = information[1][visit_no]
                c_icd_code = information[2][visit_no]
                c_ndc = information[3][visit_no]
                c_procedures = information[4][visit_no]

                icd_length.append(len(c_icd_code))
                med_length.append(len(c_ndc))
                proc_length.append(len(c_procedures))
                
                icd_codes.append(c_icd_code)
                med_codes.append(c_ndc)
                proc_codes.append(c_procedures)
                
                medical_history.extend(c_icd_code)
                medical_history.extend(c_ndc)
                medical_history.extend(c_procedures)
                medical_history.extend('SEP')
                
            pat = Patient(information[0], patient_id, icd_codes, med_codes, proc_codes, medical_history, icd_length, med_length, proc_length)
            patient_history_dict[patient_id] = pat
        
    return patient_history_dict

In [4]:
def compact_matrix(axay_sparse):
    d1d2_cond = {}
    d1d2_cond_amounts = {}
    for dx, dx_prob in axay_sparse.items():
        v1,v2,dx1,dx2 = dx.split(",")
        visit_difference = str(abs(int(v1)-int(v2)))
        dx_comb = visit_difference + "," + str(dx1) + "," + str(dx2)
        if dx_comb not in d1d2_cond:
            d1d2_cond[dx_comb] = 0
            d1d2_cond_amounts[dx_comb] = 0
        d1d2_cond[dx_comb] += dx_prob
        d1d2_cond_amounts[dx_comb] += 1

    d1d2_cond_probs = {}
    for dx, amount in d1d2_cond.items():
        d1d2_cond_probs[dx] = amount/d1d2_cond_amounts[dx]
    
    return d1d2_cond_probs
    
def calculate_conditional(dx_probs, dx_probs2, dd_probs):
    #diagnosis and diagnosis conditional probabiity
    d1d2_sparse = {}
    d2d1_sparse = {}
    for dx1, dx_prob1 in dx_probs.items():
        #print(dx1[0])
        #print(dx_prob1)
        #break
        for dx2, dx_prob2 in dx_probs2.items():
            d1d2 = str(dx1[0]) + ','+ str(dx2[0]) + ", " + str(dx1[1]) + ',' + str(dx2[1])
            d2d1 = str(dx2[0]) + ','+ str(dx1[0]) + ", " + str(dx2[1]) + ',' + str(dx1[1])
            if d1d2 in dd_probs:
                d1d2_sparse[d1d2] = dd_probs[d1d2] / dx_prob1
                d2d1_sparse[d2d1] = dd_probs[d1d2] / dx_prob2
            else:
                d1d2_sparse[d1d2] = 0.0
                d2d1_sparse[d2d1] = 0.0

    
    return compact_matrix(d1d2_sparse), compact_matrix(d2d1_sparse)

def count_conditional_prob(patient_object, output_path):
    print("Conditional probabilites")
    
    dx_freqs = {}
    med_freqs = {}
    proc_freqs = {}

    pp_freqs = {}
    dd_freqs = {}
    dm_freqs = {}
    dp_freqs = {}
    mp_freqs = {}
    mm_freqs = {}
    tot_visits = 0
    for patient, pat_object in pat_hist.items():
        tot_visits += pat_object.nvisits

        # Calculate occurences of diganose codes
        visit_code_no = 0
        visit_no = 0
        for dx_code in pat_object.diagnosis_codes:
            if visit_code_no == pat_object.diagnosis_length[visit_no]:
                visit_no += 1
                visit_code_no = 0
            if (visit_no, dx_code) not in dx_freqs:
                dx_freqs[(visit_no, dx_code)] = 0
            dx_freqs[(visit_no, dx_code)] += 1
            visit_code_no += 1


        # Calculate occurences of medications codes
        visit_code_no = 0
        visit_no = 0
        for med_code in pat_object.medication_codes:
            if visit_code_no == pat_object.medication_length[visit_no]:
                visit_no += 1
                visit_code_no = 0
            if (visit_no, med_code) not in med_freqs:
                med_freqs[(visit_no, med_code)] = 0
            med_freqs[(visit_no, med_code)] += 1
            visit_code_no += 1


        # Calculate occurences of procedure_codes
        visit_code_no = 0
        visit_no = 0
        for proc_code in pat_object.procedure_codes:
            if proc_code == -1:
                visit_no += 1
                visit_code_no = 0
                continue 
            if visit_code_no == pat_object.procedure_length[visit_no]:
                visit_no += 1
                visit_code_no = 0
            if (visit_no, proc_code) not in proc_freqs:
                proc_freqs[(visit_no, proc_code)] = 0
            proc_freqs[(visit_no, proc_code)] += 1
            visit_code_no += 1


        #P(X AND Y)

        code1_visit, code2_visit = 0, 0
        code1_index, code2_index = 0, 0
        for med_code in pat_object.medication_codes: #P(M AND P)
            code2_visit = 0
            if pat_object.medication_length[code1_visit] == code1_index:
                code1_visit += 1
                code1_index = 0
            for proc_code in pat_object.procedure_codes:
                visit_difference = abs(code1_visit - code2_visit)
                comb = str(code1_visit) + "," + str(code2_visit) + ", " + str(med_code) + ',' + str(proc_code)
                if comb not in mp_freqs:
                    mp_freqs[comb] = 0
                mp_freqs[comb] += 1        
                code2_index += 1
                if pat_object.procedure_length[code2_visit] == code2_index:
                    code2_visit += 1
                    code2_index = 0
            code1_index += 1

        code1_visit, code2_visit = 0, 0
        code1_index, code2_index = 0, 0
        for dx_code1 in pat_object.diagnosis_codes: #P(D AND P)
            code2_visit = 0
            if pat_object.diagnosis_length[code1_visit] == code1_index:
                code1_visit += 1
                code1_index = 0    
            for proc_code in pat_object.procedure_codes:
                visit_difference = abs(code1_visit - code2_visit)
                comb = str(code1_visit) + "," + str(code2_visit) + ", " + str(dx_code1) + ',' + str(proc_code)
                if comb not in dp_freqs:
                    dp_freqs[comb] = 0
                dp_freqs[comb] += 1                  
                code2_index += 1
                if pat_object.procedure_length[code2_visit] == code2_index:
                    code2_visit += 1
                    code2_index = 0
            code1_index += 1

        code1_visit, code2_visit = 0, 0
        code1_index, code2_index = 0, 0
        for dx_code1 in pat_object.diagnosis_codes: #P(D AND M)
            code2_visit = 0
            if pat_object.diagnosis_length[code1_visit] == code1_index:
                code1_visit += 1
                code1_index = 0
            for med_code in pat_object.medication_codes:
                visit_difference = abs(code1_visit - code2_visit)
                comb = str(code1_visit) + "," + str(code2_visit) + ", " + str(dx_code1) + ',' + str(med_code)
                if comb not in dm_freqs:
                    dm_freqs[comb] = 0
                dm_freqs[comb] += 1                      
                code2_index += 1
                if pat_object.medication_length[code2_visit] == code2_index:
                    code2_visit += 1
                    code2_index = 0
            code1_index += 1

        code1_visit, code2_visit = 0, 0
        code1_index, code2_index = 0, 0
        for dx_code1 in pat_object.diagnosis_codes: #P(D AND D)
            code2_visit = 0
            if pat_object.diagnosis_length[code1_visit] == code1_index:
                code1_visit += 1
                code1_index = 0
            for dx_code2 in pat_object.diagnosis_codes:
                comb = str(code1_visit) + "," + str(code2_visit) + ", " + str(dx_code1) + ',' + str(dx_code2)
                if comb not in dd_freqs:
                    dd_freqs[comb] = 0
                dd_freqs[comb] += 1        
                code2_index += 1
                if pat_object.diagnosis_length[code2_visit] == code2_index:
                    code2_visit += 1
                    code2_index = 0
            code1_index += 1 
            
        code1_visit, code2_visit = 0, 0
        code1_index, code2_index = 0, 0
        for med_code1 in pat_object.medication_codes: #P(M AND M)
            code2_visit = 0
            if pat_object.medication_length[code1_visit] == code1_index:
                code1_visit += 1
                code1_index = 0
            for med_code2 in pat_object.medication_codes:
                visit_difference = abs(code1_visit - code2_visit)
                comb = str(code1_visit) + "," + str(code2_visit) + ", " + str(med_code1) + ',' + str(med_code2)
                if comb not in dm_freqs:
                    mm_freqs[comb] = 0
                mm_freqs[comb] += 1                      
                code2_index += 1
                if pat_object.medication_length[code2_visit] == code2_index:
                    code2_visit += 1
                    code2_index = 0
            code1_index += 1

            
        code1_visit, code2_visit = 0, 0
        code1_index, code2_index = 0, 0
        for proc_code1 in pat_object.procedure_codes: #P(P AND P)
            code2_visit = 0
            if pat_object.procedure_length[code1_visit] == code1_index:
                code1_visit += 1
                code1_index = 0    
            for proc_code2 in pat_object.procedure_codes:
                visit_difference = abs(code1_visit - code2_visit)
                comb = str(code1_visit) + "," + str(code2_visit) + ", " + str(proc_code1) + ',' + str(proc_code2)
                if comb not in dp_freqs:
                    pp_freqs[comb] = 0
                pp_freqs[comb] += 1                  
                code2_index += 1
                if pat_object.procedure_length[code2_visit] == code2_index:
                    code2_visit += 1
                    code2_index = 0
            code1_index += 1
            
            
    print(tot_visits)
    dx_probs = dict([(k, v / float(tot_visits)) for k, v in dx_freqs.items()]) # P(D)
    med_probs = dict([(k, v / float(tot_visits)) for k, v in med_freqs.items()]) # P(M)
    proc_probs = dict([(k, v / float(tot_visits)) for k, v in proc_freqs.items()]) # P(Procs)
    
    print("dx, med and proc probs is done")
    mp_probs = dict([(k, v / (float(tot_visits))) for k, v in mp_freqs.items()]) # P(M and P)
    dp_probs = dict([(k, v / (float(tot_visits))) for k, v in dp_freqs.items()]) # P(D and P)
    dm_probs = dict([(k, v / (float(tot_visits))) for k, v in dm_freqs.items()]) # P(D and M)
    dd_probs = dict([(k, v / (float(tot_visits))) for k, v in dd_freqs.items()]) # P(D and D)
    mm_probs = dict([(k, v / (float(tot_visits))) for k, v in dm_freqs.items()]) # P(M and M)
    pp_probs = dict([(k, v / (float(tot_visits))) for k, v in pp_freqs.items()]) # P(M and M)
    
    print("P(A and B) is done")
    
    d1d2_conditional, _ = calculate_conditional(dx_probs, dx_probs, dd_probs)
    print("d1d2 is done")
    m1m2_conditional, _ = calculate_conditional(med_probs, med_probs, mm_probs)
    print("m1m2 is done")
    p1p2_conditional, _ = calculate_conditional(proc_probs, proc_probs, pp_probs)
    print("p1p2 is done")
    dm_conditional, md_conditional = calculate_conditional(dx_probs, med_probs, dm_probs)
    print("dm md is done is done")
    dp_conditional, pd_conditional = calculate_conditional(dx_probs, proc_probs, dp_probs)
    print("dp pd is done is done")
    mp_conditional, pm_conditional = calculate_conditional(med_probs, proc_probs, mp_probs)
    
    print("Lets dump the conditional values")
    pickle.dump(d1d2_conditional,open(output_path + '/d1d2_cond_probs.empirical.p', 'wb'), -1)
    #pickle.dump(d2d1_conditional,open('/d2d1_cond_probs.empirical.p', 'wb'), -1)
    pickle.dump(m1m2_conditional,open(output_path + '/m1m2_cond_probs.empirical.p', 'wb'), -1)
    #pickle.dump(m2m1_conditional,open('/m2m1_cond_probs.empirical.p', 'wb'), -1)
    pickle.dump(p1p2_conditional,open(output_path + '/p1p2_cond_probs.empirical.p', 'wb'), -1)
    #pickle.dump(p2p1_conditional,open('/p2p1_cond_probs.empirical.p', 'wb'), -1)
    pickle.dump(dm_conditional,open(output_path + '/dm_cond_probs.empirical.p', 'wb'), -1)
    pickle.dump(md_conditional,open(output_path + '/md_cond_probs.empirical.p', 'wb'), -1)
    pickle.dump(dp_conditional,open(output_path + '/dp_cond_probs.empirical.p', 'wb'), -1)
    pickle.dump(pd_conditional,open(output_path + '/pd_cond_probs.empirical.p', 'wb'), -1)
    pickle.dump(mp_conditional,open(output_path + '/mp_cond_probs.empirical.p', 'wb'), -1)
    pickle.dump(pm_conditional,open(output_path + '/pm_cond_probs.empirical.p', 'wb'), -1)

In [5]:
pathdata = '../data/datasets/MIMIC/mimic_done3.parquet'
output_path = '../data/train_stats/MIMIC2/'
pat_hist = {}
pat2 = pd.read_parquet(pathdata).rename(columns={'ccsr':'diagnos_code', 'ndc':'medication_code', 'procedure_ccsr':'procedure_code'})

In [6]:
pat_hist = process_patient(pat2, pat_hist)
count_conditional_prob(pat_hist, output_path)

Conditional probabilites
134968
dx, med and proc probs is done
P(A and B) is done
d1d2 is done
m1m2 is done
p1p2 is done
dm md is done is done
dp pd is done is done
Lets dump the conditional values
