In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as Func

In [2]:
spark = SparkSession.builder.appName("ModularCSVLoader").getOrCreate()

# Base URL for all the CSV files
base_url = "datasets/"

# List of file names to process
file_names = [
"allergies.csv",
"encounters.csv",
"medications.csv",
"patients.csv",
"procedures.csv"
]

# Dictionaries to store the resulting DataFrames for further processing/joining.
# pandas_dfs = {}
spark_dfs = {}

# Process each file and store the DataFrames into the dictionaries
for file_name in file_names:
    name_key = file_name.replace('.csv', '')
    file_url = f"{base_url}{file_name}"
    print(f"Processing file: {file_name}")
    spark_df = spark.read.csv(file_url, header=True, inferSchema=True)
    # pandas_dfs[name_key] = pd_df
    spark_dfs[name_key] = spark_df

25/04/11 17:49:02 WARN Utils: Your hostname, debian-shed resolves to a loopback address: 127.0.1.1; using 192.168.1.11 instead (on interface enp3s0)
25/04/11 17:49:02 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/11 17:49:03 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Processing file: allergies.csv
Processing file: encounters.csv


                                                                                

Processing file: medications.csv


                                                                                

Processing file: patients.csv
Processing file: procedures.csv


                                                                                

In [3]:
from pyspark.sql import functions as F

enc_df = spark_dfs['encounters'].filter(
    (spark_dfs['encounters'].REASONCODE == '55680006') &
    (spark_dfs['encounters'].START > F.lit("1999-07-15 00:00:00").cast("timestamp"))
)

enc_df.show(5)


+--------------------+-------------------+-------------------+--------------------+--------------------+--------------+--------+--------------------+------+----------+-----------------+
|                  Id|              START|               STOP|             PATIENT|            PROVIDER|ENCOUNTERCLASS|    CODE|         DESCRIPTION|  COST|REASONCODE|REASONDESCRIPTION|
+--------------------+-------------------+-------------------+--------------------+--------------------+--------------+--------+--------------------+------+----------+-----------------+
|2a917920-2701-49f...|2003-03-31 21:50:51|2003-04-08 13:20:43|708b81c9-21a9-411...|fb37c581-84a6-351...|     emergency|50849002|Emergency Room Ad...|105.37|  55680006|    Drug overdose|
|22874b3d-0873-40e...|2012-02-18 21:50:51|2012-02-28 21:12:17|708b81c9-21a9-411...|fb37c581-84a6-351...|     emergency|50849002|Emergency Room Ad...|105.37|  55680006|    Drug overdose|
|134c5ee3-1b72-4e3...|2013-08-03 21:50:51|2013-08-13 07:44:52|708b81c9

In [4]:
patient_df = spark_dfs['patients'].filter(
    (spark_dfs['patients'].BIRTHDATE.isNotNull())
)

patient_df.show(5)

+--------------------+----------+---------+-----------+---------+----------+------+-----------+----------+------+------------+-------+-----+------------+------+--------------------+--------------------+-------------+------------+-----+
|                  Id| BIRTHDATE|DEATHDATE|        SSN|  DRIVERS|  PASSPORT|PREFIX|      FIRST|      LAST|SUFFIX|      MAIDEN|MARITAL| RACE|   ETHNICITY|GENDER|          BIRTHPLACE|             ADDRESS|         CITY|       STATE|  ZIP|
+--------------------+----------+---------+-----------+---------+----------+------+-----------+----------+------+------------+-------+-----+------------+------+--------------------+--------------------+-------------+------------+-----+
|3d8e57b2-3de5-4fb...|1943-03-11|       NA|999-86-7250|S99939389| X3970685X|  Mrs.|   Allyn942|Kreiger457|    NA|Bartoletti50|      M|asian|asian_indian|     F|Muhlenberg  Penns...|372 Marks Heights...|Middle Paxton|Pennsylvania|   NA|
|7f4ea9fb-f436-411...|1980-09-28|       NA|999-90-4314|S

In [5]:
from pyspark.sql import functions as F

# Alias the DataFrames for clarity
pat = patient_df.alias("pat")
enc = enc_df.alias("enc")

# Join on patient Id. Use proper aliases when referring to columns.
joined_df = pat.join(enc, pat["Id"] == enc["PATIENT"], "inner")

