In [45]:
import os
import argparse

import pandas as pd
import numpy as np
from datetime import datetime
from collections import Counter

from datetime import datetime, timedelta
import time

import warnings

from preprocess_utils import Sampler, adjust_time, read_csv

from sqlalchemy import create_engine, text
import mysql.connector
from dotenv import load_dotenv
import os

In [51]:
# pd.set_option('display.max_columns', 250)
pd.set_option('display.max_rows', 100)


In [25]:
rng = np.random.default_rng(0)
load_dotenv()

True

In [26]:
def condition_value_shuffler(table, target_cols):
    original_idx = np.arange(len(table))
    rng_val = np.random.default_rng(0)
    shuffled_idx = rng_val.choice(original_idx, len(original_idx), replace=False).tolist()
    table.iloc[original_idx, [table.columns.get_loc(col) for col in target_cols]] = table[target_cols].values[shuffled_idx]
    return table

In [32]:
def first_admit_year_sampler(start_year, span_year, earliest_year=None):
        end_year = start_year + span_year
        prob = np.array(range(1,span_year+2))/(sum(np.array(range(1, span_year+2))))
        sampled_year = rng.choice(range(start_year,end_year+1), p=prob)
        if earliest_year is not None:
            year_adjustment = int(sampled_year - earliest_year)
            return year_adjustment*365*24*60 # in minute
        else:
            return sampled_year


In [33]:
def sample_date_given_year(self, year, num_split=1, frmt = '%Y-%m-%d'):
        start_time = time.mktime(time.strptime(f'{year}-01-01', frmt))
        end_time = time.mktime(time.strptime(f'{year}-12-31', frmt))    
        dts = []
        for split in range(num_split):
            split_seed = split/num_split + rng.random()/num_split
            ptime = start_time + split_seed * (end_time - start_time)
            dt = datetime.fromtimestamp(time.mktime(time.localtime(ptime)))
            dts.append(dt.strftime("%Y-%m-%d"))
        return dts

In [34]:
def adjust_time(table, time_col, patient_col, current_time=None, offset_dict=None):

    shifted_time = []
    for idx, time_val in enumerate(table[time_col].values):
        if pd.notnull(time_val) and time_val!='':
            if offset_dict is not None:
                id_ = table[patient_col].iloc[idx]
                if id_ in offset_dict:
                    if type(time_val)==str: # mimic3
                        time_val = str(datetime.strptime(time_val, '%Y-%m-%d %H:%M:%S') + timedelta(minutes=int(offset_dict[id_])))
                    else: # eicu
                        time_val = str(datetime.strptime(offset_dict[id_], '%Y-%m-%d %H:%M:%S') + timedelta(minutes=int(time_val)))
                else:
                    time_val = None
            if time_val is not None and current_time is not None and current_time < time_val:
                time_val = None
        else:
            time_val = None
        shifted_time.append(time_val)

    return shifted_time

In [35]:
def read_csv(data_dir, filename, columns=None, lower=True, filter=None, dtype=None, memory_efficient=False):
    filepath = os.path.join(data_dir, filename)
    if memory_efficient:
        import dask.dataframe as dd
        from dask.diagnostics import ProgressBar
        ProgressBar().register()

        if filepath.endswith('gz'):
            compression='gzip'
        else:
            compression=None

        if dtype:
            df = dd.read_csv(filepath, blocksize=25e6, dtype=dtype, compression=compression)
        else:
            df = dd.read_csv(filepath, blocksize=25e6, compression=compression)
        if columns is not None:
            df = df[columns]
        if filter is not None:
            for key in filter:
                df = df[df[key].isin(filter[key])]
        df = df.compute()
    else:
        df = pd.read_csv(filepath, usecols=columns)
        if filter is not None:
            for key in filter:
                df = df[df[key].isin(filter[key])]
    if lower:
        df = df.applymap(lambda x: x.lower().strip() if pd.notnull(x) and type(x)==str else x)
    return df

In [36]:
host = os.getenv("DB_HOST")
user = os.getenv("DB_USERNAME")
password = os.getenv("DB_PASSWORD")
database = os.getenv("DB_NAME")

In [59]:
data_dir = "../dataset/mimic-iii-raw"
out_dir = '../dataset/ehrsql'
db_name = "mimic_iii"
num_patient = 1000
sample_icu_patient_only = False
deid = True
timeshift = True
start_year = 2018
time_span = 5
cur_patient_ratio = 0.1
current_time = "2023-12-31 23:59:00"

