In [None]:
import csv
import os
import pickle
import sys
import time

import sklearn.model_selection as ms
import torch
from torch.utils.data import TensorDataset
from tqdm import tqdm

In [None]:
import pyhealth
from pyhealth.data import Event, Visit, Patient

import numpy as np

np.random.seed(1234)

In [None]:
from pyhealth.datasets import eICUDataset

dataset_const = eICUDataset(
    root='../../eicu_csv',
    tables=["diagnosis", "treatment", "admissionDx"],
    refresh_cache=False,
    dev=True
)

dataset = eICUDataset(
    root='../../eicu_csv',
    tables=["diagnosis", "treatment", "admissionDx"],
    refresh_cache=False,
    dev=True
)

In [None]:
dataset.stat()
dataset.info()

In [None]:
dataset.patients.values()

In [None]:
patient_id = '002-9990+146474'
patient = dataset.patients[patient_id]
visits = dataset.patients[patient_id].visits
visits

In [None]:
visit_id = '163891'
dir(visits[visit_id])

In [None]:
visit = visits[visit_id]
print("### Accessing the diagnosis events ###")
print(visit.get_event_list('diagnosis'))
visit.get_code_list('diagnosis')

print("### Accessing the admissionDx events ###")
print(visit.get_event_list('admissionDx'))

print("### Accessing the treatment events ###")
print(visit.get_event_list('treatment'))

In [None]:
# Dropping patient with less than 24 hours duration minute
# should be stated in the data entry 'unitdischargeoffset'
# aka visit.discharge_time - visit.encounter_time
def process_patient(ds, hour_threshold=24):
    dataset_processed = ds
    encounter_processed_count = 0
    encounter_deleted_count = 0

    for patient_id, patient in ds.patients.items():
        visits = patient.visits.copy()
        for visit_id, visit in visits.items():
            encounter_processed_count += 1
            if (visit.discharge_time - visit.encounter_time) < np.timedelta64(hour_threshold, 'h'):
                # print("Dropping patient {} visit {} due to less than {} hours duration".format(patient_id, visit_id, hour_threshold))
                encounter_deleted_count += 1
                del dataset_processed.patients[patient_id].visits[visit_id]

    print("Processed {} encounters, deleted {} encounters".format(encounter_processed_count, encounter_deleted_count))
    return dataset_processed


# Processed 200859 encounters, deleted 67959 encounters
dataset_processed = process_patient(dataset)


In [None]:
# readmission prediction
from pyhealth.tasks import readmission_prediction_eicu_fn
eicu_base = eICUDataset(
    root='../../eicu_csv',
    tables=["diagnosis", "treatment", "admissionDx", "physicalExam", "medication", "lab"],
    dev=True,
    refresh_cache=False,
)
sample_dataset = eicu_base.set_task(task_fn=readmission_prediction_eicu_fn)

In [None]:
sample_dataset.stat()
print(sample_dataset.available_keys)
print(sample_dataset.samples[0])

In [None]:
from pyhealth.tasks import mortality_prediction_eicu_fn
sample_dataset = eicu_base.set_task(mortality_prediction_eicu_fn)
sample_dataset.stat()
print(sample_dataset.available_keys)
print(sample_dataset.samples[0])

In [None]:
from load_eicu import readmission_prediction_eicu_fn_customized
samples_list = []
readmission_count = 0
for patient in eicu_base.patients:
    samples = readmission_prediction_eicu_fn_customized(eicu_base.patients[patient], 20)
    if len(samples) != 0:
        for sample in samples:
            if sample['label'] == 1:
                readmission_count += 1
                break
        samples_list.append(samples)

print("Total number of patients: {}".format(len(eicu_base.patients.keys())))
print("Total number of samples: {}".format(len(samples_list)))
print("Total number of readmission: {}".format(readmission_count))


In [None]:
admissionDx = samples_list[42][1]['admissionDx'][0]
conditions = samples_list[42][1]['conditions'][0]
treatment = samples_list[42][1]['treatment'][0]
# samples_list[42]
[ax.lower() for ax in admissionDx]

In [None]:
visit = eicu_base.patients[samples_list[42][1]['patient_id']].visits[samples_list[42][1]['visit_id']]

In [None]:
visit.get_event_list('admissionDx')

In [None]:
visit.get_code_list('admissionDx')

In [None]:
# print the encounter id and its label in samples_list
for sample in samples_list[42]:
    print(sample['visit_id'], sample['label'])