In [1]:
# Import packages
import pandas as pd
import numpy as np
import os
import pyspark
from pathlib import Path
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number, unix_timestamp, when, year, round
from tqdm import tqdm

# Make pandas dataframes prettier
from IPython.display import display, HTML

In [2]:
# Initialise PySpark context and session
sc = pyspark.SparkContext(appName="appName")
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("appName").getOrCreate()

23/08/22 20:23:26 WARN Utils: Your hostname, Christians-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.10.107 instead (on interface en0)
23/08/22 20:23:26 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/08/22 20:23:26 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Retrieve relevant tables

In [3]:
# Load dataset using PySpark

mimic_dir = "../../data/mimic-iii/"

adm_df = spark.read.option("header", True).option("inferSchema", True).csv(
    mimic_dir + "ADMISSIONS.csv")
pat_df = spark.read.option("header", True).option("inferSchema", True).csv(
    mimic_dir + "PATIENTS.csv")

In [4]:
# List down relevant variables in each table
adm_cols = ["SUBJECT_ID", "HADM_ID", "INSURANCE", "MARITAL_STATUS", "ETHNICITY", 
            "HOSPITAL_EXPIRE_FLAG", "ADMISSION_TYPE", "ADMITTIME", "DISCHTIME", 
            "EDREGTIME", "EDOUTTIME"]
pat_cols = ["SUBJECT_ID", "GENDER", "DOB"]

## Data processing

### Step 1: Identify unique patients with emergency admissions

In [5]:
# Filter patients from emergency setting
emergency_adm_df = adm_df.filter(adm_df.ADMISSION_TYPE == "EMERGENCY").select(*adm_cols)
print("Number of emergency admissions = {}".format(emergency_adm_df.count()))

Number of emergency admissions = 42071


In [6]:
# Identify first admissions for all patients
windowSubj = Window.partitionBy("SUBJECT_ID").orderBy(col("ADMITTIME"))
first_adm_df = emergency_adm_df.withColumn("row", row_number().over(windowSubj)).filter(
    col("row") == 1).drop("row")
print("Number of unique patients = {}".format(first_adm_df.count()))

Number of unique patients = 32610


In [7]:
# Change column names for index admission info
first_adm_df = first_adm_df.withColumnRenamed("ADMITTIME", "INDEX_ADMITTIME")
first_adm_df = first_adm_df.withColumnRenamed("DISCHTIME", "INDEX_DISCHTIME")
first_adm_df = first_adm_df.withColumnRenamed("HADM_ID", "INDEX_HADM_ID")

In [8]:
# Remove patients who died in hospital
first_adm_df = first_adm_df.filter(first_adm_df.HOSPITAL_EXPIRE_FLAG != 1)
print("Number of patients who survived index admission = {}".format(first_adm_df.count()))

Number of patients who survived index admission = 28404


### Step 2: Identify patients with 90-day emergency readmissions

In [9]:
# Join with the bigger table to identify readmissions
first_adm_df_joined = first_adm_df.select(
    "SUBJECT_ID", "INDEX_HADM_ID", "INDEX_DISCHTIME").join(
    emergency_adm_df.select("SUBJECT_ID", "HADM_ID", "ADMITTIME"), 
    first_adm_df.SUBJECT_ID == emergency_adm_df.SUBJECT_ID
)

# Remove index admission records from consideration
first_adm_df_joined = first_adm_df_joined.filter(col("INDEX_HADM_ID") != col("HADM_ID"))

first_adm_df_joined.count()

9430

In [10]:
# Calculate time difference and identify 90-day readmission
first_adm_df_joined = first_adm_df_joined.withColumn(
    "DAYS_FROM_INDEX", 
    (unix_timestamp("ADMITTIME") - unix_timestamp("INDEX_DISCHTIME"))/(3600*24)
)

first_adm_df_joined = first_adm_df_joined.withColumn(
    "90DAYREADM", col("DAYS_FROM_INDEX") <= 90
)

In [11]:
print("Number of readmissions within 90 days = {}".format(
    first_adm_df_joined.filter(first_adm_df_joined["90DAYREADM"] == True).count()))

Number of readmissions within 90 days = 2615


In [12]:
# Join the readmission flag back to the first admission table

first_adm_df_joined = first_adm_df_joined.withColumnRenamed("INDEX_HADM_ID", "INDEX_HADM_ID2")
first_adm_df = first_adm_df.join(
    first_adm_df_joined.select("INDEX_HADM_ID2", "90DAYREADM").filter(
        first_adm_df_joined["90DAYREADM"] == True).dropDuplicates(), 
    first_adm_df.INDEX_HADM_ID == first_adm_df_joined.INDEX_HADM_ID2,
    "left"
)