In [60]:
if timeshift:
    assert start_year is not None, 'To do a time shift, "start_year" must be specified' 
    assert time_span is not None, 'To do a time shift, "time_span" must be specified' 
    assert current_time is not None, 'To do a time shift, "current_time" must be specified'    

# MIMIC III

In [61]:
out_dir = os.path.join(out_dir, db_name)
num_cur_patient = int(num_patient * cur_patient_ratio)
num_non_cur_patient = num_patient - num_cur_patient
if timeshift:
    start_pivot_datetime = datetime(year=start_year, month=1, day=1)

In [62]:
chartevent2itemid = {
    "Temperature C (calc)".lower(): "677",  # body temperature
    "SaO2".lower(): "834",  # Sao2
    "heart rate".lower(): "211",  # heart rate
    "Respiratory Rate".lower(): "618",  # respiration rate
    "Arterial BP [Systolic]".lower(): "51",  # systolic blood pressure
    "Arterial BP [Diastolic]".lower(): "8368",  # diastolic blood pressure
    "Arterial BP Mean".lower(): "52",  # mean blood pressure
    "Admit Wt".lower(): "762",  # weight
    "Admit Ht".lower(): "920",  # height
}

### build_admission_table

In [63]:
print("Processing PATIENTS, ADMISSIONS, ICUSTAYS, TRANSFERS")
start_time = time.time()

# read patients
PATIENTS_table = read_csv(data_dir, "PATIENTS.csv", columns=["row_id", "subject_id", "gender", "dob", "dod"], lower=True)
# print(PATIENTS_table)
subjectid2dob = {pid: dob for pid, dob in zip(PATIENTS_table["subject_id"].values, PATIENTS_table["dob"].values)}

# read admissions
ADMISSIONS_table = read_csv(
    data_dir,
    "ADMISSIONS.csv",
    columns=["row_id", "subject_id", "hadm_id", "admittime", "dischtime", "admission_type", "admission_location", "discharge_location", "insurance", "language", "marital_status", "ethnicity"],
    lower=True,
)
ADMISSIONS_table["AGE"] = [
    int((datetime.strptime(admtime, "%Y-%m-%d %H:%M:%S") - datetime.strptime(subjectid2dob[pid], "%Y-%m-%d %H:%M:%S")).days / 365.25)
    for pid, admtime in zip(ADMISSIONS_table["subject_id"].values, ADMISSIONS_table["admittime"].values)
]

# save original admittime
hadm_id2admtime_dict = {hadm: admtime for hadm, admtime in zip(ADMISSIONS_table["hadm_id"].values, ADMISSIONS_table["admittime"].values)}
hadm_id2dischtime_dict = {hadm: dischtime for hadm, dischtime in zip(ADMISSIONS_table["hadm_id"].values, ADMISSIONS_table["dischtime"].values)}

# get earlist admission time
admittime_earliest = {subj_id: min(ADMISSIONS_table["admittime"][ADMISSIONS_table["subject_id"] == subj_id].values) for subj_id in ADMISSIONS_table["subject_id"].unique()}
if timeshift:
    subjectid2admittime_dict = {
        subj_id: first_admit_year_sampler(start_year, time_span, datetime.strptime(admittime_earliest[subj_id], "%Y-%m-%d %H:%M:%S").year)
        for subj_id in ADMISSIONS_table["subject_id"].unique()
    }

# read icustays
ICUSTAYS_table = read_csv(
    data_dir, "ICUSTAYS.csv", columns=["row_id", "subject_id", "hadm_id", "icustay_id", "first_careunit", "last_careunit", "first_wardid", "last_wardid", "intime", "outtime"], lower=True
)
# subset only icu patients
if sample_icu_patient_only:
    ADMISSIONS_table = ADMISSIONS_table[ADMISSIONS_table["subject_id"].isin(set(ICUSTAYS_table["subject_id"]))]

# read transfer
TRANSFERS_table = read_csv(
    data_dir, "TRANSFERS.csv", columns=["row_id", "subject_id", "hadm_id", "icustay_id", "eventtype", "curr_careunit", "curr_wardid", "intime", "outtime"], lower=True
)
TRANSFERS_table = TRANSFERS_table.rename(columns={"curr_careunit": "careunit", "curr_wardid": "wardid"})
TRANSFERS_table = TRANSFERS_table.dropna(subset=["intime"])

