In [None]:
# Import required libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
import os
from pyspark.sql.functions import countDistinct
from pyspark.sql.functions import col, count, countDistinct, sum, when, lit, concat_ws, collect_list, size, desc, asc
from pyspark.sql.window import Window

In [None]:
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-17-openjdk-amd64'

spark = SparkSession.builder \
    .appName("MIMIC_Medication_Recommender_Optimized") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .config("spark.sql.adaptive.skewJoin.enabled", "true") \
    .config("spark.sql.adaptive.localShuffleReader.enabled", "true") \
    .config("spark.hadoop.fs.defaultFS", "hdfs://localhost:8020") \
    .config("spark.driver.memory", "2g") \
    .config("spark.driver.maxResultSize", "1g") \
    .config("spark.executor.memory", "2g") \
    .config("spark.executor.memoryFraction", "0.8") \
    .config("spark.executor.cores", "2") \
    .config("spark.default.parallelism", "8") \
    .config("spark.sql.shuffle.partitions", "8") \
    .config("spark.sql.adaptive.shuffle.targetPostShuffleInputSize", "64MB") \
    .config("spark.sql.adaptive.advisoryPartitionSizeInBytes", "64MB") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .config("spark.executor.extraJavaOptions", "-XX:+UseG1GC -XX:+UnlockExperimentalVMOptions -XX:+UseStringDeduplication") \
    .config("spark.driver.extraJavaOptions", "-XX:+UseG1GC -XX:+UnlockExperimentalVMOptions -XX:+UseStringDeduplication") \
    .getOrCreate()

# Set log level to reduce verbosity
spark.sparkContext.setLogLevel("WARN")

print("=== Spark session initialized ===")
print(f"Spark version: {spark.version}")
print(f"Master: {spark.conf.get('spark.master')}")
print(f"Driver memory: {spark.conf.get('spark.driver.memory')}")
print(f"Executor memory: {spark.conf.get('spark.executor.memory')}")
print(f"Shuffle partitions: {spark.conf.get('spark.sql.shuffle.partitions')}")

In [None]:
# Read Parquet files from HDFS (much more efficient than CSV)
hdfs_uri = "hdfs://namenode:8020"

print("Reading Parquet files from HDFS...")

# Read Parquet files (much more efficient than CSV)
diagnoses_df = spark.read.parquet(hdfs_uri + "/data/diagnoses_icd_part*.parquet")
prescriptions_df = spark.read.parquet(hdfs_uri + "/data/prescriptions_part*.parquet")
icd_descriptions_df = spark.read.parquet(hdfs_uri + "/data/d_icd_diagnoses_part*.parquet")
patients_df = spark.read.parquet(hdfs_uri + "/data/patients_part*.parquet")

print("✓ All Parquet files loaded successfully!")

In [None]:
# Optimized Data Processing with Memory Management
print("=== OPTIMIZED DATA PROCESSING ===")

# Cache frequently used DataFrames for better performance
print("Caching DataFrames for better performance...")
diagnoses_df.cache()
prescriptions_df.cache()
icd_descriptions_df.cache()
patients_df.cache()

# Get basic statistics
print(f"Diagnoses records: {diagnoses_df.count():,}")
print(f"Prescriptions records: {prescriptions_df.count():,}")
print(f"ICD descriptions: {icd_descriptions_df.count():,}")
print(f"Patients: {patients_df.count():,}")

# Clean data with better memory management
print("\nCleaning data...")
diagnoses_clean = diagnoses_df.filter(col("icd_code").isNotNull() & col("hadm_id").isNotNull())
prescriptions_clean = prescriptions_df.filter(col("drug").isNotNull() & col("hadm_id").isNotNull())

# Repartition for better performance
diagnoses_clean = diagnoses_clean.repartition(8, "hadm_id")
prescriptions_clean = prescriptions_clean.repartition(8, "hadm_id")

print("✓ Data cleaned and optimized for processing")

In [None]:
# Optimized Join and Aggregation with Memory Management
print("=== OPTIMIZED JOIN AND AGGREGATION ===")

# Join data with optimized strategy
print("Performing optimized joins...")
diagnosis_med_join = diagnoses_clean.alias("d").join(
    prescriptions_clean.alias("p"), 
    "hadm_id", 
    "inner"
).join(
    icd_descriptions_df.alias("desc"),
    ["icd_code", "icd_version"],
    "left"
)

# Repartition the joined data for better performance
diagnosis_med_join = diagnosis_med_join.repartition(16, "icd_code")