# Select the columns of interest from each DataFrame and rename encounter columns to avoid ambiguity.
joined_df = joined_df.select(
    pat["Id"].alias("pat_Id"), enc['id'].alias("enc_Id"),
    pat["BIRTHDATE"], pat['DEATHDATE'],
    enc["PATIENT"].alias("enc_PATIENT_id"),
    enc["START"].alias("enc_START"),
    enc["STOP"].alias("enc_STOP")
)

# Calculate the patient's age at the time of the encounter.
# Assuming "BIRTHDATE" is in a format recognized by Spark for date conversion.
joined_df = joined_df.withColumn(
    "AGE_AT_VISIT",
    F.floor(F.datediff(F.col("enc_START").cast("date"), F.col("BIRTHDATE").cast("date")) / 365)
)

# Filter the records where age is between 18 and 35
cohort_df = joined_df.filter((F.col("AGE_AT_VISIT") >= 18) & (F.col("AGE_AT_VISIT") <= 35))

cohort_df.show(5)


+--------------------+--------------------+----------+---------+--------------------+-------------------+-------------------+------------+
|              pat_Id|              enc_Id| BIRTHDATE|DEATHDATE|      enc_PATIENT_id|          enc_START|           enc_STOP|AGE_AT_VISIT|
+--------------------+--------------------+----------+---------+--------------------+-------------------+-------------------+------------+
|708b81c9-21a9-411...|22874b3d-0873-40e...|1986-10-28|       NA|708b81c9-21a9-411...|2012-02-18 21:50:51|2012-02-28 21:12:17|          25|
|708b81c9-21a9-411...|134c5ee3-1b72-4e3...|1986-10-28|       NA|708b81c9-21a9-411...|2013-08-03 21:50:51|2013-08-13 07:44:52|          26|
|708b81c9-21a9-411...|6125f147-72d4-48a...|1986-10-28|       NA|708b81c9-21a9-411...|2014-12-08 21:50:51|2014-12-17 12:25:27|          28|
|708b81c9-21a9-411...|f837dcf8-af7d-43b...|1986-10-28|       NA|708b81c9-21a9-411...|2015-08-31 21:50:51|2015-09-08 12:04:08|          28|
|65b093e4-b353-447...|01059

In [6]:

cohort_df = cohort_df.withColumn(
    "DEATH_AT_VISIT_IND",
    Func.when(
        # Check that DEATHDATE is not "NA" and falls between the START and STOP dates:
        (Func.col("DEATHDATE") != "NA") &
        (Func.to_timestamp(Func.col("DEATHDATE").cast("timestamp"), "yyyy-MM-dd HH:mm:ss").between(Func.col('enc_START').cast("timestamp"), Func.col('enc_STOP').cast("timestamp"))),
        1
    ).otherwise(0)
)

# For debugging or previewing results:
cohort_df.select("DEATHDATE", "DEATH_AT_VISIT_IND").show(5, truncate=False)


+---------+------------------+
|DEATHDATE|DEATH_AT_VISIT_IND|
+---------+------------------+
|NA       |0                 |
|NA       |0                 |
|NA       |0                 |
|NA       |0                 |
|NA       |0                 |
+---------+------------------+
only showing top 5 rows



In [7]:
# Filter records where DEATH_AT_VISIT_IND = 1
test_cohort_df = cohort_df.filter(Func.col("DEATH_AT_VISIT_IND") == 1)

# Show the filtered records
test_cohort_df.show(5, truncate=False)

+------------------------------------+------------------------------------+----------+----------+------------------------------------+-------------------+-------------------+------------+------------------+
|pat_Id                              |enc_Id                              |BIRTHDATE |DEATHDATE |enc_PATIENT_id                      |enc_START          |enc_STOP           |AGE_AT_VISIT|DEATH_AT_VISIT_IND|
+------------------------------------+------------------------------------+----------+----------+------------------------------------+-------------------+-------------------+------------+------------------+
|dacea80b-75dd-42d6-a5c0-be369c3e4ebf|1deb51e9-ef0b-4013-b46f-f8efcd836842|1979-12-11|2009-08-24|dacea80b-75dd-42d6-a5c0-be369c3e4ebf|2009-08-20 01:30:34|2009-08-24 21:37:59|29          |1                 |
|978114f9-f9f2-4361-b64d-8045ab8f1602|41c9fbf6-a7d9-48a1-93fc-9fa32d554083|1976-01-07|2009-09-01|978114f9-f9f2-4361-b64d-8045ab8f1602|2009-08-22 16:34:32|2009-09-01 17:22:5

In [8]:
# COUNT_CURRENT_MEDS: Count of active medications at the start of the drug overdose encounter