# process patients
if timeshift:
    PATIENTS_table["dob"] = adjust_time(PATIENTS_table, "dob", current_time=current_time, offset_dict=subjectid2admittime_dict, patient_col="subject_id")
    PATIENTS_table["dod"] = adjust_time(PATIENTS_table, "dod", current_time=current_time, offset_dict=subjectid2admittime_dict, patient_col="subject_id")
    PATIENTS_table = PATIENTS_table.dropna(subset=["dob"])

# process admissions
if timeshift:
    ADMISSIONS_table["admittime"] = adjust_time(ADMISSIONS_table, "admittime", current_time=current_time, offset_dict=subjectid2admittime_dict, patient_col="subject_id")
    ADMISSIONS_table = ADMISSIONS_table.dropna(subset=["admittime"])
    ADMISSIONS_table["dischtime"] = adjust_time(ADMISSIONS_table, "dischtime", current_time=current_time, offset_dict=subjectid2admittime_dict, patient_col="subject_id")
    ADMISSIONS_table['discharge_location'] = [loc if pd.notnull(t) else None for loc, t in zip(ADMISSIONS_table["discharge_location"], ADMISSIONS_table["dischtime"])]

# process icustays
if timeshift:
    ICUSTAYS_table["intime"] = adjust_time(ICUSTAYS_table, "intime", current_time=current_time, offset_dict=subjectid2admittime_dict, patient_col="subject_id")
    ICUSTAYS_table["outtime"] = adjust_time(ICUSTAYS_table, "outtime", current_time=current_time, offset_dict=subjectid2admittime_dict, patient_col="subject_id")
    ICUSTAYS_table = ICUSTAYS_table.dropna(subset=["intime"])

# process transfers
if timeshift:
    TRANSFERS_table["intime"] = adjust_time(TRANSFERS_table, "intime", current_time=current_time, offset_dict=subjectid2admittime_dict, patient_col="subject_id")
    TRANSFERS_table["outtime"] = adjust_time(TRANSFERS_table, "outtime", current_time=current_time, offset_dict=subjectid2admittime_dict, patient_col="subject_id")
    TRANSFERS_table = TRANSFERS_table.dropna(subset=["intime"])

################################################################################        
"""
Decide the final cohort of patients: `cur_patient_list` and `non_cur_patient`
"""
# sample current patients
try:
    cur_patient_list = rng.choice(
        ADMISSIONS_table["subject_id"][ADMISSIONS_table["dischtime"].isnull()].unique(),
        num_cur_patient,
        replace=False,
    ).tolist()
except:
    print("Cannot take a larger sample than population when 'replace=False")
    print("Use all available patients instead.")
    cur_patient_list = ADMISSIONS_table["subject_id"][ADMISSIONS_table["dischtime"].isnull()].unique().tolist()

# sample non-current patients
try:
    non_cur_patient = rng.choice(
        ADMISSIONS_table["subject_id"][(ADMISSIONS_table["dischtime"].notnull()) & (~ADMISSIONS_table["subject_id"].isin(cur_patient_list))].unique(),
        num_non_cur_patient,
        replace=False,
    ).tolist()
except:
    print("Cannot take a larger sample than population when 'replace=False")
    print("Use all available patients instead.")
    non_cur_patient = ADMISSIONS_table["subject_id"][(ADMISSIONS_table["dischtime"].notnull()) & (~ADMISSIONS_table["subject_id"].isin(cur_patient_list))].unique().tolist()

patient_list = cur_patient_list + non_cur_patient
print(f"num_cur_patient: {len(cur_patient_list)}")
print(f"num_non_cur_patient: {len(non_cur_patient)}")
print(f"num_patient: {len(patient_list)}")

PATIENTS_table = PATIENTS_table[PATIENTS_table["subject_id"].isin(patient_list)]
ADMISSIONS_table = ADMISSIONS_table[ADMISSIONS_table["subject_id"].isin(patient_list)]

hadm_list = list(set(ADMISSIONS_table["hadm_id"]))
ICUSTAYS_table = ICUSTAYS_table[ICUSTAYS_table["hadm_id"].isin(hadm_list)]
TRANSFERS_table = TRANSFERS_table[TRANSFERS_table["hadm_id"].isin(hadm_list)]

