In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as Func

In [2]:
spark = SparkSession.builder.appName("ModularCSVLoader").getOrCreate()

# Base URL for all the CSV files
base_url = "datasets/"

# List of file names to process
file_names = [
"allergies.csv",
"encounters.csv",
"medications.csv",
"patients.csv",
"procedures.csv"
]

# Dictionaries to store the resulting DataFrames for further processing/joining.
# pandas_dfs = {}
spark_dfs = {}

# Process each file and store the DataFrames into the dictionaries
for file_name in file_names:
    name_key = file_name.replace('.csv', '')
    file_url = f"{base_url}{file_name}"
    print(f"Processing file: {file_name}")
    spark_df = spark.read.csv(file_url, header=True, inferSchema=True)
    # pandas_dfs[name_key] = pd_df
    spark_dfs[name_key] = spark_df

25/04/11 13:17:38 WARN Utils: Your hostname, debian-shed resolves to a loopback address: 127.0.1.1; using 192.168.1.11 instead (on interface enp3s0)
25/04/11 13:17:38 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/11 13:17:39 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Processing file: allergies.csv
Processing file: encounters.csv


                                                                                

Processing file: medications.csv


                                                                                

Processing file: patients.csv
Processing file: procedures.csv


                                                                                

In [3]:
from pyspark.sql import functions as F

enc_df = spark_dfs['encounters'].filter(
    (spark_dfs['encounters'].REASONCODE == '55680006') &
    (spark_dfs['encounters'].START > F.lit("1999-07-15 00:00:00").cast("timestamp"))
)

enc_df.show(5)


+--------------------+-------------------+-------------------+--------------------+--------------------+--------------+--------+--------------------+------+----------+-----------------+
|                  Id|              START|               STOP|             PATIENT|            PROVIDER|ENCOUNTERCLASS|    CODE|         DESCRIPTION|  COST|REASONCODE|REASONDESCRIPTION|
+--------------------+-------------------+-------------------+--------------------+--------------------+--------------+--------+--------------------+------+----------+-----------------+
|2a917920-2701-49f...|2003-03-31 21:50:51|2003-04-08 13:20:43|708b81c9-21a9-411...|fb37c581-84a6-351...|     emergency|50849002|Emergency Room Ad...|105.37|  55680006|    Drug overdose|
|22874b3d-0873-40e...|2012-02-18 21:50:51|2012-02-28 21:12:17|708b81c9-21a9-411...|fb37c581-84a6-351...|     emergency|50849002|Emergency Room Ad...|105.37|  55680006|    Drug overdose|
|134c5ee3-1b72-4e3...|2013-08-03 21:50:51|2013-08-13 07:44:52|708b81c9

In [4]:
patient_df = spark_dfs['patients'].filter(
    (spark_dfs['patients'].BIRTHDATE.isNotNull())
)

patient_df.show(5)

+--------------------+----------+---------+-----------+---------+----------+------+-----------+----------+------+------------+-------+-----+------------+------+--------------------+--------------------+-------------+------------+-----+
|                  Id| BIRTHDATE|DEATHDATE|        SSN|  DRIVERS|  PASSPORT|PREFIX|      FIRST|      LAST|SUFFIX|      MAIDEN|MARITAL| RACE|   ETHNICITY|GENDER|          BIRTHPLACE|             ADDRESS|         CITY|       STATE|  ZIP|
+--------------------+----------+---------+-----------+---------+----------+------+-----------+----------+------+------------+-------+-----+------------+------+--------------------+--------------------+-------------+------------+-----+
|3d8e57b2-3de5-4fb...|1943-03-11|       NA|999-86-7250|S99939389| X3970685X|  Mrs.|   Allyn942|Kreiger457|    NA|Bartoletti50|      M|asian|asian_indian|     F|Muhlenberg  Penns...|372 Marks Heights...|Middle Paxton|Pennsylvania|   NA|
|7f4ea9fb-f436-411...|1980-09-28|       NA|999-90-4314|S

In [5]:
from pyspark.sql import functions as F

# Alias the DataFrames for clarity
pat = patient_df.alias("pat")
enc = enc_df.alias("enc")

# Join on patient Id. Use proper aliases when referring to columns.
joined_df = pat.join(enc, pat["Id"] == enc["PATIENT"], "inner")