from pyspark.sql import functions as F

# Alias the DataFrames for clarity
med = spark_dfs['medications'].alias("m")
cohort = cohort_df.alias("c")

# Join the DataFrames using their aliases for clarity in the join condition
med_df = med.join(cohort, F.col("c.pat_id") == F.col("m.PATIENT"), "inner") \
    .filter(
        F.col("m.START").cast("timestamp") >= F.col("c.enc_START").cast("timestamp")
    )

# Now use the actual column names for renaming; after join, med.PATIENT will appear as "PATIENT"
med_df = med_df.withColumnRenamed("PATIENT", "med_PATIENT") \
    .withColumnRenamed("START", "med_START") \
    .withColumnRenamed("STOP", "med_STOP") 

grouped_med = med_df.groupBy(
    "CODE", "ENCOUNTER", "med_PATIENT"
).agg(F.count("*").alias("med_cnt"))

grouped_med = grouped_med.groupBy("med_PATIENT", "ENCOUNTER").agg(
    F.sum("med_cnt").alias("COUNT_CURRENT_MEDS")
)

med_df = med_df.join(grouped_med,
    (med_df["med_PATIENT"] == grouped_med["med_PATIENT"]) &
    (med_df["ENCOUNTER"] == grouped_med["ENCOUNTER"]),
    "inner"
).select(
    med_df["*"],
    grouped_med["COUNT_CURRENT_MEDS"]
)

med_df.show(5, truncate=False)


                                                                                

+----------+----------+------------------------------------+------------------------------------+-------+------------------------------------------------------------+------+---------+---------+----------+-----------------+------------------------------------+------------------------------------+----------+----------+------------------------------------+-------------------+-------------------+------------+------------------+------------------+
|med_START |med_STOP  |med_PATIENT                         |ENCOUNTER                           |CODE   |DESCRIPTION                                                 |COST  |DISPENSES|TOTALCOST|REASONCODE|REASONDESCRIPTION|pat_Id                              |enc_Id                              |BIRTHDATE |DEATHDATE |enc_PATIENT_id                      |enc_START          |enc_STOP           |AGE_AT_VISIT|DEATH_AT_VISIT_IND|COUNT_CURRENT_MEDS|
+----------+----------+------------------------------------+------------------------------------+-------+-

In [9]:
# Calcuale and add CURRENT_OPIOID_IND	
# if the patient had at least one active medication at the start of the overdose encounter that is on the Opioids List (provided below)

# GET CODES FOR  Opioids List:
# Hydromorphone 325Mg
# Fentanyl – 100 MCG
# Oxycodone-acetaminophen 100 Ml

from pyspark.sql import functions as F

patterns = [
    "(?i)^Hydromorphone 325", 
    "(?i)^Fentanyl", 
    "(?i)^Oxycodone-acetaminophen 100"
]

# Build a filter condition by OR-ing each pattern on the DESCRIPTION column
filter_condition = None
for pattern in patterns:
    cond = F.col("DESCRIPTION").rlike(pattern)
    filter_condition = cond if filter_condition is None else filter_condition | cond

# Filter med_df using the combined condition and return only the distinct CODE column.
# Then extract the CODE values as a Python list.
opioid_codes_list = [row["CODE"] for row in med_df.filter(filter_condition).select("CODE").distinct().collect()]

# Print the resulting list of opioid codes
print(opioid_codes_list)

# Add the CURRENT_OPIOID_IND column: 1 if med_df.CODE is in restricted_codes_list, else 0.
cohort_df = med_df.withColumn(
    "CURRENT_OPIOID_IND",
    F.when(F.col("CODE").isin(*opioid_codes_list), F.lit(1)).otherwise(F.lit(0))
)

# Show the results
cohort_df.show(5, truncate=False)

                                                                                

[316049, 429503]


                                                                                

+----------+----------+------------------------------------+------------------------------------+-------+------------------------------------------------------------+------+---------+---------+----------+-----------------+------------------------------------+------------------------------------+----------+----------+------------------------------------+-------------------+-------------------+------------+------------------+------------------+------------------+
|med_START |med_STOP  |med_PATIENT                         |ENCOUNTER                           |CODE   |DESCRIPTION                                                 |COST  |DISPENSES|TOTALCOST|REASONCODE|REASONDESCRIPTION|pat_Id                              |enc_Id                              |BIRTHDATE |DEATHDATE |enc_PATIENT_id                      |enc_START          |enc_STOP           |AGE_AT_VISIT|DEATH_AT_VISIT_IND|COUNT_CURRENT_MEDS|CURRENT_OPIOID_IND|
+----------+----------+------------------------------------+--------

