In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as Func

In [None]:

spark = SparkSession.builder.appName("ModularCSVLoader").getOrCreate()

# Base URL for all the CSV files
base_url = "datasets/"

# List of file names to process
file_names = [
"allergies.csv",
"encounters.csv",
"medications.csv",
"patients.csv",
"procedures.csv"
]

# Dictionaries to store the resulting DataFrames for further processing/joining.
# pandas_dfs = {}
spark_dfs = {}

# Process each file and store the DataFrames in a dictiona
for file_name in file_names:
    name_key = file_name.replace('.csv', '')
    file_url = f"{base_url}{file_name}"
    print(f"Processing file: {file_name}")
    spark_df = spark.read.csv(file_url, header=True, inferSchema=True)
    # pandas_dfs[name_key] = pd_df
    spark_dfs[name_key] = spark_df

In [None]:
# Part 1: Assemble the project cohort
# Filter the 'encounters' DataFrame for specific conditions (e.g., REASONCODE and START date)

enc_df = spark_dfs['encounters'].filter(
    (spark_dfs['encounters'].REASONCODE == '55680006') &
    (spark_dfs['encounters'].START > Func.lit("1999-07-15 00:00:00").cast("timestamp"))
)

# enc_df.show(5)


In [None]:
# create a patients dataframe with only the patients that have a birthdate
patient_df = spark_dfs['patients'].filter(
    (spark_dfs['patients'].BIRTHDATE.isNotNull())
)

# patient_df.show(5)

In [None]:
""" Part 1: Assemble the project cohort
The cohort is defined as patients aged 18-35 at the time of the encounter with a specific reason code.
"""

from pyspark.sql import functions as F

# Alias the DataFrames for clarity
pat = patient_df.alias("pat")
enc = enc_df.alias("enc")

# Join on patient Id. Use proper aliases when referring to columns.
joined_df = pat.join(enc, pat["Id"] == enc["PATIENT"], "inner")

# Select and rename same name columns to avoid ambiguity.
joined_df = joined_df.select(
    pat["Id"].alias("pat_Id"), enc['id'].alias("enc_Id"),
    pat["BIRTHDATE"], pat['DEATHDATE'],
    enc["PATIENT"].alias("enc_PATIENT_id"),
    enc["START"].alias("enc_START"),
    enc["STOP"].alias("enc_STOP")
)

# Calculate the patient's age at the time of the encounter.
joined_df = joined_df.withColumn(
    "AGE_AT_VISIT",
    F.floor(F.datediff(F.col("enc_START").cast("date"), F.col("BIRTHDATE").cast("date")) / 365)
)

# Filter the records where age is between 18 and 35
cohort_df = joined_df.filter((F.col("AGE_AT_VISIT") >= 18) & (F.col("AGE_AT_VISIT") <= 35))

cohort_df.show(5)


In [None]:

"""
DEATH_AT_VISIT_IND: 1 if patient died during the drug overdose encounter, 0 if the patient died at a different time
This cell adds a new column to the cohort DataFrame indicating whether the patient died during the encounter.
"""

cohort_df = cohort_df.withColumn(
    "DEATH_AT_VISIT_IND",
    Func.when(
        # Check that DEATHDATE is not "NA" and falls between the START and STOP dates:
        (Func.col("DEATHDATE") != "NA") &
        (Func.to_timestamp(Func.col("DEATHDATE").cast("timestamp"), "yyyy-MM-dd HH:mm:ss").between(Func.col('enc_START').cast("timestamp"), Func.col('enc_STOP').cast("timestamp"))),
        1
    ).otherwise(0)
)

# For debugging or previewing results:
# cohort_df.select("DEATHDATE", "DEATH_AT_VISIT_IND").show(5, truncate=False)


In [None]:
# Filter records where DEATH_AT_VISIT_IND = 1
test_cohort_df = cohort_df.filter(Func.col("DEATH_AT_VISIT_IND") == 1)

# Show the filtered records
test_cohort_df.show(5, truncate=False)

In [None]:
# Part 2: Create additional fields

"""
This cell counts the number of active medications at the start of the drug overdose encounter.
It joins the medications DataFrame with the cohort DataFrame on patient ID and filters based on the encounter start date.
The result is a DataFrame with the count of current medications for each patient at the time of the encounter.
"""

from pyspark.sql import functions as F

# Alias the DataFrames for clarity
med = spark_dfs['medications'].alias("m")
cohort = cohort_df.alias("c")

# Join the DataFrames using their aliases for clarity in the join condition
med_df = med.join(cohort, F.col("c.pat_id") == F.col("m.PATIENT"), "inner") \
    .filter(
        F.col("m.START").cast("timestamp") >= F.col("c.enc_START").cast("timestamp")
    )

# Now use the actual column names for renaming; after join, med.PATIENT will appear as "PATIENT"
med_df = med_df.withColumnRenamed("PATIENT", "med_PATIENT") \
    .withColumnRenamed("START", "med_START") \
    .withColumnRenamed("STOP", "med_STOP") 

grouped_med = med_df.groupBy(
    "CODE", "ENCOUNTER", "med_PATIENT"
).agg(F.count("*").alias("med_cnt"))

grouped_med = grouped_med.groupBy("med_PATIENT", "ENCOUNTER").agg(
    F.sum("med_cnt").alias("COUNT_CURRENT_MEDS")
)

med_df = med_df.join(grouped_med,
    (med_df["med_PATIENT"] == grouped_med["med_PATIENT"]) &
    (med_df["ENCOUNTER"] == grouped_med["ENCOUNTER"]),
    "inner"
).select(
    med_df["*"],
    grouped_med["COUNT_CURRENT_MEDS"]
)

