# Personal Notes

On macbook pro
- Using arm64 architecture is faster than x86_64 for pyspark jobs: `arch -arm64 /bin/zsh`
  - Verify: `uname -m`
- 

In [17]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

def start_spark():
    spark = (SparkSession.builder
             .master("local[*]") # use all cores on computer/dynamically adjust based of cpu count and maximize parallism forcomput vector opertations
             .appName("Clinical_Notes_Processing")
             .config("spark.driver.memory", "8g") # half of ram on mac
             .config("spark.executor.memory", "4g")
             .config("spark.driver.maxResultSize", "2g") # increase for embeddings rapidly growing
             .config("spark.sql.shuffle.partitions", "8")
             .config("spark.executor.cores", "4")
             .config("spark.driver.extraJavaOptions", "-XX:+UseG1GC") # more efficient garbage collector. Helpful for large heaps containing embedded vectors 
             .getOrCreate())

    # Reduce shuffle partitions
    spark.conf.set("spark.sql.shuffle.partitions", "50")

    print("Spark Version:", spark.version)
    print("Spark UI: http://localhost:4040")
    return spark

In [18]:
def clean_resources(spark):
    spark.stop()
    return start_spark()

In [19]:
spark = start_spark()

Spark Version: 3.5.4
Spark UI: http://localhost:4040


In [20]:
import os

GOOGLE_DRIVE_LOCAL_MOUNT='/Users/sagana/Library/CloudStorage/GoogleDrive-sondande@uchicago.edu/.shortcut-targets-by-id/1O2pwlZERv3B7ki78Wn0brrpnArRBTFdH/MLI_2025 Winter/'

# Check if Google Drive is accessible
if os.path.exists(GOOGLE_DRIVE_LOCAL_MOUNT):
    print("Google Drive is mounted successfully!")
    print("Files in Drive:", os.listdir(GOOGLE_DRIVE_LOCAL_MOUNT))
else:
    print("Google Drive is not mounted. Please check your installation.")

Google Drive is mounted successfully!
Files in Drive: ['(ReferHere)Final_Dataset_Data_Folder ', 'merged_5000_patient_radio.csv', 'mimic-iv-ext-clinical-decision-making-a-mimic-iv-derived-dataset-for-evaluation-of-large-language-models-on-the-task-of-clinical-decision-making-for-abdominal-pathologies-1.1.zip', '.DS_Store', 'extracted_zip', 'Project_Presentation.pptx', 'JM outputs', 'SQL DB Export', 'mimiciv.db', 'mimic-iv-3.1.zip', 'Machine Learning I Team 5 Project Proposal.gdoc', 'YY_codes', 'mimic-iv-note-deidentified-free-text-clinical-notes-2.2.zip', 'merged_5000_patient.csv', 'Project Idea.gdoc', 'Final_Dataset_Data_Folder_unzip', 'MLI_2025_Winter', 'Sagana Outputs', 'merged_5000_patient_radio_disc.csv', 'Project Milestone-I.gdoc', 'Dataset Readme.gdoc']


In [21]:
from pyspark.sql.types import *
from pyspark.sql.functions import collect_set, collect_list, struct, col, when, count, countDistinct, lit
import pandas as pd
import ast

# Read in schema file and process to get schemas needed
schemas_df = spark.read.csv(f'{GOOGLE_DRIVE_LOCAL_MOUNT}/SQL DB Export/CSV/schema.csv', header=True)
schemas_df.show(5)