In [13]:
print("Sanity check: Number of unique patients = {}".format(first_adm_df.count()))
print("Number of patients with 90-day readmissions = {}".format(
    first_adm_df.filter(first_adm_df["90DAYREADM"] == True).count()))

Sanity check: Number of unique patients = 28404
Number of patients with 90-day readmissions = 2259


In [14]:
# Drop redundant columns and fill missing readmission flags
first_adm_df = first_adm_df.drop("INDEX_HADM_ID2", "HOSPITAL_EXPIRE_FLAG", "ADMISSION_TYPE")
first_adm_df = first_adm_df.fillna({"90DAYREADM": False})

In [15]:
# View output table
first_adm_df.show(10)

23/08/22 20:23:39 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
+----------+-------------+---------+--------------+--------------------+-------------------+-------------------+-------------------+-------------------+----------+
|SUBJECT_ID|INDEX_HADM_ID|INSURANCE|MARITAL_STATUS|           ETHNICITY|    INDEX_ADMITTIME|    INDEX_DISCHTIME|          EDREGTIME|          EDOUTTIME|90DAYREADM|
+----------+-------------+---------+--------------+--------------------+-------------------+-------------------+-------------------+-------------------+----------+
|        37|       188670| Medicare|       MARRIED|               WHITE|2183-08-21 16:48:00|2183-08-26 18:54:00|2183-08-21 05:58:00|2183-08-21 17:45:00|     false|
|        43|       146828|  Private|       MARRIED|               WHITE|218

### Step 3: Merge gender and DOB data

In [16]:
# Join with patients table. Rename to allow easier removal of variables later
# as otherwise PySpark keeps a duplicate
pat_df2 = pat_df.select(*pat_cols).withColumnRenamed("SUBJECT_ID", "SUBJECT_ID2")
combined_df = first_adm_df.join(
    pat_df2, first_adm_df.SUBJECT_ID == pat_df2.SUBJECT_ID2, "left"
)

# Remove redundant variables
combined_df = combined_df.drop("SUBJECT_ID2")

print("Sanity check: Number of unique patients = {}".format(combined_df.count()))

Sanity check: Number of unique patients = 28404


### Step 4: Create derived variables

#### Length of stay (LOS)

In [17]:
# Calculate length of stay
combined_df = combined_df.withColumn(
    "LOS", 
    round((unix_timestamp("INDEX_DISCHTIME") - unix_timestamp("INDEX_ADMITTIME"))/(3600*24) + 1)
)

In [18]:
# Observe summary statistics to assess abnormalities
combined_df.describe("LOS").show()

+-------+------------------+
|summary|               LOS|
+-------+------------------+
|  count|             28404|
|   mean|11.181277284889452|
| stddev| 10.52733385349223|
|    min|               1.0|
|    max|             204.0|
+-------+------------------+



#### Emergency department (ED) visit of more than 6 hours

In [19]:
# Calculate time at ED if present
combined_df = combined_df.withColumn(
    "TIME_AT_ED",
    when(col("EDREGTIME").isNotNull(), 
         (unix_timestamp("EDOUTTIME") - unix_timestamp("EDREGTIME"))/3600)
)
#combined_df = combined_df.fillna({"TIME_AT_ED": 0})

In [20]:
# Create flag of ED visit of more than 6 hours
combined_df = combined_df.withColumn("ED_6HRS", col("TIME_AT_ED") >= 6)

In [21]:
print("Number of patients with no ED visit =", combined_df.filter(col("ED_6HRS").isNull()).count())

Number of patients with no ED visit = 8595


In [22]:
print("Number of patients with ED visit longer than 6 hrs =", 
      combined_df.filter(col("ED_6HRS") == True).count())

Number of patients with ED visit longer than 6 hrs = 6737


#### Age at index admission

In [23]:
# Calculate age
combined_df = combined_df.withColumn("AGE", (year("INDEX_ADMITTIME") - year("DOB")))

In [24]:
# Observe summary statistics to assess abnormalities
combined_df.describe("AGE").show()

+-------+-----------------+
|summary|              AGE|
+-------+-----------------+
|  count|            28404|
|   mean|73.08717082101113|
| stddev|55.94813102926004|
|    min|                0|
|    max|              310|
+-------+-----------------+



In [25]:
print("Number of patients with recorded age above 99 = {}".format(
    combined_df.filter(combined_df.AGE > 99).count()))

Number of patients with recorded age above 99 = 1470


In [26]:
print("Number of patients with recorded age of 0 = {}".format(
    combined_df.filter(combined_df.AGE == 0).count()))

Number of patients with recorded age of 0 = 202