# Select the columns of interest from each DataFrame and rename encounter columns to avoid ambiguity.
joined_df = joined_df.select(
    pat["Id"].alias("pat_Id"), enc['id'].alias("enc_Id"),
    pat["BIRTHDATE"], pat['DEATHDATE'],
    enc["PATIENT"].alias("enc_PATIENT_id"),
    enc["START"].alias("enc_START"),
    enc["STOP"].alias("enc_STOP")
)

# Calculate the patient's age at the time of the encounter.
# Assuming "BIRTHDATE" is in a format recognized by Spark for date conversion.
joined_df = joined_df.withColumn(
    "age",
    F.floor(F.datediff(F.col("enc_START").cast("date"), F.col("BIRTHDATE").cast("date")) / 365)
)

# Filter the records where age is between 18 and 35
cohort_df = joined_df.filter((F.col("age") >= 18) & (F.col("age") <= 35))

cohort_df.show(5)


+--------------------+--------------------+----------+---------+--------------------+-------------------+-------------------+---+
|              pat_Id|              enc_Id| BIRTHDATE|DEATHDATE|      enc_PATIENT_id|          enc_START|           enc_STOP|age|
+--------------------+--------------------+----------+---------+--------------------+-------------------+-------------------+---+
|708b81c9-21a9-411...|22874b3d-0873-40e...|1986-10-28|       NA|708b81c9-21a9-411...|2012-02-18 21:50:51|2012-02-28 21:12:17| 25|
|708b81c9-21a9-411...|134c5ee3-1b72-4e3...|1986-10-28|       NA|708b81c9-21a9-411...|2013-08-03 21:50:51|2013-08-13 07:44:52| 26|
|708b81c9-21a9-411...|6125f147-72d4-48a...|1986-10-28|       NA|708b81c9-21a9-411...|2014-12-08 21:50:51|2014-12-17 12:25:27| 28|
|708b81c9-21a9-411...|f837dcf8-af7d-43b...|1986-10-28|       NA|708b81c9-21a9-411...|2015-08-31 21:50:51|2015-09-08 12:04:08| 28|
|65b093e4-b353-447...|010594a6-a6ff-487...|1995-12-03|       NA|65b093e4-b353-447...|2018-

In [6]:

cohort_df = cohort_df.withColumn(
    "DEATH_AT_VISIT_IND",
    Func.when(
        # Check that DEATHDATE is not "NA" and falls between the START and STOP dates:
        (Func.col("DEATHDATE") != "NA") &
        (Func.to_timestamp(Func.col("DEATHDATE").cast("timestamp"), "yyyy-MM-dd HH:mm:ss").between(Func.col('enc_START').cast("timestamp"), Func.col('enc_STOP').cast("timestamp"))),
        1
    ).otherwise(0)
)

# For debugging or previewing results:
cohort_df.select("DEATHDATE", "DEATH_AT_VISIT_IND").show(5, truncate=False)


+---------+------------------+
|DEATHDATE|DEATH_AT_VISIT_IND|
+---------+------------------+
|NA       |0                 |
|NA       |0                 |
|NA       |0                 |
|NA       |0                 |
|NA       |0                 |
+---------+------------------+
only showing top 5 rows



In [7]:
# Filter records where DEATH_AT_VISIT_IND = 1
cohort_df = cohort_df.filter(Func.col("DEATH_AT_VISIT_IND") == 1)

# Show the filtered records
cohort_df.show(5, truncate=False)

+------------------------------------+------------------------------------+----------+----------+------------------------------------+-------------------+-------------------+---+------------------+
|pat_Id                              |enc_Id                              |BIRTHDATE |DEATHDATE |enc_PATIENT_id                      |enc_START          |enc_STOP           |age|DEATH_AT_VISIT_IND|
+------------------------------------+------------------------------------+----------+----------+------------------------------------+-------------------+-------------------+---+------------------+
|dacea80b-75dd-42d6-a5c0-be369c3e4ebf|1deb51e9-ef0b-4013-b46f-f8efcd836842|1979-12-11|2009-08-24|dacea80b-75dd-42d6-a5c0-be369c3e4ebf|2009-08-20 01:30:34|2009-08-24 21:37:59|29 |1                 |
|978114f9-f9f2-4361-b64d-8045ab8f1602|41c9fbf6-a7d9-48a1-93fc-9fa32d554083|1976-01-07|2009-09-01|978114f9-f9f2-4361-b64d-8045ab8f1602|2009-08-22 16:34:32|2009-09-01 17:22:53|33 |1                 |
|7cdbf215-