+----------------+--------------------+
|           table|              schema|
+----------------+--------------------+
|   diagnoses_icd|['subject_id', 'h...|
|       discharge|['subject_id', 'h...|
|        drgcodes|['subject_id', 'h...|
| d_icd_diagnoses|['icd_code', 'icd...|
|d_icd_procedures|['icd_code', 'icd...|
+----------------+--------------------+
only showing top 5 rows



In [22]:
# Construct schema
radiology_schema_list = ast.literal_eval(schemas_df.filter(col("table") == 'radiology').select(col("schema")).collect()[0][0])
radiology_schema = StructType([
    StructField(x, StringType(), True) for x in radiology_schema_list
])

# Read in radiology dataset
radiology_df = spark.read.option("delimiter", "|").option("quote", '"').option("multiLine", "true").csv(f'{GOOGLE_DRIVE_LOCAL_MOUNT}/Sagana Outputs/Clinical Notes Creation/Input Data/radiology.csv', schema=radiology_schema)
radiology_df.show(truncate= 80)

+----------+--------+-------------------+--------------------------------------------------------------------------------+
|subject_id| hadm_id|          charttime|                                                                            text|
+----------+--------+-------------------+--------------------------------------------------------------------------------+
|  10000117|    NULL|2175-05-10 10:12:00|BILATERAL DIGITAL SCREENING MAMMOGRAM WITH CAD\\n\\nHISTORY:  Baseline screen...|
|  10000117|    NULL|2177-05-23 13:18:00|INDICATION:  ___ female with right epigastric pain radiating to back,\\nrule ...|
|  10000117|    NULL|2178-08-29 13:39:00|CLINICAL HISTORY:  Right upper quadrant pain, evaluate for gallstones.\\n\\nA...|
|  10000117|22927623|2181-11-15 00:40:00|EXAMINATION:   CHEST (PA AND LAT)\\n\\nINDICATION:  History: ___ with PMH GER...|
|  10000117|22927623|2181-11-15 00:47:00|EXAMINATION:   NECK SOFT TISSUES\\n\\nINDICATION:  ___ woman with dysphasia. ...|
|  10000117|    

In [23]:
# Read in radiology dataset
discharge_schema_list = ast.literal_eval(schemas_df.filter(col("table") == 'discharge').select(col("schema")).collect()[0][0])
discharge_schema = StructType([
    StructField(x, StringType(), True) for x in discharge_schema_list
])

discharge_df = spark.read.option("delimiter", "|").option("quote", '"').option("multiLine", "true").csv(f'{GOOGLE_DRIVE_LOCAL_MOUNT}/Sagana Outputs/Clinical Notes Creation/Input Data/discharge.csv', schema=discharge_schema)
discharge_df.show(truncate= 80)

+----------+--------+-------------------+--------------------------------------------------------------------------------+
|subject_id| hadm_id|          charttime|                                                                            text|
+----------+--------+-------------------+--------------------------------------------------------------------------------+
|  10000117|27988844|2183-09-21 00:00:00| \\nName:  ___                 Unit No:   ___\\n \\nAdmission Date:  ___     ...|
|  10000117|22927623|2181-11-15 00:00:00| \\nName:  ___                 Unit No:   ___\\n \\nAdmission Date:  ___     ...|
|  10000248|20600184|2192-11-30 00:00:00| \\nName:  ___                      Unit No:   ___\\n \\nAdmission Date:  ___...|
|  10000560|28979390|2189-10-17 00:00:00| \\nName:  ___                     Unit No:   ___\\n \\nAdmission Date:  ___ ...|
|  10000764|27897940|2132-10-19 00:00:00| \\nName:  ___               Unit No:   ___\\n \\nAdmission Date:  ___       ...|
|  10000826|2828

In [24]:
# Select only required fields
radiology_df_filtered = radiology_df.select('subject_id', 'text')
discharge_df_filtered = discharge_df.select('subject_id', 'text')

In [25]:
# Filter for only notes where we have a patient to ensure we filter down datasets
patients_df = spark.read.csv(f'{GOOGLE_DRIVE_LOCAL_MOUNT}/JM outputs/patients_cleaned.csv', header=True)
patients_df.show(5)

+----------+------+----------+-----------+---------+--------+--------------+-----+-----------------------+------------------------+----+------+------+----+
|subject_id|gender|anchor_age|anchor_year|insurance|language|marital_status| race|blood_pressure_systolic|blood_pressure_diastolic| bmi|height|weight|egfr|
+----------+------+----------+-----------+---------+--------+--------------+-----+-----------------------+------------------------+----+------+------+----+
|  10000117|     F|        48|       2174| Medicaid| English|      DIVORCED|WHITE|                    108|                      74|18.9|    64|   110|NULL|
|  10000161|     M|        60|       2163| Medicaid| English|        SINGLE|WHITE|                    106|                      92|NULL|  NULL|  NULL|NULL|
|  10000248|     M|        34|       2192|  Private| English|       MARRIED|WHITE|                   NULL|                    NULL|25.5|    68|   168|NULL|
|  10000280|     M|        20|       2151|  Private| English|   

In [26]:
final_radiology_df = radiology_df_filtered.join(patients_df, radiology_df_filtered.subject_id == patients_df.subject_id, 'left_semi')
final_radiology_df_updated = final_radiology_df.withColumnRenamed('text', 'radiology_text')

In [27]:
final_discharge_df = discharge_df_filtered.join(patients_df, discharge_df_filtered.subject_id == patients_df.subject_id, 'left_semi')
final_discharge_df_updated = final_discharge_df.withColumnRenamed('text', 'discharge_text')

In [28]:
combined_df = final_radiology_df_updated.join(final_discharge_df_updated, how='inner', on=['subject_id'])
combined_df.show(1)

[Stage 38:>                                                         (0 + 1) / 1]

+----------+--------------------+--------------------+
|subject_id|      radiology_text|      discharge_text|
+----------+--------------------+--------------------+
|  10001663|Addendum:\\n\\nAd...| \\nName:  ___   ...|
+----------+--------------------+--------------------+
only showing top 1 row



                                                                                

In [29]:
combined_df.count()

                                                                                

4650350

In [30]:
# Sample the dataset for sample set for embeddings
# Sample without replacement to have a dataset that is representative of the original dataset
# Add seed for reproducibility

# Ensure sampling contains same subject_ids in both datasets
combined_df_sample = combined_df.sample(False, 0.002, 42)

In [31]:
combined_df_sample.count()

                                                                                

9235

In [32]:
combined_df_sample.show()

[Stage 64:>                                                         (0 + 1) / 1]

+----------+--------------------+--------------------+
|subject_id|      radiology_text|      discharge_text|
+----------+--------------------+--------------------+
|  10023708|INDICATION:  Preo...| \\nName:  ___   ...|
|  10023708|HISTORY:  Screeni...| \\nName:  ___   ...|
|  10023708|EXAMINATION:  WRI...| \\nName:  ___   ...|
|  10071302|INDICATION:  ___ ...| \\nName:  ___   ...|
|  10076263|HISTORY:  ___ fem...| \\nName:  ___   ...|
|  10076263|EXAMINATION:  DX ...| \\nName:  ___   ...|
|  10076263|INDICATION:  ___ ...| \\nName:  ___   ...|
|  10076263|EXAMINATION:  CT ...| \\nName:  ___   ...|
|  10076263|INDICATION:  ___ ...| \\nName:  ___   ...|
|  10076263|EXAMINATION:  UNI...| \\nName:  ___   ...|
|  10076263|EXAMINATION:  DX ...| \\nName:  ___   ...|
|  10076263|EXAMINATION:  DX ...| \\nName:  ___   ...|
|  10076263|INDICATION:  ___ ...| \\nName:  ___   ...|
|  10076263|INDICATION:  ___ ...| \\nName:  ___   ...|
|  10102822|INDICATION:  ___ ...| \\nName:  ___   ...|
|  1010443

                                                                                

In [33]:
# combined_df_sample.write.mode("overwrite").option("compression", "snappy").parquet('combined_cn_sample/')

In [34]:
combined_df_sample.printSchema()

root
 |-- subject_id: string (nullable = true)
 |-- radiology_text: string (nullable = true)
 |-- discharge_text: string (nullable = true)



In [35]:
# final_radiology_df_count  = final_radiology_df.count()
# distinct_radiology_df_count = final_radiology_df.select('subject_id').distinct().count()
# final_discharge_df_count = final_discharge_df.count()
# distinct_discharge_df_count = final_discharge_df.select('subject_id').distinct().count()

# sampled_rad_count = sampled_final_radiology_df.count()
# distinct_rad_count = sampled_final_radiology_df.select('subject_id').distinct().count()
# sampled_dis_count = sampled_final_discharge_df.count()
# distinct_dis_count = sampled_final_discharge_df.select('subject_id').distinct().count()

# print(f"Original Radiology Count: {final_radiology_df_count}, Distinct Radiology Count: {distinct_radiology_df_count}")
# print(f"Sampled Radiology Count: {sampled_rad_count}, Distinct Radiology Count: {distinct_rad_count}")
# print(f"Original Discharge Count: {final_discharge_df_count}, Distinct Discharge Count: {distinct_discharge_df_count}")
# print(f"Sampled Discharge Count: {sampled_dis_count}, Distinct Discharge Count: {distinct_dis_count}")

In [36]:
# spark.stop()

In [37]:
# # Suppose "sampled_df" has a vector column "embeddings"
# # 1. Convert Spark vector to array, then collect
# from pyspark.sql.functions import udf
# from pyspark.ml.linalg import VectorUDT
# import numpy as np

# # Convert Spark Vector to Python list
# def to_array(v):
#     return v.toArray().tolist()

# to_array_udf = udf(to_array, "array<double>")
# sampled_df_array = sampled_df.withColumn("embeddings_array", to_array_udf("embeddings"))

# # 2. Collect to Pandas
# pdf = sampled_df_array.select("embeddings_array").limit(5000).toPandas()  # limit for memory safety
# X = np.array(pdf["embeddings_array"].tolist())  # shape: (n_samples, embed_dim)

In [38]:
# final_radiology_df.write.mode("overwrite").option("compression", "snappy").parquet('radiology_filtered/')

In [39]:
# final_discharge_df.write.mode("overwrite").option("compression", "snappy").parquet('discharge_filtered/')

In [40]:
# spark.stop()

## Review Processed datasets

In [41]:
# discharge_processed_df = spark.read.parquet('discharge_processed/')
# radio_processed_df = spark.read.parquet('radiology_processed/')

In [42]:
# from pyspark.sql.functions import expr
# from pyspark.sql.functions import to_json, cola

# df = discharge_processed_df.withColumn("sections", to_json(col("sections")))  # Convert map column to JSON string
# df = df.withColumn("entities", to_json(col("entities")))
# # df.write.csv("output_directory", header=True, mode="overwrite")
# # df.coalesce(1).write.csv("discharge_processed_csv/", header=True, mode="overwrite")
# df_PD = df.toPandas()

In [43]:
# df_PD.to_csv('discharge_processed_csv/discharge_processed.csv', index=False)

In [44]:
# from pyspark.sql.functions import expr
# from pyspark.sql.functions import to_json, col

# df = radio_processed_df.withColumn("sections", to_json(col("sections")))  # Convert map column to JSON string
# df = df.withColumn("entities", to_json(col("entities")))
# # df.write.csv("output_directory", header=True, mode="overwrite")
# # df.coalesce(1).write.csv("discharge_processed_csv/", header=True, mode="overwrite")
# df_PD = df.toPandas()
# df_PD.to_csv('radiology_processed_csv/radio_processed.csv', index=False)

In [45]:
# df_PD.head()

## Review embeddings output

In [46]:
# spark = start_spark()

In [47]:
embedded_dis = spark.read.parquet('discharge_text/clinical_notes_sampled_embedded/')
embedded_radio = spark.read.parquet('radiology_text/clinical_notes_sampled_embedded/')

In [48]:
embedded_dis.show()

+----------+--------------------+---------------------------+------------------------+
|subject_id|      discharge_text|cleaned_discharge_text_text|embedding_discharge_text|
+----------+--------------------+---------------------------+------------------------+
|  10004296| \\nName:  ___   ...|       name: ___ unit no...|    [0.07342699, 0.06...|
|  10007058| \\nName:  ___   ...|       name: ___ unit no...|    [0.3247316, -0.22...|
|  10031575| \\nName:  ___   ...|       name: ___ unit no...|    [0.11593194, -0.0...|
|  10031575| \\nName:  ___   ...|       name: ___ unit no...|    [0.062262576, 0.0...|
|  10036821| \\nName:  ___   ...|       name: ___ unit no...|    [0.2427187, -0.10...|
|  10036821| \\nName:  ___   ...|       name: ___ unit no...|    [0.1506541, -0.00...|
|  10048105| \\nName:  ___   ...|       name: ___ unit no...|    [0.17518048, -0.4...|
|  10076617| \\nName:  ___.  ...|       name: ___. unit n...|    [0.13433282, -0.1...|
|  10076617| \\nName:  ___.  ...|       nam

In [49]:
embedded_radio.show()

+----------+--------------------+---------------------------+------------------------+
|subject_id|      radiology_text|cleaned_radiology_text_text|embedding_radiology_text|
+----------+--------------------+---------------------------+------------------------+
|  10004296|EXAMINATION:  PEL...|       examination: pelv...|    [0.042560313, 0.1...|
|  10007058|EXAMINATION:  CTA...|       examination: cta ...|    [0.3039025, -0.13...|
|  10031575|EXAMINATION:  CT ...|       examination: ct h...|    [0.12216909, -0.1...|
|  10031575|EXAMINATION:  UNI...|       examination: unil...|    [0.10871682, 0.08...|
|  10036821|EXAMINATION:  SEC...|       examination: seco...|    [0.33056372, -0.1...|
|  10036821|EXAMINATION:  CT ...|       examination: ct c...|    [0.30515096, -0.1...|
|  10048105|CHEST RADIOGRAPH\...|       chest radiograph ...|    [0.19040951, -0.1...|
|  10076617|INDICATION:  ___ ...|       indication: ___ w...|    [0.3708051, -0.10...|
|  10076617|CHEST RADIOGRAPH\...|       che

In [50]:
filtered_radio_emb = embedded_radio.select('subject_id', 'embedding_radiology_text')
filtered_dis_emb = embedded_dis.select('subject_id', 'embedding_discharge_text')

In [58]:
filtered_radio_emb.printSchema()

root
 |-- subject_id: string (nullable = true)
 |-- embedding_radiology_text: array (nullable = true)
 |    |-- element: float (containsNull = true)



In [None]:
filtered_radio_agg = (
    filtered_radio_emb
    .groupBy("subject_id")
    .agg(
        collect_set("icd_code").alias("proc_codes")
    )
)

In [51]:
clinical_notes_combined_em = filtered_radio_emb.join(filtered_dis_emb, how='inner', on=['subject_id'])
clinical_notes_combined_em.show()

+----------+------------------------+------------------------+
|subject_id|embedding_radiology_text|embedding_discharge_text|
+----------+------------------------+------------------------+
|  10023486|    [0.21450946, -0.1...|    [0.073009044, -0....|
|  10023708|    [0.22877766, -0.1...|    [0.13038772, -0.0...|
|  10023708|    [0.22877766, -0.1...|    [0.14673024, -0.1...|
|  10023708|    [0.14008474, 0.02...|    [0.13038772, -0.0...|
|  10023708|    [0.14008474, 0.02...|    [0.14673024, -0.1...|
|  10063368|    [0.18466042, -0.1...|    [0.2575382, 0.030...|
|  10065615|    [0.05487843, -0.2...|    [0.19949329, -0.4...|
|  10065615|    [0.05487843, -0.2...|    [0.19949329, -0.4...|
|  10065615|    [0.30068773, -0.2...|    [0.19949329, -0.4...|
|  10065615|    [0.30068773, -0.2...|    [0.19949329, -0.4...|
|  10066489|    [0.30751058, -0.2...|    [0.124997, -0.092...|
|  10076263|    [0.23147626, -0.0...|    [0.27583912, -0.1...|
|  10076263|    [0.23147626, -0.0...|    [0.28715703, -

In [54]:
clinical_notes_combined_em.count()

568278

In [55]:
# c_notes_PD = clinical_notes_combined_em.toPandas()
# c_notes_PD.head()

In [56]:
# c_notes_PD.to_csv('embedded_clinical_notes_combined.csv', index=False)

In [57]:
# c_notes_PD.info()