In [27]:
print("Number of patients with recorded (non-zero) age below 18 = {}".format(
    combined_df.filter((combined_df.AGE > 0) & (combined_df.AGE < 18)).count()))

Number of patients with recorded (non-zero) age below 18 = 71


#### Final step: cleanup

In [28]:
# Remove redundant variables
combined_df = combined_df.drop("EDREGTIME", "EDOUTTIME", "TIME_AT_ED", 
                               "INDEX_DISCHTIME", "INDEX_ADMITTIME", "DOB")

print("Sanity check: Number of unique patients = {}".format(combined_df.count()))

Sanity check: Number of unique patients = 28404


In [29]:
# View first few columns
combined_df.show(10)

+----------+-------------+---------+--------------+--------------------+----------+------+----+-------+---+
|SUBJECT_ID|INDEX_HADM_ID|INSURANCE|MARITAL_STATUS|           ETHNICITY|90DAYREADM|GENDER| LOS|ED_6HRS|AGE|
+----------+-------------+---------+--------------+--------------------+----------+------+----+-------+---+
|        37|       188670| Medicare|       MARRIED|               WHITE|     false|     M| 6.0|   true| 69|
|        43|       146828|  Private|       MARRIED|               WHITE|     false|     M|11.0|  false| 33|
|        26|       197661| Medicare|        SINGLE|UNKNOWN/NOT SPECI...|     false|     M| 8.0|   null| 72|
|        34|       115799| Medicare|       MARRIED|               WHITE|     false|     M| 3.0|   null|300|
|        13|       143045| Medicaid|          null|               WHITE|     false|     F| 8.0|   null| 40|
|        52|       190797|  Private|        SINGLE|               WHITE|     false|     M|11.0|   null| 39|
|        22|       165315|  

### Step 5: Standardise recording of missing values

In [30]:
# Standardise recording of missing values
combined_df = combined_df.withColumn(
    "MARITAL_STATUS_CLEAN", 
    when(~col("MARITAL_STATUS").contains("UNKNOWN"), col("MARITAL_STATUS"))
)

In [31]:
# Convert abnormal age to missing values
combined_df = combined_df.withColumn(
    "AGE_CLEAN",
    when((~(col("AGE") == 0)) & (~(col("AGE") > 99)), col("AGE"))
)

# Observe summary statistics to assess abnormalities
combined_df.describe("AGE_CLEAN").show()

+-------+-----------------+
|summary|        AGE_CLEAN|
+-------+-----------------+
|  count|            26732|
|   mean|61.15894807721083|
| stddev|17.62539593688143|
|    min|                1|
|    max|               89|
+-------+-----------------+



In [32]:
# Clean up column names
combined_df = combined_df.drop("MARITAL_STATUS", "AGE")
combined_df = combined_df.withColumnRenamed("MARITAL_STATUS_CLEAN", "MARITAL_STATUS")
combined_df = combined_df.withColumnRenamed("AGE_CLEAN", "AGE")

### Step 6: Remove patients below 18 years old

In [33]:
# Filter non-adult patients
combined_df = combined_df.filter((combined_df.AGE >= 18) | (col("AGE").isNull()))
print("Number of patients after excluding paediatrics = {}".format(combined_df.count()))

Number of patients after excluding paediatrics = 28333


### Step 7: Conversion to Pandas

In [34]:
# Convert to Pandas DataFrame
final_df = combined_df.toPandas()

In [35]:
# Initial check on number of missing values
final_df.isna().sum()

SUBJECT_ID           0
INDEX_HADM_ID        0
INSURANCE            0
ETHNICITY            0
90DAYREADM           0
GENDER               0
LOS                  0
ED_6HRS           8578
MARITAL_STATUS    1882
AGE               1672
dtype: int64

### Step XX: Standardise recording of demographic variables

#### Ethnicity

In [36]:
# Look at unique values
final_df["ETHNICITY"].value_counts(dropna=False)

WHITE                                                       20019
UNKNOWN/NOT SPECIFIED                                        2343
BLACK/AFRICAN AMERICAN                                       2309
HISPANIC OR LATINO                                            799
OTHER                                                         680
UNABLE TO OBTAIN                                              521
ASIAN                                                         475
PATIENT DECLINED TO ANSWER                                    218
ASIAN - CHINESE                                               134
HISPANIC/LATINO - PUERTO RICAN                                111
BLACK/CAPE VERDEAN                                            108
WHITE - RUSSIAN                                                80
MULTI RACE ETHNICITY                                           65
BLACK/HAITIAN                                                  60
HISPANIC/LATINO - DOMINICAN                                    48
ASIAN - AS