if deid:
    icu2careunit = {}
    icu2wardid = {}
    random_indices = rng.choice(len(ICUSTAYS_table), len(ICUSTAYS_table), replace=False).tolist()
    for idx, icu in enumerate(ICUSTAYS_table["icustay_id"]):
        icu2careunit[icu] = {}
        icu2careunit[icu][ICUSTAYS_table["first_careunit"][ICUSTAYS_table["icustay_id"] == icu].values[0]] = ICUSTAYS_table["first_careunit"].iloc[random_indices[idx]]
        icu2careunit[icu][ICUSTAYS_table["last_careunit"][ICUSTAYS_table["icustay_id"] == icu].values[0]] = ICUSTAYS_table["last_careunit"].iloc[random_indices[idx]]
        ICUSTAYS_table["first_careunit"][ICUSTAYS_table["icustay_id"] == icu] = ICUSTAYS_table["first_careunit"].iloc[random_indices[idx]]
        ICUSTAYS_table["last_careunit"][ICUSTAYS_table["icustay_id"] == icu] = ICUSTAYS_table["last_careunit"].iloc[random_indices[idx]]
        icu2wardid[icu] = {}
        icu2wardid[icu][ICUSTAYS_table["first_wardid"][ICUSTAYS_table["icustay_id"] == icu].values[0]] = ICUSTAYS_table["first_wardid"].iloc[random_indices[idx]]
        icu2wardid[icu][ICUSTAYS_table["last_wardid"][ICUSTAYS_table["icustay_id"] == icu].values[0]] = ICUSTAYS_table["last_wardid"].iloc[random_indices[idx]]
        ICUSTAYS_table["first_wardid"][ICUSTAYS_table["icustay_id"] == icu] = ICUSTAYS_table["first_wardid"].iloc[random_indices[idx]]
        ICUSTAYS_table["last_wardid"][ICUSTAYS_table["icustay_id"] == icu] = ICUSTAYS_table["last_wardid"].iloc[random_indices[idx]]

    for icu in ICUSTAYS_table["icustay_id"]:
        TRANSFERS_table["careunit"][TRANSFERS_table["icustay_id"] == icu] = TRANSFERS_table["careunit"][TRANSFERS_table["icustay_id"] == icu].replace(icu2careunit[icu])
        TRANSFERS_table["wardid"][TRANSFERS_table["icustay_id"] == icu] = TRANSFERS_table["wardid"][TRANSFERS_table["icustay_id"] == icu].replace(icu2wardid[icu])

PATIENTS_table["row_id"] = range(len(PATIENTS_table))
ADMISSIONS_table["row_id"] = range(len(ADMISSIONS_table))
ICUSTAYS_table["row_id"] = range(len(ICUSTAYS_table))
TRANSFERS_table["row_id"] = range(len(TRANSFERS_table))

PATIENTS_table.to_csv(os.path.join(out_dir, "PATIENTS.csv"), index=False)
ADMISSIONS_table.to_csv(os.path.join(out_dir, "ADMISSIONS.csv"), index=False)
ICUSTAYS_table.to_csv(os.path.join(out_dir, "ICUSTAYS.csv"), index=False)
TRANSFERS_table.to_csv(os.path.join(out_dir, "TRANSFERS.csv"), index=False)

print(f"PATIENTS, ADMISSIONS, ICUSTAYS, TRANSFERS processed (took {round(time.time() - start_time, 4)} secs)")

Processing PATIENTS, ADMISSIONS, ICUSTAYS, TRANSFERS
Cannot take a larger sample than population when 'replace=False
Use all available patients instead.
Cannot take a larger sample than population when 'replace=False
Use all available patients instead.
num_cur_patient: 1
num_non_cur_patient: 96
num_patient: 97
PATIENTS, ADMISSIONS, ICUSTAYS, TRANSFERS processed (took 0.6855 secs)


In [64]:
PATIENTS_table.at[0, 'dod'] = None
display(PATIENTS_table)

Unnamed: 0,row_id,subject_id,gender,dob,dod
0,0,10006,f,1953-04-09 00:00:00,
1,1,10011,f,1986-07-01 00:00:00,2022-09-22 00:00:00
2,2,10013,f,1936-09-28 00:00:00,2023-11-01 00:00:00
3,3,10017,f,1947-10-23 00:00:00,
4,4,10019,m,1972-07-24 00:00:00,2021-06-18 00:00:00
5,5,10026,f,1721-06-28 00:00:00,2022-01-05 00:00:00
6,6,10027,f,1938-02-25 00:00:00,2020-10-25 00:00:00
7,7,10029,m,1944-05-09 00:00:00,2023-10-20 00:00:00
8,8,10032,m,1932-04-27 00:00:00,2020-06-18 00:00:00
10,9,10035,m,1944-05-10 00:00:00,


In [66]:

