# Personal Notes

On macbook pro
- Using arm64 architecture is faster than x86_64 for pyspark jobs: `arch -arm64 /bin/zsh`
  - Verify: `uname -m`
- 

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

def start_spark():
    spark = (SparkSession.builder
             .master("local[*]") # use all cores on computer/dynamically adjust based of cpu count and maximize parallism forcomput vector opertations
             .appName("Clinical_Notes_Processing")
             .config("spark.driver.memory", "8g") # half of ram on mac
             .config("spark.executor.memory", "4g")
             .config("spark.driver.maxResultSize", "2g") # increase for embeddings rapidly growing
             .config("spark.sql.shuffle.partitions", "8")
             .config("spark.executor.cores", "4")
             .config("spark.driver.extraJavaOptions", "-XX:+UseG1GC") # more efficient garbage collector. Helpful for large heaps containing embedded vectors 
             .getOrCreate())

    # Reduce shuffle partitions
    spark.conf.set("spark.sql.shuffle.partitions", "50")

    print("Spark Version:", spark.version)
    print("Spark UI: http://localhost:4040")
    return spark

In [None]:
def clean_resources(spark):
    spark.stop()
    return start_spark()

In [None]:
spark = start_spark()

In [None]:
import os

GOOGLE_DRIVE_LOCAL_MOUNT='/Users/sagana/Library/CloudStorage/GoogleDrive-sondande@uchicago.edu/.shortcut-targets-by-id/1O2pwlZERv3B7ki78Wn0brrpnArRBTFdH/MLI_2025 Winter/'

# Check if Google Drive is accessible
if os.path.exists(GOOGLE_DRIVE_LOCAL_MOUNT):
    print("Google Drive is mounted successfully!")
    print("Files in Drive:", os.listdir(GOOGLE_DRIVE_LOCAL_MOUNT))
else:
    print("Google Drive is not mounted. Please check your installation.")

In [None]:
from pyspark.sql.types import *
from pyspark.sql.functions import collect_set, collect_list, struct, col, when, count, countDistinct, lit
import pandas as pd
import ast

# Read in schema file and process to get schemas needed
schemas_df = spark.read.csv(f'{GOOGLE_DRIVE_LOCAL_MOUNT}/SQL DB Export/CSV/schema.csv', header=True)
schemas_df.show(5)

In [None]:
# Construct schema
radiology_schema_list = ast.literal_eval(schemas_df.filter(col("table") == 'radiology').select(col("schema")).collect()[0][0])
radiology_schema = StructType([
    StructField(x, StringType(), True) for x in radiology_schema_list
])

# Read in radiology dataset
radiology_df = spark.read.option("delimiter", "|").option("quote", '"').option("multiLine", "true").csv(f'{GOOGLE_DRIVE_LOCAL_MOUNT}/Sagana Outputs/Clinical Notes Creation/Input Data/radiology.csv', schema=radiology_schema)
radiology_df.show(truncate= 80)

In [None]:
# Read in radiology dataset
discharge_schema_list = ast.literal_eval(schemas_df.filter(col("table") == 'discharge').select(col("schema")).collect()[0][0])
discharge_schema = StructType([
    StructField(x, StringType(), True) for x in discharge_schema_list
])

discharge_df = spark.read.option("delimiter", "|").option("quote", '"').option("multiLine", "true").csv(f'{GOOGLE_DRIVE_LOCAL_MOUNT}/Sagana Outputs/Clinical Notes Creation/Input Data/discharge.csv', schema=discharge_schema)
discharge_df.show(truncate= 80)

In [None]:
# Select only required fields
radiology_df_filtered = radiology_df.select('subject_id', 'text')
discharge_df_filtered = discharge_df.select('subject_id', 'text')

In [None]:
# Filter for only notes where we have a patient to ensure we filter down datasets
patients_df = spark.read.csv(f'{GOOGLE_DRIVE_LOCAL_MOUNT}/JM outputs/patients_cleaned.csv', header=True)
patients_df.show(5)

In [None]:
final_radiology_df = radiology_df_filtered.join(patients_df, radiology_df_filtered.subject_id == patients_df.subject_id, 'left_semi')
final_radiology_df_updated = final_radiology_df.withColumnRenamed('text', 'radiology_text')

In [None]:
final_discharge_df = discharge_df_filtered.join(patients_df, discharge_df_filtered.subject_id == patients_df.subject_id, 'left_semi')
final_discharge_df_updated = final_discharge_df.withColumnRenamed('text', 'discharge_text')

In [None]:
combined_df = final_radiology_df_updated.join(final_discharge_df_updated, how='inner', on=['subject_id'])
combined_df.show(1)

In [None]:
combined_df.count()

In [None]:
# Sample the dataset for sample set for embeddings
# Sample without replacement to have a dataset that is representative of the original dataset
# Add seed for reproducibility

# Ensure sampling contains same subject_ids in both datasets
combined_df_sample = combined_df.sample(False, 0.002, 42)

In [None]:
combined_df_sample.count()

In [None]:
combined_df_sample.show()

In [None]:
# combined_df_sample.write.mode("overwrite").option("compression", "snappy").parquet('combined_cn_sample/')