# cohor and med df are now mergerd
med_df.show(5, truncate=False)


In [None]:
# Calcuale and add CURRENT_OPIOID_IND	
# if the patient had at least one active medication at the start of the overdose encounter that is on the Opioids List (provided below)

# GET CODES FOR  Opioids List:
# Hydromorphone 325Mg
# Fentanyl – 100 MCG
# Oxycodone-acetaminophen 100 Ml

# Define the list of opioid patterns to search for in the DESCRIPTION column
patterns = [
    "(?i)^Hydromorphone 325", 
    "(?i)^Fentanyl", 
    "(?i)^Oxycodone-acetaminophen 100"
]

# Build a filter condition by OR-ing each pattern on the DESCRIPTION column
filter_condition = None
for pattern in patterns:
    cond = Func.col("DESCRIPTION").rlike(pattern)
    filter_condition = cond if filter_condition is None else filter_condition | cond

# Filter med_df using the combined condition and return only the distinct CODE column.
# Then extract the CODE values as a Python list.
opioid_codes_list = [row["CODE"] for row in med_df.filter(filter_condition).select("CODE").distinct().collect()]

# Print the resulting list of opioid codes
print(opioid_codes_list)

# Add the CURRENT_OPIOID_IND column: 1 if med_df.CODE is in restricted_codes_list, else 0.
cohort_df = med_df.withColumn(
    "CURRENT_OPIOID_IND",
    Func.when(Func.col("CODE").isin(*opioid_codes_list), Func.lit(1)).otherwise(Func.lit(0))
)

# Show the results
# cohort_df.show(5, truncate=False)

In [None]:
# Filter records where DEATH_AT_VISIT_IND = 1
test_cohort_df = cohort_df.filter(Func.col("CURRENT_OPIOID_IND") == 1)

# Show the filtered records
test_cohort_df.show(5, truncate=False)

In [None]:
"""
This code creates a patient-partitioned window to compute each encounter’s next start date 
    and the day difference between encounters, then assigns 90-day and 30-day readmission indicators, 
    conditionally replaces the next encounter start date with "N/A" when appropriate, renames that column to FIRST_READMISSION_DATE, 
    and finally displays filtered results.
"""

from pyspark.sql.window import Window

# Define a window partitioned by pat_Id and ordered by enc_STOP (converted to timestamp)
patient_window = Window.partitionBy("pat_Id").orderBy(Func.col("enc_STOP").cast("timestamp"))

# Get the next encounter's start date using lead().
df_with_next = cohort_df.withColumn("next_enc_START", Func.lead("enc_START").over(patient_window))

# Calculate the difference in days between the current encounter and the next encounter.
df_with_diff = df_with_next.withColumn("diff_days", Func.datediff(Func.col("next_enc_START"), Func.col("enc_START")))

# Create the READMISSION_90_DAY_IND indicator: flag as 1 if the next encounter is within 90 days and not 0, else 0.
df_with_indicator = df_with_diff.withColumn(
    "READMISSION_90_DAY_IND",
    Func.when(
        (Func.col("diff_days").isNotNull()) &
        (Func.col("diff_days") != 0) &
        (Func.col("diff_days") <= 90), Func.lit(1)
    ).otherwise(Func.lit(0))
)

#  create a 30-day readmission indicator.
df_with_indicator = df_with_indicator.withColumn(
    "READMISSION_30_DAY_IND",
    Func.when(
        (Func.col("diff_days").isNotNull()) &
        (Func.col("diff_days") != 0) &
        (Func.col("diff_days") <= 30), Func.lit(1)
    ).otherwise(Func.lit(0))
)

# Show filtered results: those with diff_days between 1 and 89.
df_with_indicator.filter(Func.col("diff_days").between(1, 89)).show(5, truncate=False)
df_with_indicator.filter(Func.col("READMISSION_30_DAY_IND") == 1).show(5, truncate=False)

# Update next_enc_START to "N/A" when diff_days is not positive or exceeds 90 (i.e. otherwise set to "N/A").
df_with_indicator = df_with_indicator.withColumn(
    "next_enc_START",
    Func.when(
        (Func.col("diff_days") <= 90) & (Func.col("diff_days") != 0), Func.col("next_enc_START")
    ).otherwise(Func.lit("N/A"))
)

# Rename next_enc_START to FIRST_READMISSION_DATE.
df_with_indicator = df_with_indicator.withColumnRenamed('next_enc_START', 'FIRST_READMISSION_DATE')

df_with_indicator.show(5, truncate=False)


In [None]:
df_with_indicator.filter(F.col("diff_days").between(1, 89)).show(5, truncate=False)
df_with_indicator.filter(F.col("READMISSION_30_DAY_IND") == 1).show(5, truncate=False)

In [None]:
# Part 3: Export the data to a CSV file
# Select and rename columns to match the required format
output_df = df_with_indicator.select(
    F.col("pat_Id").alias("PATIENT_ID"),
    F.col("enc_Id").alias("ENCOUNTER_ID"),
    F.col("enc_START").alias("HOSPITAL_ENCOUNTER_DATE"),
    F.col("AGE_AT_VISIT"),
    F.col("DEATH_AT_VISIT_IND"),
    F.col("COUNT_CURRENT_MEDS"),
    F.col("CURRENT_OPIOID_IND"),
    F.col("READMISSION_90_DAY_IND"),
    F.col("READMISSION_30_DAY_IND"),
    F.col("FIRST_READMISSION_DATE")
)

# Write the DataFrame to a CSV file
output_path = "output/df_with_indicators.csv"
output_df.write.csv(output_path, header=True, mode="overwrite")

print(f"DataFrame written to {output_path}")