In [37]:
# Remap ethnicity
def _map_ethnicity(x):
    if "ASIAN" in x:
        return "ASIAN"
    elif x in ["PATIENT DECLINED TO ANSWER", "UNKNOWN/NOT SPECIFIED", "UNABLE TO OBTAIN"]:
        return np.NaN
    elif "WHITE" in x:
        return "WHITE"
    elif "BLACK" in x:
        return "BLACK"
    else:
        return "OTHERS"

final_df["ETHNICITY"] = final_df["ETHNICITY"].apply(_map_ethnicity)

final_df["ETHNICITY"].value_counts(dropna=False)

WHITE     20181
NaN        3082
BLACK      2500
OTHERS     1858
ASIAN       712
Name: ETHNICITY, dtype: int64

#### Marital Status

In [38]:
# Look at unique values
final_df["MARITAL_STATUS"].value_counts(dropna=False)

MARRIED         12912
SINGLE           7552
WIDOWED          3948
None             1882
DIVORCED         1736
SEPARATED         293
LIFE PARTNER       10
Name: MARITAL_STATUS, dtype: int64

In [39]:
# Remap marital status
def _map_marital_status(x):
    if pd.isnull(x):
        return np.NaN
    elif x in ["MARRIED", "LIFE PARTNER"]:
        return "PARTNERED"
    elif x in ["WIDOWED", "SEPARATED", "DIVORCED"]:
        return "WIDOWED OR SEPARATED"
    else:
        return x

final_df["MARITAL_STATUS"] = final_df["MARITAL_STATUS"].apply(_map_marital_status)

final_df["MARITAL_STATUS"].value_counts(dropna=False)

PARTNERED               12922
SINGLE                   7552
WIDOWED OR SEPARATED     5977
NaN                      1882
Name: MARITAL_STATUS, dtype: int64

#### Insurance

In [40]:
# Look at unique values
final_df["INSURANCE"].value_counts(dropna=False)

Medicare      14602
Private        9692
Medicaid       2620
Government      983
Self Pay        436
Name: INSURANCE, dtype: int64

In [41]:
# Remap insurance
final_df["INSURANCE"] = final_df["INSURANCE"].map({
    "Medicare" : "Government",
    "Medicaid" : "Government",
    "Government" : "Government",
    "Private" : "Private or Self Pay",
    "Self Pay" : "Private or Self Pay"
})
final_df["INSURANCE"].value_counts(dropna=False)

Government             18205
Private or Self Pay    10128
Name: INSURANCE, dtype: int64

### Step 8: Final cleanup and save full table

In [42]:
# Remove columns related to unique IDs
final_df = final_df.drop(["SUBJECT_ID", "INDEX_HADM_ID"], axis=1)

In [43]:
# Reencode boolean as integers
final_df["90DAYREADM"] = final_df["90DAYREADM"].astype(int)
final_df["ED_6HRS"] = final_df["ED_6HRS"].astype(float) # since there are NAs

In [44]:
# View final table
final_df.head(10)

Unnamed: 0,INSURANCE,ETHNICITY,90DAYREADM,GENDER,LOS,ED_6HRS,MARITAL_STATUS,AGE
0,Government,WHITE,0,F,5.0,0.0,WIDOWED OR SEPARATED,85.0
1,Government,BLACK,0,M,3.0,1.0,WIDOWED OR SEPARATED,63.0
2,Private or Self Pay,WHITE,0,F,31.0,0.0,SINGLE,21.0
3,Government,WHITE,1,F,9.0,1.0,PARTNERED,75.0
4,Private or Self Pay,WHITE,0,F,4.0,,PARTNERED,63.0
5,Government,WHITE,0,M,10.0,0.0,WIDOWED OR SEPARATED,86.0
6,Government,WHITE,0,M,3.0,0.0,WIDOWED OR SEPARATED,
7,Private or Self Pay,,0,M,3.0,,,
8,Government,WHITE,0,F,6.0,0.0,PARTNERED,80.0
9,Government,WHITE,0,F,10.0,0.0,WIDOWED OR SEPARATED,44.0


In [45]:
# Assess missingness
final_df.isna().sum()

INSURANCE            0
ETHNICITY         3082
90DAYREADM           0
GENDER               0
LOS                  0
ED_6HRS           8578
MARITAL_STATUS    1882
AGE               1672
dtype: int64

In [46]:
# Save final table as CSV
final_df.to_csv(mimic_dir + "processed.csv", index=False)

## Create random samples

In [47]:
# Sample size to be constructed
Ns = [20000, 10000, 5000, 2000, 1000]

# Set random seed
SEED = 2023

# Construct random samples
for n in tqdm(Ns):
    sampled_df = final_df.sample(n, replace=False, random_state=SEED)
    sampled_df.to_csv(mimic_dir + "processed{}.csv".format(n), index=False)

100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 26.13it/s]