In [None]:
combined_df_sample.printSchema()

In [None]:
# final_radiology_df_count  = final_radiology_df.count()
# distinct_radiology_df_count = final_radiology_df.select('subject_id').distinct().count()
# final_discharge_df_count = final_discharge_df.count()
# distinct_discharge_df_count = final_discharge_df.select('subject_id').distinct().count()

# sampled_rad_count = sampled_final_radiology_df.count()
# distinct_rad_count = sampled_final_radiology_df.select('subject_id').distinct().count()
# sampled_dis_count = sampled_final_discharge_df.count()
# distinct_dis_count = sampled_final_discharge_df.select('subject_id').distinct().count()

# print(f"Original Radiology Count: {final_radiology_df_count}, Distinct Radiology Count: {distinct_radiology_df_count}")
# print(f"Sampled Radiology Count: {sampled_rad_count}, Distinct Radiology Count: {distinct_rad_count}")
# print(f"Original Discharge Count: {final_discharge_df_count}, Distinct Discharge Count: {distinct_discharge_df_count}")
# print(f"Sampled Discharge Count: {sampled_dis_count}, Distinct Discharge Count: {distinct_dis_count}")

In [None]:
# spark.stop()

In [None]:
# # Suppose "sampled_df" has a vector column "embeddings"
# # 1. Convert Spark vector to array, then collect
# from pyspark.sql.functions import udf
# from pyspark.ml.linalg import VectorUDT
# import numpy as np

# # Convert Spark Vector to Python list
# def to_array(v):
#     return v.toArray().tolist()

# to_array_udf = udf(to_array, "array<double>")
# sampled_df_array = sampled_df.withColumn("embeddings_array", to_array_udf("embeddings"))

# # 2. Collect to Pandas
# pdf = sampled_df_array.select("embeddings_array").limit(5000).toPandas()  # limit for memory safety
# X = np.array(pdf["embeddings_array"].tolist())  # shape: (n_samples, embed_dim)

In [None]:
# final_radiology_df.write.mode("overwrite").option("compression", "snappy").parquet('radiology_filtered/')

In [None]:
# final_discharge_df.write.mode("overwrite").option("compression", "snappy").parquet('discharge_filtered/')

In [None]:
# spark.stop()

## Review Processed datasets

In [None]:
# discharge_processed_df = spark.read.parquet('discharge_processed/')
# radio_processed_df = spark.read.parquet('radiology_processed/')

In [None]:
# from pyspark.sql.functions import expr
# from pyspark.sql.functions import to_json, cola

# df = discharge_processed_df.withColumn("sections", to_json(col("sections")))  # Convert map column to JSON string
# df = df.withColumn("entities", to_json(col("entities")))
# # df.write.csv("output_directory", header=True, mode="overwrite")
# # df.coalesce(1).write.csv("discharge_processed_csv/", header=True, mode="overwrite")
# df_PD = df.toPandas()

In [None]:
# df_PD.to_csv('discharge_processed_csv/discharge_processed.csv', index=False)

In [None]:
# from pyspark.sql.functions import expr
# from pyspark.sql.functions import to_json, col

# df = radio_processed_df.withColumn("sections", to_json(col("sections")))  # Convert map column to JSON string
# df = df.withColumn("entities", to_json(col("entities")))
# # df.write.csv("output_directory", header=True, mode="overwrite")
# # df.coalesce(1).write.csv("discharge_processed_csv/", header=True, mode="overwrite")
# df_PD = df.toPandas()
# df_PD.to_csv('radiology_processed_csv/radio_processed.csv', index=False)

In [None]:
# df_PD.head()

## Review embeddings output

In [None]:
# spark = start_spark()

In [None]:
embedded_dis = spark.read.parquet('discharge_text/clinical_notes_sampled_embedded/')
embedded_radio = spark.read.parquet('radiology_text/clinical_notes_sampled_embedded/')

In [None]:
embedded_dis.show(1, truncate=200)

In [None]:
embedded_radio.show()

In [None]:
filtered_radio_emb = embedded_radio.select('subject_id', 'embedding_radiology_text')
filtered_dis_emb = embedded_dis.select('subject_id', 'embedding_discharge_text')

In [None]:
filtered_radio_emb.printSchema()

In [None]:
filtered_radio_agg = (
    filtered_radio_emb
    .groupBy("subject_id")
    .agg(
        collect_set("icd_code").alias("proc_codes")
    )
)

In [None]:
clinical_notes_combined_em = filtered_radio_emb.join(filtered_dis_emb, how='inner', on=['subject_id'])
clinical_notes_combined_em.show()

In [None]:
clinical_notes_combined_em.count()

In [None]:
# c_notes_PD = clinical_notes_combined_em.toPandas()
# c_notes_PD.head()

In [None]:
# c_notes_PD.to_csv('embedded_clinical_notes_combined.csv', index=False)

In [None]:
# c_notes_PD.info()

In [None]:
# clinical_notes_combined_em.write.mode("overwrite").option("compression", "snappy").parquet('embedded_clinical_notes_combined/')

In [None]:
spark.stop()