# Optimized aggregation with better memory management
print("Performing optimized aggregation...")
medication_frequency = diagnosis_med_join.groupBy(
    col("d.icd_code"),
    col("desc.long_title").alias("diagnosis_description"),
    col("p.drug")
).agg(
    count("*").alias("prescription_count"),
    countDistinct("d.hadm_id").alias("unique_patients_count")
)

# Cache the aggregated result
medication_frequency.cache()

print("✓ Join and aggregation completed successfully")

# Thêm ranking cho mỗi diagnosis
window_spec = Window.partitionBy("icd_code").orderBy(col("prescription_count").desc())

medication_ranking = medication_frequency.withColumn(
    "rank", 
    row_number().over(window_spec)
)

# Lấy top 5 thuốc cho mỗi chẩn đoán
top_medications_per_diagnosis = medication_ranking.filter(col("rank") <= 5)

In [None]:
# Optimized Recommendation Table Creation
print("=== CREATING RECOMMENDATION TABLE ===")

# Add ranking for each diagnosis with optimized window function
print("Adding rankings...")
window_spec = Window.partitionBy("icd_code").orderBy(col("prescription_count").desc())

medication_ranking = medication_frequency.withColumn(
    "rank", 
    row_number().over(window_spec)
)

# Get top 5 medications for each diagnosis
top_medications_per_diagnosis = medication_ranking.filter(col("rank") <= 5)

# Create recommendation table with optimized calculations
print("Creating final recommendation table...")
recommendation_table = top_medications_per_diagnosis.select(
    col("icd_code"),
    col("diagnosis_description"),
    col("drug"),
    col("prescription_count"),
    col("unique_patients_count"),
    col("rank"),
    format_string(
        "%.2f", 
        col("prescription_count") * 100.0 / sum("prescription_count").over(Window.partitionBy("icd_code"))
    ).alias("percentage_within_diagnosis")
)

# Add diagnosis statistics
diagnosis_stats = medication_frequency.groupBy("icd_code").agg(
    count("drug").alias("total_different_meds"),
    sum("prescription_count").alias("total_prescriptions")
)

final_recommendation_table = recommendation_table.join(
    diagnosis_stats, "icd_code", "left"
)

# Cache the final result
final_recommendation_table.cache()

print("✓ Recommendation table created successfully")
print(f"Total recommendations: {final_recommendation_table.count():,}")

In [None]:
# Save Results Efficiently
print("=== SAVING RESULTS ===")

# Create output directory in HDFS
hdfs_output_path = "hdfs://namenode:8020/processed/medication_recommendations"

print("Saving recommendation table to HDFS...")
try:
    # Save as Parquet with optimized settings
    final_recommendation_table.write \
        .mode("overwrite") \
        .option("compression", "snappy") \
        .parquet(hdfs_output_path)
    
    print("✓ Successfully saved to HDFS!")
    print(f"Output path: {hdfs_output_path}")
    
    # Show sample results
    print("\n=== SAMPLE RECOMMENDATIONS ===")
    final_recommendation_table.show(20, truncate=False)
    
    # Show summary statistics
    print("\n=== SUMMARY STATISTICS ===")
    print(f"Total unique diagnoses: {final_recommendation_table.select('icd_code').distinct().count():,}")
    print(f"Total unique medications: {final_recommendation_table.select('drug').distinct().count():,}")
    print(f"Average medications per diagnosis: {final_recommendation_table.groupBy('icd_code').count().agg({'count': 'avg'}).collect()[0][0]:.1f}")
    
except Exception as e:
    print(f"Error saving results: {e}")
    print("Trying alternative save method...")
    
    # Alternative: Save as CSV if Parquet fails
    try:
        final_recommendation_table.write \
            .mode("overwrite") \
            .option("header", "true") \
            .csv(hdfs_output_path + "_csv")
        print("✓ Successfully saved as CSV!")
    except Exception as e2:
        print(f"Failed to save results: {e2}")

print("\n=== PROCESSING COMPLETED ===")

In [None]:
# Clean up and close Spark session
print("=== CLEANING UP ===")

# Unpersist cached DataFrames to free memory
try:
    diagnoses_df.unpersist()
    prescriptions_df.unpersist()
    icd_descriptions_df.unpersist()
    patients_df.unpersist()
    medication_frequency.unpersist()
    final_recommendation_table.unpersist()
    print("✓ Cached DataFrames cleared")
except:
    pass

# Stop Spark session
spark.stop()
print("✓ Spark session stopped")
print("=== NOTEBOOK EXECUTION COMPLETED ===")