In [7]:
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.data import Patient, Visit, Event
from tqdm import tqdm
import pandas as pd
import pickle

# Set this to the directory with all MIMIC-3 dataset files
data_root = "data"

In [2]:
# Load the dataset

mimic3_ds = MIMIC3Dataset(
        root=data_root,
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD"],
        dev=False
)

In [3]:
# Print dataset statistics

mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 46520
	- Number of visits: 58976
	- Number of visits per patient: 1.2678
	- Number of events per visit in DIAGNOSES_ICD: 11.0384
	- Number of events per visit in PROCEDURES_ICD: 4.0711



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 46520\n\t- Number of visits: 58976\n\t- Number of visits per patient: 1.2678\n\t- Number of events per visit in DIAGNOSES_ICD: 11.0384\n\t- Number of events per visit in PROCEDURES_ICD: 4.0711\n'

In [4]:
# Find all diagnoses codes
# Remove diagnoses codes with fewer than 5 occurences in the dataset

all_diagnosis_codes = []
for patient_id, patient in mimic3_ds.patients.items():
  for i in range(len(patient)):
    visit: Visit = patient[i]
    conditions = visit.get_code_list(table="DIAGNOSES_ICD")
    all_diagnosis_codes.extend(conditions)

codes = pd.Series(all_diagnosis_codes)
diag_code_counts = codes.value_counts()
filtered_diag_codes = diag_code_counts[diag_code_counts > 4].index.values
n_unique_diag_codes = len(filtered_diag_codes)

In [5]:
MIN_N_VISITS_PER_PATIENT = 2

# Filter Dataset to requirements specified in paper

filtered_patients = {}
for patient_id, patient in tqdm(mimic3_ds.patients.items()):

    filtered_patient: Patient = Patient(
        patient_id=patient.patient_id,
        birth_datetime=patient.birth_datetime,
        death_datetime=patient.death_datetime,
        gender=patient.gender,
        ethnicity=patient.ethnicity
    )

    for i_visit, visit in enumerate(patient):
        filtered_visit: Visit = Visit(
            visit_id=visit.visit_id,
            patient_id=visit.patient_id,
            encounter_time=visit.encounter_time,
            discharge_time=visit.discharge_time,
            discharge_status=visit.discharge_status
        )

        diagnoses_codes = visit.get_code_list("DIAGNOSES_ICD")
        procedures_codes = visit.get_code_list("PROCEDURES_ICD")
        prescriptions_codes = visit.get_code_list("PRESCRIPTIONS")

        if len(diagnoses_codes) > 0:
            diagnosis_events = visit.event_list_dict["DIAGNOSES_ICD"]
            for i_event in range(len(diagnosis_events) - 1, -1, -1):
                event: Event = diagnosis_events[i_event]
                if event.code not in filtered_diag_codes:
                    diagnosis_events.pop(i_event) # Remove the diagnosis code with fewer than the cutoff occurrences

            if len(diagnosis_events) == 0: continue # Don't include visits with no diagnoses

            filtered_visit.set_event_list("DIAGNOSES_ICD", diagnosis_events)
        else:
            continue # Don't include visits with no diagnoses

        if len(procedures_codes) > 0:
           filtered_visit.set_event_list("PROCEDURES_ICD", visit.event_list_dict["PROCEDURES_ICD"])

        if len(prescriptions_codes) > 0:
            filtered_visit.set_event_list("PRESCRIPTIONS", visit.event_list_dict["PRESCRIPTIONS"])

        filtered_patient.add_visit(filtered_visit)

    if len(filtered_patient.visits) >= MIN_N_VISITS_PER_PATIENT:
        filtered_patients[patient_id] = filtered_patient


100%|██████████| 46520/46520 [01:08<00:00, 675.28it/s]


In [6]:
mimic3_ds.patients = filtered_patients
mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 7496
	- Number of visits: 19905
	- Number of visits per patient: 2.6554
	- Number of events per visit in DIAGNOSES_ICD: 12.9735
	- Number of events per visit in PROCEDURES_ICD: 4.0975



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 7496\n\t- Number of visits: 19905\n\t- Number of visits per patient: 2.6554\n\t- Number of events per visit in DIAGNOSES_ICD: 12.9735\n\t- Number of events per visit in PROCEDURES_ICD: 4.0975\n'

In [11]:
# Export to pickle
with open("mimic3_dataset.pkl", "wb") as dataset_file:
    pickle.dump(mimic3_ds, dataset_file)