conn_string = f'mysql+mysqlconnector://{user}:{password}@{host}/{database}'
engine = create_engine(conn_string, echo=True)
conn = engine.connect()

rows = read_csv(out_dir, "PATIENTS.csv")
rows.to_sql("PATIENTS", conn, if_exists="append", index=False)

2024-02-10 15:37:59,709 INFO sqlalchemy.engine.Engine SELECT DATABASE()
2024-02-10 15:37:59,710 INFO sqlalchemy.engine.Engine [raw sql] {}
2024-02-10 15:37:59,861 INFO sqlalchemy.engine.Engine SELECT @@sql_mode
2024-02-10 15:37:59,862 INFO sqlalchemy.engine.Engine [raw sql] {}
2024-02-10 15:37:59,970 INFO sqlalchemy.engine.Engine SELECT @@lower_case_table_names
2024-02-10 15:37:59,971 INFO sqlalchemy.engine.Engine [raw sql] {}
2024-02-10 15:38:00,085 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-02-10 15:38:00,179 INFO sqlalchemy.engine.Engine DESCRIBE `test`.`PATIENTS`
2024-02-10 15:38:00,180 INFO sqlalchemy.engine.Engine [raw sql] {}
2024-02-10 15:38:00,255 INFO sqlalchemy.engine.Engine INSERT INTO `PATIENTS` (row_id, subject_id, gender, dob, dod) VALUES (%(row_id)s, %(subject_id)s, %(gender)s, %(dob)s, %(dod)s)
2024-02-10 15:38:00,255 INFO sqlalchemy.engine.Engine [generated in 0.02897s] [{'row_id': 0, 'subject_id': 10006, 'gender': 'f', 'dob': '1953-04-09 00:00:00', 'dod': No

DataError: (mysql.connector.errors.DataError) 1292 (22007): target: test.-.primary: vttablet: rpc error: code = InvalidArgument desc = Incorrect datetime value: '1953-04-09 00:00:00' for column 'DOB' at row 1 (errno 1292) (sqlstate 22007) (CallerID: g3k5rtdlq1vbv9x9rjz9): Sql: "insert into PATIENTS(row_id, subject_id, gender, dob, dod) values (:vtg1 /* INT64 */, :vtg2 /* INT64 */, :vtg3 /* VARCHAR */, :vtg4 /* VARCHAR */, null), (:vtg5 /* INT64 */, :vtg6 /* INT64 */, :vtg7 /* VARCHAR */, :vtg8 /* VARCHAR */, :vtg9 /* VARCHAR */), (:vtg10 /* INT64 */, :vtg11 /* INT6
[SQL: INSERT INTO `PATIENTS` (row_id, subject_id, gender, dob, dod) VALUES (%(row_id)s, %(subject_id)s, %(gender)s, %(dob)s, %(dod)s)]
[parameters: [{'row_id': 0, 'subject_id': 10006, 'gender': 'f', 'dob': '1953-04-09 00:00:00', 'dod': None}, {'row_id': 1, 'subject_id': 10011, 'gender': 'f', 'dob': '1986-07-01 00:00:00', 'dod': '2022-09-22 00:00:00'}, {'row_id': 2, 'subject_id': 10013, 'gender': 'f', 'dob': '1936-09-28 00:00:00', 'dod': '2023-11-01 00:00:00'}, {'row_id': 3, 'subject_id': 10017, 'gender': 'f', 'dob': '1947-10-23 00:00:00', 'dod': None}, {'row_id': 4, 'subject_id': 10019, 'gender': 'm', 'dob': '1972-07-24 00:00:00', 'dod': '2021-06-18 00:00:00'}, {'row_id': 5, 'subject_id': 10026, 'gender': 'f', 'dob': '1721-06-28 00:00:00', 'dod': '2022-01-05 00:00:00'}, {'row_id': 6, 'subject_id': 10027, 'gender': 'f', 'dob': '1938-02-25 00:00:00', 'dod': '2020-10-25 00:00:00'}, {'row_id': 7, 'subject_id': 10029, 'gender': 'm', 'dob': '1944-05-09 00:00:00', 'dod': '2023-10-20 00:00:00'}  ... displaying 10 of 97 total bound parameter sets ...  {'row_id': 95, 'subject_id': 44212, 'gender': 'f', 'dob': '1978-07-11 00:00:00', 'dod': None}, {'row_id': 96, 'subject_id': 44222, 'gender': 'm', 'dob': '1949-08-04 00:00:00', 'dod': None}]]
(Background on this error at: https://sqlalche.me/e/20/9h9h)