In [10]:
# Filter records where DEATH_AT_VISIT_IND = 1
test_cohort_df = cohort_df.filter(Func.col("CURRENT_OPIOID_IND") == 1)

# Show the filtered records
test_cohort_df.show(5, truncate=False)

                                                                                

+----------+----------+------------------------------------+------------------------------------+------+--------------------+------+---------+---------+----------+-----------------+------------------------------------+------------------------------------+----------+----------+------------------------------------+-------------------+-------------------+------------+------------------+------------------+------------------+
|med_START |med_STOP  |med_PATIENT                         |ENCOUNTER                           |CODE  |DESCRIPTION         |COST  |DISPENSES|TOTALCOST|REASONCODE|REASONDESCRIPTION|pat_Id                              |enc_Id                              |BIRTHDATE |DEATHDATE |enc_PATIENT_id                      |enc_START          |enc_STOP           |AGE_AT_VISIT|DEATH_AT_VISIT_IND|COUNT_CURRENT_MEDS|CURRENT_OPIOID_IND|
+----------+----------+------------------------------------+------------------------------------+------+--------------------+------+---------+--------

In [11]:
cohort_df.printSchema()

root
 |-- med_START: date (nullable = true)
 |-- med_STOP: string (nullable = true)
 |-- med_PATIENT: string (nullable = true)
 |-- ENCOUNTER: string (nullable = true)
 |-- CODE: integer (nullable = true)
 |-- DESCRIPTION: string (nullable = true)
 |-- COST: double (nullable = true)
 |-- DISPENSES: integer (nullable = true)
 |-- TOTALCOST: double (nullable = true)
 |-- REASONCODE: string (nullable = true)
 |-- REASONDESCRIPTION: string (nullable = true)
 |-- pat_Id: string (nullable = true)
 |-- enc_Id: string (nullable = true)
 |-- BIRTHDATE: date (nullable = true)
 |-- DEATHDATE: string (nullable = true)
 |-- enc_PATIENT_id: string (nullable = true)
 |-- enc_START: timestamp (nullable = true)
 |-- enc_STOP: timestamp (nullable = true)
 |-- AGE_AT_VISIT: long (nullable = true)
 |-- DEATH_AT_VISIT_IND: integer (nullable = false)
 |-- COUNT_CURRENT_MEDS: long (nullable = true)
 |-- CURRENT_OPIOID_IND: integer (nullable = false)



In [12]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Assume cohort_df has at least the following columns: PATIENT and enc_STOP.
# Define a window partitioned by PATIENT and ordered by enc_STOP.
patient_window = Window.partitionBy("pat_Id").orderBy(F.col("enc_STOP").cast("timestamp"))

# Get the next encounter's start date using lead().
df_with_next = cohort_df.withColumn("next_enc_START", F.lead("enc_START").over(patient_window))

# Calculate the difference in days between the current encounter and the next encounter.
df_with_diff = df_with_next.withColumn("diff_days", F.datediff(F.col("next_enc_START"), F.col("enc_START")))

# Create the READMISSION_90_DAY_IND indicator: flag as 1 if the next encounter is within 90 days, else 0.
df_with_indicator = df_with_diff.withColumn(
    "READMISSION_90_DAY_IND",
    F.when((F.col("diff_days").isNotNull()) & \
           (F.col("diff_days") != 0) & \
            (F.col("diff_days") <= 90), F.lit(1)).otherwise(F.lit(0))
)

df_with_indicator = df_with_indicator.withColumn(
    "READMISSION_30_DAY_IND",
    F.when((F.col("diff_days").isNotNull()) & \
           (F.col("diff_days") != 0) & \
            (F.col("diff_days") <= 30), F.lit(1)).otherwise(F.lit(0))
)
# df_with_indicator.show(5, truncate=False)
df_with_indicator.filter(F.col("diff_days").between(1, 89)).show(5, truncate=False)
df_with_indicator.filter(F.col("READMISSION_30_DAY_IND") == 1).show(5, truncate=False)

# 4. Update next_enc_START to "N/A" when diff_days equals 0.
df_with_indicator = df_with_indicator.withColumn(
    "next_enc_START",
    F.when((F.col("diff_days") <= 90) & (F.col("diff_days") != 0), F.col("next_enc_START")).otherwise(F.lit("N/A"))
)
# cohort_df = df_with_indicator.drop( "diff_days")
df_with_indicator = df_with_indicator.withColumnRenamed('next_enc_START', 'FIRST_READMISSION_DATE')
df_with_indicator.show(5, truncate=False)