In [None]:
# COUNT_CURRENT_MEDS: Count of active medications at the start of the drug overdose encounter

from pyspark.sql import functions as F

# Alias the DataFrames for clarity
med = spark_dfs['medications'].alias("m")
cohort = cohort_df.alias("c")

# Join the DataFrames using their aliases for clarity in the join condition
med_df = med.join(cohort, F.col("c.pat_id") == F.col("m.PATIENT"), "inner") \
    .filter(
        F.col("m.START").cast("timestamp") >= F.col("c.enc_START").cast("timestamp")
    )

# Now use the actual column names for renaming; after join, med.PATIENT will appear as "PATIENT"
med_df = med_df.withColumnRenamed("PATIENT", "med_PATIENT") \
    .withColumnRenamed("START", "med_START") \
    .withColumnRenamed("STOP", "med_STOP") 


med_df.show(5, truncate=False)




+----------+----------+------------------------------------+------------------------------------+-------+-------------------------------------+-------+---------+---------+----------+-----------------+------------------------------------+------------------------------------+----------+----------+------------------------------------+-------------------+-------------------+---+------------------+
|med_START |med_STOP  |med_PATIENT                         |ENCOUNTER                           |CODE   |DESCRIPTION                          |COST   |DISPENSES|TOTALCOST|REASONCODE|REASONDESCRIPTION|pat_Id                              |enc_Id                              |BIRTHDATE |DEATHDATE |enc_PATIENT_id                      |enc_START          |enc_STOP           |age|DEATH_AT_VISIT_IND|
+----------+----------+------------------------------------+------------------------------------+-------+-------------------------------------+-------+---------+---------+----------+-----------------+------

                                                                                

In [9]:
grouped_med = med_df.groupBy(
    "CODE", "ENCOUNTER", "med_PATIENT"
).agg(F.count("*").alias("med_cnt"))

grouped_med = grouped_med.groupBy("med_PATIENT", "ENCOUNTER").agg(
    F.sum("med_cnt").alias("total_med_count")
)

grouped_med.show(truncate=False)



+------------------------------------+------------------------------------+---------------+
|med_PATIENT                         |ENCOUNTER                           |total_med_count|
+------------------------------------+------------------------------------+---------------+
|124bb400-58b3-4352-acaa-f79c84397795|6582cf77-7929-4b34-97ff-88b7cd11f49c|1              |
|5fa3711b-4012-4c34-aaf8-0097d0785418|68988987-ad1e-40c3-8c7d-5fb7e099480b|1              |
|a20356aa-08cd-4c33-9667-08ffd3130676|23bffe9b-7d1e-4985-9260-c860c4b72315|2              |
|77b4cadf-2eb6-4019-96f0-d4a8b344401d|e723d4a7-a71d-4951-bdc6-617c9b19bf55|1              |
|dacea80b-75dd-42d6-a5c0-be369c3e4ebf|c3c43259-21c7-4bbc-accc-011c50a0ef77|1              |
|28202667-2977-4d2f-ba68-977a98827104|cb13d070-c2b3-4ac6-9920-0d89c216442a|1              |
|0ea047ce-faa7-4320-be2f-672a93fc9cfb|868c907d-4892-4771-a509-4e0f85015a4d|1              |
|28202667-2977-4d2f-ba68-977a98827104|23ae1d57-58e7-4cbb-b51b-171d03d46041|1    

                                                                                

In [None]:
# %pip install matplotlib
import matplotlib.pyplot as plt  # type: ignore

# Count the total number of rows
total_rows = patient_df.count()

# Calculate the number of nulls for each column
null_counts = patient_df.select(
    [(Func.count(Func.when(Func.col(c).isNull(), c)) / total_rows).alias(c) for c in patient_df.columns]
).collect()[0]

# Convert the null counts to a dictionary
null_counts_dict = {col: null_counts[i] for i, col in enumerate(patient_df.columns)}

# Plot the null percentages
plt.figure(figsize=(12, 6))
plt.bar(null_counts_dict.keys(), null_counts_dict.values(), color='skyblue')
plt.xticks(rotation=90)
plt.xlabel('Columns')
plt.ylabel('Percentage of Null Values')
plt.title('Null Value Percentage by Column in patient_df')
plt.show()