In [1]:
import _pickle as pickle
import csv
import os
import sys
import numpy as np
import sklearn.model_selection as ms
import tensorflow as tf
import pandas as pd

In [72]:
class Patient(object):
    def __init__(self, adm_ids, patient_id, diag_codes, medication_codes, procedure_codes, medical_history):
        self.patient_id = patient_id
        self.nvisits = len(adm_ids)
        self.diagnosis_codes = np.concatenate(diag_codes, axis=0 )
        self.medication_codes = np.concatenate(medication_codes, axis=0)
        self.procedure_codes = np.concatenate(procedure_codes, axis=0)
        
        self.medical_history = medical_history
        
        
def process_patient(infile, patient_history_dict):

    patients = pd.read_parquet(infile)
    
    patient_dict = {}
    
    count=0
    for rowindex, line in patients.iterrows():
        if count == 100:
            break
        if count % 10000 == 0:
            print(count, end='\r')
            
        patient_id = line['subject_id']
        encounter_id = line['hadm_id']
        readmission = line["label"]
        icd_code = line["diagnos_code"]
        ndc = line["medication_code"]
        procedures = line["procedure_code"]

        if patient_id not in patient_dict:
            patient_dict[patient_id] = []
        patient_dict[patient_id].append(encounter_id)
        patient_dict[patient_id].append(readmission)
        patient_dict[patient_id].append(icd_code)
        patient_dict[patient_id].append(ndc)
        patient_dict[patient_id].append(procedures)
        
        count +=1
    
    for patient_id, information in patient_dict.items():
        if sum(information[1]) > 0:
            medical_history = []
            icd_codes, med_codes, proc_codes = [], [], []
            for visit_no in range(information[0].shape[0]):
                c_hadm = information[0][visit_no]
                c_readmission = information[1][visit_no]
                c_icd_code = information[2][visit_no]
                c_ndc = information[3][visit_no]
                c_procedures = information[4][visit_no]
                
                icd_codes.append(c_icd_code)
                med_codes.append(c_ndc)
                proc_codes.append(c_procedures)
                
                medical_history.extend(c_icd_code)
                medical_history.extend(c_ndc)
                medical_history.extend(c_procedures)
                medical_history.extend('SEP')
                
            pat = Patient(patient_id, information[0], icd_codes, med_codes, proc_codes, medical_history)
            patient_history_dict[patient_id] = pat
        
    return patient_history_dict

In [98]:
def count_conditional_prob(patient_object):
    print("Conditional probabilites")
    
    tot_visits = 0
    dx_freqs = {}
    med_freqs = {}
    proc_freqs = {}
    
    dd_freqs = {}
    dm_freqs = {}
    
    for patient, pat_object in patient_object.items():
        #tot_visits +=1
        tot_visits += pat_object.nvisits
        # Calculate occurences of diganose codes
        for dx_code in pat_object.diagnosis_codes:
            if dx_code not in dx_freqs:
                dx_freqs[dx_code] = 0
            dx_freqs[dx_code] += 1
        
        # Calculate occurences of medications codes
        for med_code in pat_object.medication_codes:
            if med_code not in med_freqs:
                med_freqs[med_code] = 0
            med_freqs[med_code] += 1
            
        # Calculate occurences of procedure_codes
        for proc_code in pat_object.procedure_codes:
            if proc_code == -1:
                continue 
            if proc_code not in proc_freqs:
                proc_freqs[proc_code] = 0
            proc_freqs[proc_code] += 1
            
        
        # Calculate occurences of diagnos and diagnos
        for dx_code1 in pat_object.diagnosis_codes:
            for dx_code2 in pat_object.diagnosis_codes:
                comb = str(dx_code1) + ',' + str(dx_code2)
                if comb not in dd_freqs:
                    dd_freqs[comb] = 0
                dd_freqs[comb] += 1
                
        # Calculate occurences of diagnos and medications occurences
        for dx_code in pat_object.diagnosis_codes:
            for med_code in pat_object.medication_codes:
                comb = str(dx_code) + ',' + str(med_code)
                if comb not in dm_freqs:
                    dm_freqs[comb] = 0
                dm_freqs[comb] += 1
                
    #print(tot_visits)
    
    dx_probs = dict([(k, v / float(tot_visits)) for k, v in dx_freqs.items()]) # P(D)
    med_probs = dict([(k, v / float(tot_visits)) for k, v in med_freqs.items()]) # P(M)
    proc_probs = dict([(k, v / float(tot_visits)) for k, v in proc_freqs.items()]) # P(Procs)
    
    dd_probs = dict([(k, v / float(tot_visits)) for k, v in dd_freqs.items()]) # P(D and D)
    dm_probs = dict([(k, v / float(tot_visits)) for k, v in dm_freqs.items()]) # P(D and M)
    
    print(dd_probs.values())

In [74]:
pat_hist = {}
pat_hist = process_patient('../data/datasets/synthea/Smaller_cohorts/train.parquet', pat_hist)

0

In [99]:
count_conditional_prob(pat_hist)

Conditional probabilites
dict_values([0.022569444444444444, 0.0011574074074074073, 0.7013888888888888, 0.034722222222222224, 0.012152777777777778, 0.08217592592592593, 0.026041666666666668, 0.0011574074074074073, 0.0011574074074074073, 0.03530092592592592, 0.004050925925925926, 0.0011574074074074073, 0.0005787037037037037, 0.0011574074074074073, 0.7013888888888888, 0.03530092592592592, 98.34085648148148, 4.092592592592593, 0.546875, 2.794560185185185, 0.9612268518518519, 0.034722222222222224, 0.004050925925925926, 4.092592592592593, 0.33449074074074076, 0.03125, 0.36226851851851855, 0.04456018518518518, 0.012152777777777778, 0.0011574074074074073, 0.546875, 0.03125, 0.013888888888888888, 0.07060185185185185, 0.018518518518518517, 0.08217592592592593, 0.0005787037037037037, 2.794560185185185, 0.36226851851851855, 0.07060185185185185, 1.7615740740740742, 0.13599537037037038, 0.026041666666666668, 0.0011574074074074073, 0.9612268518518519, 0.04456018518518518, 0.018518518518518517, 0.1359