25/04/11 17:49:37 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

+----------+----------+------------------------------------+------------------------------------+------+-----------------------------------------+-----+---------+---------+----------+------------------------------------+------------------------------------+------------------------------------+----------+---------+------------------------------------+-------------------+-------------------+------------+------------------+------------------+------------------+-------------------+---------+----------------------+----------------------+
|med_START |med_STOP  |med_PATIENT                         |ENCOUNTER                           |CODE  |DESCRIPTION                              |COST |DISPENSES|TOTALCOST|REASONCODE|REASONDESCRIPTION                   |pat_Id                              |enc_Id                              |BIRTHDATE |DEATHDATE|enc_PATIENT_id                      |enc_START          |enc_STOP           |AGE_AT_VISIT|DEATH_AT_VISIT_IND|COUNT_CURRENT_MEDS|CURRENT_OPIOID_IND

                                                                                

+----------+----------+------------------------------------+------------------------------------+------+---------------------------------------------------+-----+---------+---------+----------+--------------------------+------------------------------------+------------------------------------+----------+----------+------------------------------------+-------------------+-------------------+------------+------------------+------------------+------------------+-------------------+---------+----------------------+----------------------+
|med_START |med_STOP  |med_PATIENT                         |ENCOUNTER                           |CODE  |DESCRIPTION                                        |COST |DISPENSES|TOTALCOST|REASONCODE|REASONDESCRIPTION         |pat_Id                              |enc_Id                              |BIRTHDATE |DEATHDATE |enc_PATIENT_id                      |enc_START          |enc_STOP           |AGE_AT_VISIT|DEATH_AT_VISIT_IND|COUNT_CURRENT_MEDS|CURRENT_OPIOID_I

                                                                                

+----------+----------+------------------------------------+------------------------------------+-------+-----------------------------------------------------------------------------------------------------------------+------+---------+---------+----------+---------------------------+------------------------------------+------------------------------------+----------+----------+------------------------------------+-------------------+-------------------+------------+------------------+------------------+------------------+----------------------+---------+----------------------+----------------------+
|med_START |med_STOP  |med_PATIENT                         |ENCOUNTER                           |CODE   |DESCRIPTION                                                                                                      |COST  |DISPENSES|TOTALCOST|REASONCODE|REASONDESCRIPTION          |pat_Id                              |enc_Id                              |BIRTHDATE |DEATHDATE |enc_PATIENT

In [13]:
df_with_indicator.filter(F.col("diff_days").between(1, 89)).show(5, truncate=False)
df_with_indicator.filter(F.col("READMISSION_30_DAY_IND") == 1).show(5, truncate=False)

                                                                                

+----------+----------+------------------------------------+------------------------------------+------+-----------------------------------------+-----+---------+---------+----------+------------------------------------+------------------------------------+------------------------------------+----------+---------+------------------------------------+-------------------+-------------------+------------+------------------+------------------+------------------+----------------------+---------+----------------------+----------------------+
|med_START |med_STOP  |med_PATIENT                         |ENCOUNTER                           |CODE  |DESCRIPTION                              |COST |DISPENSES|TOTALCOST|REASONCODE|REASONDESCRIPTION                   |pat_Id                              |enc_Id                              |BIRTHDATE |DEATHDATE|enc_PATIENT_id                      |enc_START          |enc_STOP           |AGE_AT_VISIT|DEATH_AT_VISIT_IND|COUNT_CURRENT_MEDS|CURRENT_OPIOID_

                                                                                

+----------+----------+------------------------------------+------------------------------------+------+---------------------------------------------------+-----+---------+---------+----------+--------------------------+------------------------------------+------------------------------------+----------+----------+------------------------------------+-------------------+-------------------+------------+------------------+------------------+------------------+----------------------+---------+----------------------+----------------------+
|med_START |med_STOP  |med_PATIENT                         |ENCOUNTER                           |CODE  |DESCRIPTION                                        |COST |DISPENSES|TOTALCOST|REASONCODE|REASONDESCRIPTION         |pat_Id                              |enc_Id                              |BIRTHDATE |DEATHDATE |enc_PATIENT_id                      |enc_START          |enc_STOP           |AGE_AT_VISIT|DEATH_AT_VISIT_IND|COUNT_CURRENT_MEDS|CURRENT_OPIOI