# ICU Length of Stay Prediction - MIMIC-III Pipeline

## üéØ Objective
Predict ICU stay duration using PySpark ML on MIMIC-III dataset

## üìä Data & Constraints
- **Sources**: 6 MIMIC-III tables (CHARTEVENTS, LABEVENTS, ICUSTAYS, etc.)
- **Filters**: 
        - Patient Age 18-80
        - LOS 0.1-15 days
        - Valid time sequences
- **Timeframe**: Vitals (first 24h), Labs (6h pre to 24h post ICU)


## üåÄ Big Data Processing

- **Storage**: We used Google Cloud Dataproc and Google Storage Buckets for MIMIC-III storage 
- **CHARTEVENTS**: Chart Events table has +330 million rows
- **Parquet**: Converted "CHARTEVENTS" and "LABEVENTS" tables to Parquet format for efficient storage and processing
- **Filtering**: We filtered immediately when loading to optimize CHARTEVENTS DataFrame

## üîß Features (39 total)
- **Demographics (2)**: Age, gender
- **Admission (8)**: Emergency/elective, timing, insurance
- **ICU Units (6)**: Care unit types, transfers
- **Vitals (11)**: HR, BP, RR, temp, SpO2 (avg/std)
- **Labs (8)**: Creatinine, glucose, electrolytes, blood counts
- **Diagnoses (4)**: Total count, sepsis, respiratory failure

## ü§ñ Models & Results
- **Linear Regression**: 
- **Random Forest**: 

## ‚òÅÔ∏è Infrastructure
- **GCP Dataproc**: 1x Master and 2x Workers, n2-standard-4  (12 vCPUs, 48GB RAM, 400GB Disk Storage)
- **Optimizations**: Smart sampling, aggressive filtering, 80/20 split





## Cenas a acresentar no relatorio:

* justificar o pq de cada uma das colunas
* dar tune aos hiperparametros do modelo
* referencias e bibliografias :
    *

## Import Libraries

In [2]:
# üì¶ PySpark Core Imports
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.functions import col, sum as sql_sum, count
from pyspark.sql.types import *
from pyspark.sql.window import Window


# üî¢ Data Processing & Feature Engineering
from pyspark.ml.feature import (
    VectorAssembler,
    StandardScaler,
    StringIndexer,
    MinMaxScaler,
    Imputer
)
from pyspark.ml.functions import vector_to_array


# ü§ñ Machine Learning Models
from pyspark.ml.regression import (
    RandomForestRegressor,
    LinearRegression
    # GBTRegressor
)


# üìä Model Evaluation & Tuning
from pyspark.ml.evaluation import RegressionEvaluator
# from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml import Pipeline
import operator
import builtins


# ‚è±Ô∏è Date/Time Utilities
from datetime import datetime, timedelta
import time


print("\n‚úÖ All imports loaded successfully!")
print(f"‚è∞ Notebook started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")


‚úÖ All imports loaded successfully!
‚è∞ Notebook started at: 2025-06-07 12:46:29



## Setup Spark Session

In [3]:
spark = SparkSession.builder \
    .appName("Forecast-LOS") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .config("spark.sql.adaptive.skewJoin.enabled", "true") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    \
    .config("spark.executor.memory", "5g") \
    .config("spark.executor.cores", "2") \
    .config("spark.executor.instances", "2") \
    \
    .config("spark.driver.memory", "10g") \
    .config("spark.driver.cores", "3") \
    .config("spark.driver.maxResultSize", "2g") \
    \
    .config("spark.sql.shuffle.partitions", "32") \
    .config("spark.sql.adaptive.advisoryPartitionSizeInBytes", "64MB") \
    .config("spark.sql.files.maxPartitionBytes", "128MB") \
    .config("spark.sql.adaptive.maxShuffledHashJoinLocalMapThreshold", "32MB") \
    \
    .config("spark.network.timeout", "600s") \
    .config("spark.sql.broadcastTimeout", "300s") \
    .config("spark.rpc.askTimeout", "300s") \
    \
    .config("spark.executor.heartbeatInterval", "20s") \
    .config("spark.dynamicAllocation.enabled", "false") \
    \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .config("spark.sql.adaptive.localShuffleReader.enabled", "true") \
    .config("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "128MB") \
    \
    .config("spark.executor.memoryOffHeap.enabled", "true") \
    .config("spark.executor.memoryOffHeap.size", "1g") \
    \
    .getOrCreate()
print("‚úÖ Spark session created successfully!")
print(f"üìä Spark Version: {spark.version}")
print(f"üîß Application Name: {spark.sparkContext.appName}")
print(f"üíæ Available cores: {spark.sparkContext.defaultParallelism}")
print(f"\n‚è∞ Spark session initialised at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/06/07 12:46:31 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


‚úÖ Spark session created successfully!
üìä Spark Version: 4.0.0
üîß Application Name: Forecast-LOS
üíæ Available cores: 8

‚è∞ Spark session initialised at: 2025-06-07 12:46:33


# Load Data

Strategy: Pre-filter CHARTEVENTS to find ICU stays with required vital signs, then efficiently load all tables using broadcast joins and lookup tables.
Key Steps:

- Filter for ICU stays with ‚â•1 of 6 vital signs (HR, BP, RR, Temp, SpO2)
- Create lookup tables for ICUSTAY_ID, HADM_ID, SUBJECT_ID
- Load all tables with pre-filtering using broadcast joins
- Convert large files to "Parquet" for performance

Result: Memory-efficient loading of only relevant data with quality assurance that all ICU stays have vital signs measurements.

In [4]:
# Configuration flags
SAMPLE_ENABLE = False
SAMPLE_SIZE = 20000
MIMIC_PATH = "gs://dataproc-staging-europe-west2-851143487985-hir6gfre/mimic-data"



print("üè• Loading MIMIC-III data...")

# Step 1: First, find ICUSTAY_IDs that have ALL required vital signs
print("üìÇ Loading CHARTEVENTS...")

try:
   chartevents_df = spark.read.parquet(f"{MIMIC_PATH}/CHARTEVENTS.parquet")
   print("‚úÖ Loaded CHARTEVENTS from parquet")
except:
   print("üìÑ Converting CHARTEVENTS.csv.gz to parquet...")
   chartevents_csv = spark.read.option("header", "true").option("inferSchema", "false").csv(f"{MIMIC_PATH}/CHARTEVENTS.csv.gz")
   chartevents_csv.write.mode("overwrite").parquet(f"{MIMIC_PATH}/CHARTEVENTS.parquet")
   chartevents_df = spark.read.parquet(f"{MIMIC_PATH}/CHARTEVENTS.parquet")
   print("‚úÖ Converted and loaded CHARTEVENTS")



# Step 2: Load ICUSTAYS 
print("\nüìÇ Loading and filtering ICUSTAYS...")
icustays_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/ICUSTAYS.csv.gz")



# Step 3: Apply sampling if enabled
if SAMPLE_ENABLE:
   print(f"üéØ Sampling {SAMPLE_SIZE} ICU stays...")
   icustays_df = icustays_df.limit(SAMPLE_SIZE)
   icustays_df.cache()
   actual_sample_size = icustays_df.count()
   print(f"‚úÖ Final sample: {actual_sample_size} ICU stays")
else:
   icustays_df.cache()
   actual_sample_size = icustays_df.count()

   
   
# Step 4: Create efficient lookup tables
print("üìã Creating ID lookup tables...")
icu_lookup = icustays_df.select("ICUSTAY_ID").distinct().cache()
hadm_lookup = icustays_df.select("HADM_ID").distinct().cache()
subject_lookup = icustays_df.select("SUBJECT_ID").distinct().cache()

icu_lookup.count()  # Trigger caching
hadm_lookup.count()
subject_lookup.count()

# Step 5: Load other tables with optimized joins
print("üìÇ Loading PATIENTS table...")
patients_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/PATIENTS.csv.gz")
patients_df = patients_df.join(broadcast(subject_lookup), "SUBJECT_ID", "inner")

print("üìÇ Loading ADMISSIONS table...")
admissions_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/ADMISSIONS.csv.gz")
admissions_df = admissions_df.join(broadcast(hadm_lookup), "HADM_ID", "inner")

print("üìÇ Loading DIAGNOSES_ICD table...")
diagnoses_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/DIAGNOSES_ICD.csv.gz")
diagnoses_df = diagnoses_df.join(broadcast(hadm_lookup), "HADM_ID", "inner")

# Step 6: Load and filter CHARTEVENTS efficiently
print("üìÇ Loading CHARTEVENTS table... [FILTERING BY ICUSTAY_ID]")
chartevents_df = chartevents_df \
   .select("ICUSTAY_ID", "CHARTTIME", "ITEMID", "VALUE", "VALUEUOM", "VALUENUM") \
   .join(broadcast(icu_lookup), "ICUSTAY_ID", "inner")

# Step 7: Load LABEVENTS
print("üìÇ Loading LABEVENTS table... [FILTERING BY HADM_ID]")
try:
   labevents_df = spark.read.parquet(f"{MIMIC_PATH}/LABEVENTS.parquet")
except:
   print("üìÑ Converting LABEVENTS.csv.gz to parquet...")
   labevents_csv = spark.read.option("header", "true").option("inferSchema", "false").csv(f"{MIMIC_PATH}/LABEVENTS.csv.gz")
   labevents_csv.write.mode("overwrite").parquet(f"{MIMIC_PATH}/LABEVENTS.parquet")
   labevents_df = spark.read.parquet(f"{MIMIC_PATH}/LABEVENTS.parquet")

labevents_df = labevents_df.join(broadcast(hadm_lookup), "HADM_ID", "inner")

# Final summary
print("\n‚úÖ Data loading complete!")
print(f"üìä ICUSTAYS: {icustays_df.count():,} rows")
print(f"üìä PATIENTS: {patients_df.count():,} rows") 
print(f"üìä ADMISSIONS: {admissions_df.count():,} rows")
print(f"üìä DIAGNOSES_ICD: {diagnoses_df.count():,} rows")
print(f"üìä CHARTEVENTS (filtered): {chartevents_df.count():,} rows")
print(f"üìä LABEVENTS (filtered): {labevents_df.count():,} rows")
print(f"\n‚è∞ Data loaded at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

üè• Loading MIMIC-III data...
üìÇ Loading CHARTEVENTS...


25/06/07 12:46:34 WARN FileStreamSink: Assume no metadata directory. Error while looking for metadata directory in the path: gs://dataproc-staging-europe-west2-851143487985-hir6gfre/mimic-data/CHARTEVENTS.parquet.
org.apache.hadoop.fs.UnsupportedFileSystemException: No FileSystem for scheme "gs"
	at org.apache.hadoop.fs.FileSystem.getFileSystemClass(FileSystem.java:3581)
	at org.apache.hadoop.fs.FileSystem.createFileSystem(FileSystem.java:3612)
	at org.apache.hadoop.fs.FileSystem.access$300(FileSystem.java:172)
	at org.apache.hadoop.fs.FileSystem$Cache.getInternal(FileSystem.java:3716)
	at org.apache.hadoop.fs.FileSystem$Cache.get(FileSystem.java:3667)
	at org.apache.hadoop.fs.FileSystem.get(FileSystem.java:557)
	at org.apache.hadoop.fs.Path.getFileSystem(Path.java:366)
	at org.apache.spark.sql.execution.streaming.FileStreamSink$.hasMetadata(FileStreamSink.scala:55)
	at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:381)
	at org.apache.spark.sql.

üìÑ Converting CHARTEVENTS.csv.gz to parquet...


Py4JJavaError: An error occurred while calling o83.csv.
: org.apache.hadoop.fs.UnsupportedFileSystemException: No FileSystem for scheme "gs"
	at org.apache.hadoop.fs.FileSystem.getFileSystemClass(FileSystem.java:3581)
	at org.apache.hadoop.fs.FileSystem.createFileSystem(FileSystem.java:3612)
	at org.apache.hadoop.fs.FileSystem.access$300(FileSystem.java:172)
	at org.apache.hadoop.fs.FileSystem$Cache.getInternal(FileSystem.java:3716)
	at org.apache.hadoop.fs.FileSystem$Cache.get(FileSystem.java:3667)
	at org.apache.hadoop.fs.FileSystem.get(FileSystem.java:557)
	at org.apache.hadoop.fs.Path.getFileSystem(Path.java:366)
	at org.apache.spark.sql.execution.datasources.DataSource$.$anonfun$checkAndGlobPathIfNecessary$1(DataSource.scala:777)
	at scala.collection.immutable.List.map(List.scala:247)
	at scala.collection.immutable.List.map(List.scala:79)
	at org.apache.spark.sql.execution.datasources.DataSource$.checkAndGlobPathIfNecessary(DataSource.scala:775)
	at org.apache.spark.sql.execution.datasources.DataSource.checkAndGlobPathIfNecessary(DataSource.scala:575)
	at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:419)
	at org.apache.spark.sql.catalyst.analysis.ResolveDataSource.org$apache$spark$sql$catalyst$analysis$ResolveDataSource$$loadV1BatchSource(ResolveDataSource.scala:143)
	at org.apache.spark.sql.catalyst.analysis.ResolveDataSource$$anonfun$apply$1.$anonfun$applyOrElse$2(ResolveDataSource.scala:61)
	at scala.Option.getOrElse(Option.scala:201)
	at org.apache.spark.sql.catalyst.analysis.ResolveDataSource$$anonfun$apply$1.applyOrElse(ResolveDataSource.scala:61)
	at org.apache.spark.sql.catalyst.analysis.ResolveDataSource$$anonfun$apply$1.applyOrElse(ResolveDataSource.scala:45)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUpWithPruning$3(AnalysisHelper.scala:139)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:86)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUpWithPruning$1(AnalysisHelper.scala:139)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.allowInvokingTransformsInAnalyzer(AnalysisHelper.scala:416)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUpWithPruning(AnalysisHelper.scala:135)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUpWithPruning$(AnalysisHelper.scala:131)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperatorsUpWithPruning(LogicalPlan.scala:37)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp(AnalysisHelper.scala:112)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp$(AnalysisHelper.scala:111)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperatorsUp(LogicalPlan.scala:37)
	at org.apache.spark.sql.catalyst.analysis.ResolveDataSource.apply(ResolveDataSource.scala:45)
	at org.apache.spark.sql.catalyst.analysis.ResolveDataSource.apply(ResolveDataSource.scala:43)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$2(RuleExecutor.scala:242)
	at scala.collection.LinearSeqOps.foldLeft(LinearSeq.scala:183)
	at scala.collection.LinearSeqOps.foldLeft$(LinearSeq.scala:179)
	at scala.collection.immutable.List.foldLeft(List.scala:79)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:239)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:231)
	at scala.collection.immutable.List.foreach(List.scala:334)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:231)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.org$apache$spark$sql$catalyst$analysis$Analyzer$$executeSameContext(Analyzer.scala:290)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$execute$1(Analyzer.scala:286)
	at org.apache.spark.sql.catalyst.analysis.AnalysisContext$.withNewAnalysisContext(Analyzer.scala:234)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:286)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:249)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$executeAndTrack$1(RuleExecutor.scala:201)
	at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:89)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.executeAndTrack(RuleExecutor.scala:201)
	at org.apache.spark.sql.catalyst.analysis.resolver.HybridAnalyzer.resolveInFixedPoint(HybridAnalyzer.scala:190)
	at org.apache.spark.sql.catalyst.analysis.resolver.HybridAnalyzer.$anonfun$apply$1(HybridAnalyzer.scala:76)
	at org.apache.spark.sql.catalyst.analysis.resolver.HybridAnalyzer.withTrackedAnalyzerBridgeState(HybridAnalyzer.scala:111)
	at org.apache.spark.sql.catalyst.analysis.resolver.HybridAnalyzer.apply(HybridAnalyzer.scala:71)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:280)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:423)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:280)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$lazyAnalyzed$2(QueryExecution.scala:110)
	at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:148)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$2(QueryExecution.scala:278)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:654)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:278)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:804)
	at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:277)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$lazyAnalyzed$1(QueryExecution.scala:110)
	at scala.util.Try$.apply(Try.scala:217)
	at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1378)
	at org.apache.spark.util.Utils$.getTryWithCallerStacktrace(Utils.scala:1439)
	at org.apache.spark.util.LazyTry.get(LazyTry.scala:58)
	at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:121)
	at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:80)
	at org.apache.spark.sql.classic.Dataset$.$anonfun$ofRows$1(Dataset.scala:115)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:804)
	at org.apache.spark.sql.classic.Dataset$.ofRows(Dataset.scala:113)
	at org.apache.spark.sql.classic.DataFrameReader.load(DataFrameReader.scala:109)
	at org.apache.spark.sql.classic.DataFrameReader.load(DataFrameReader.scala:58)
	at org.apache.spark.sql.DataFrameReader.csv(DataFrameReader.scala:392)
	at org.apache.spark.sql.classic.DataFrameReader.csv(DataFrameReader.scala:259)
	at org.apache.spark.sql.classic.DataFrameReader.csv(DataFrameReader.scala:58)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:569)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:184)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:108)
	at java.base/java.lang.Thread.run(Thread.java:840)
	Suppressed: org.apache.spark.util.Utils$OriginalTryStackTraceException: Full stacktrace of original doTryWithCallerStacktrace caller
		at org.apache.hadoop.fs.FileSystem.getFileSystemClass(FileSystem.java:3581)
		at org.apache.hadoop.fs.FileSystem.createFileSystem(FileSystem.java:3612)
		at org.apache.hadoop.fs.FileSystem.access$300(FileSystem.java:172)
		at org.apache.hadoop.fs.FileSystem$Cache.getInternal(FileSystem.java:3716)
		at org.apache.hadoop.fs.FileSystem$Cache.get(FileSystem.java:3667)
		at org.apache.hadoop.fs.FileSystem.get(FileSystem.java:557)
		at org.apache.hadoop.fs.Path.getFileSystem(Path.java:366)
		at org.apache.spark.sql.execution.datasources.DataSource$.$anonfun$checkAndGlobPathIfNecessary$1(DataSource.scala:777)
		at scala.collection.immutable.List.map(List.scala:247)
		at scala.collection.immutable.List.map(List.scala:79)
		at org.apache.spark.sql.execution.datasources.DataSource$.checkAndGlobPathIfNecessary(DataSource.scala:775)
		at org.apache.spark.sql.execution.datasources.DataSource.checkAndGlobPathIfNecessary(DataSource.scala:575)
		at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:419)
		at org.apache.spark.sql.catalyst.analysis.ResolveDataSource.org$apache$spark$sql$catalyst$analysis$ResolveDataSource$$loadV1BatchSource(ResolveDataSource.scala:143)
		at org.apache.spark.sql.catalyst.analysis.ResolveDataSource$$anonfun$apply$1.$anonfun$applyOrElse$2(ResolveDataSource.scala:61)
		at scala.Option.getOrElse(Option.scala:201)
		at org.apache.spark.sql.catalyst.analysis.ResolveDataSource$$anonfun$apply$1.applyOrElse(ResolveDataSource.scala:61)
		at org.apache.spark.sql.catalyst.analysis.ResolveDataSource$$anonfun$apply$1.applyOrElse(ResolveDataSource.scala:45)
		at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUpWithPruning$3(AnalysisHelper.scala:139)
		at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:86)
		at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUpWithPruning$1(AnalysisHelper.scala:139)
		at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.allowInvokingTransformsInAnalyzer(AnalysisHelper.scala:416)
		at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUpWithPruning(AnalysisHelper.scala:135)
		at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUpWithPruning$(AnalysisHelper.scala:131)
		at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperatorsUpWithPruning(LogicalPlan.scala:37)
		at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp(AnalysisHelper.scala:112)
		at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp$(AnalysisHelper.scala:111)
		at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperatorsUp(LogicalPlan.scala:37)
		at org.apache.spark.sql.catalyst.analysis.ResolveDataSource.apply(ResolveDataSource.scala:45)
		at org.apache.spark.sql.catalyst.analysis.ResolveDataSource.apply(ResolveDataSource.scala:43)
		at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$2(RuleExecutor.scala:242)
		at scala.collection.LinearSeqOps.foldLeft(LinearSeq.scala:183)
		at scala.collection.LinearSeqOps.foldLeft$(LinearSeq.scala:179)
		at scala.collection.immutable.List.foldLeft(List.scala:79)
		at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:239)
		at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:231)
		at scala.collection.immutable.List.foreach(List.scala:334)
		at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:231)
		at org.apache.spark.sql.catalyst.analysis.Analyzer.org$apache$spark$sql$catalyst$analysis$Analyzer$$executeSameContext(Analyzer.scala:290)
		at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$execute$1(Analyzer.scala:286)
		at org.apache.spark.sql.catalyst.analysis.AnalysisContext$.withNewAnalysisContext(Analyzer.scala:234)
		at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:286)
		at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:249)
		at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$executeAndTrack$1(RuleExecutor.scala:201)
		at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:89)
		at org.apache.spark.sql.catalyst.rules.RuleExecutor.executeAndTrack(RuleExecutor.scala:201)
		at org.apache.spark.sql.catalyst.analysis.resolver.HybridAnalyzer.resolveInFixedPoint(HybridAnalyzer.scala:190)
		at org.apache.spark.sql.catalyst.analysis.resolver.HybridAnalyzer.$anonfun$apply$1(HybridAnalyzer.scala:76)
		at org.apache.spark.sql.catalyst.analysis.resolver.HybridAnalyzer.withTrackedAnalyzerBridgeState(HybridAnalyzer.scala:111)
		at org.apache.spark.sql.catalyst.analysis.resolver.HybridAnalyzer.apply(HybridAnalyzer.scala:71)
		at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:280)
		at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:423)
		at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:280)
		at org.apache.spark.sql.execution.QueryExecution.$anonfun$lazyAnalyzed$2(QueryExecution.scala:110)
		at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:148)
		at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$2(QueryExecution.scala:278)
		at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:654)
		at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:278)
		at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:804)
		at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:277)
		at org.apache.spark.sql.execution.QueryExecution.$anonfun$lazyAnalyzed$1(QueryExecution.scala:110)
		at scala.util.Try$.apply(Try.scala:217)
		at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1378)
		at org.apache.spark.util.LazyTry.tryT$lzycompute(LazyTry.scala:46)
		at org.apache.spark.util.LazyTry.tryT(LazyTry.scala:46)
		... 23 more


In [None]:
"""#Configuration flags
SAMPLE_ENABLE = False
SAMPLE_SIZE = 20000
MIMIC_PATH = "mimic-db-short"



print("üè• Loading MIMIC-III data...")

# Step 1: First, find ICUSTAY_IDs that have ALL required vital signs
print("üìÇ Loading CHARTEVENTS...")


chartevents_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/CHARTEVENTS.csv")
print("‚úÖ Loaded CHARTEVENTS from parquet")
d_items_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/D_ITEMS.csv")




# Step 2: Load ICUSTAYS 
print("\nüìÇ Loading and filtering ICUSTAYS...")
icustays_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/ICUSTAYS.csv")



# Step 3: Apply sampling if enabled
if SAMPLE_ENABLE:
    print(f"üéØ Sampling {SAMPLE_SIZE} ICU stays...")
    icustays_df = icustays_df.limit(SAMPLE_SIZE)
    icustays_df.cache()
    actual_sample_size = icustays_df.count()
    print(f"‚úÖ Final sample: {actual_sample_size} ICU stays")
else:
    icustays_df.cache()
    actual_sample_size = icustays_df.count()

  
  
# Step 4: Create efficient lookup tables
print("üìã Creating ID lookup tables...")
icu_lookup = icustays_df.select("ICUSTAY_ID").distinct().cache()
hadm_lookup = icustays_df.select("HADM_ID").distinct().cache()
subject_lookup = icustays_df.select("SUBJECT_ID").distinct().cache()

icu_lookup.count()  # Trigger caching
hadm_lookup.count()
subject_lookup.count()

# Step 5: Load other tables with optimized joins
print("üìÇ Loading PATIENTS table...")
patients_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/PATIENTS.csv")
patients_df = patients_df.join(broadcast(subject_lookup), "SUBJECT_ID", "inner")

print("üìÇ Loading ADMISSIONS table...")
admissions_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/ADMISSIONS.csv")
admissions_df = admissions_df.join(broadcast(hadm_lookup), "HADM_ID", "inner")

print("üìÇ Loading DIAGNOSES_ICD table...")
diagnoses_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/DIAGNOSES_ICD.csv")
diagnoses_df = diagnoses_df.join(broadcast(hadm_lookup), "HADM_ID", "inner")

# Step 6: Load and filter CHARTEVENTS efficiently
print("üìÇ Loading CHARTEVENTS table... [FILTERING BY ICUSTAY_ID]")
chartevents_df = chartevents_df \
    .select("ICUSTAY_ID", "CHARTTIME", "ITEMID", "VALUE", "VALUEUOM", "VALUENUM") \
    .join(broadcast(icu_lookup), "ICUSTAY_ID", "inner")

# Step 7: Load LABEVENTS
print("üìÇ Loading LABEVENTS table... [FILTERING BY HADM_ID]")

labevents_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/LABEVENTS.csv")
d_labitems_df = spark.read.option("header", "true").option("inferSchema", "true").csv(f"{MIMIC_PATH}/D_LABITEMS.csv")


labevents_df = labevents_df.join(broadcast(hadm_lookup), "HADM_ID", "inner")

# Final summary
print("\n‚úÖ Data loading complete!")
print(f"üìä ICUSTAYS: {icustays_df.count():,} rows")
print(f"üìä PATIENTS: {patients_df.count():,} rows") 
print(f"üìä ADMISSIONS: {admissions_df.count():,} rows")
print(f"üìä DIAGNOSES_ICD: {diagnoses_df.count():,} rows")
print(f"üìä CHARTEVENTS (filtered): {chartevents_df.count():,} rows")
print(f"üìä LABEVENTS (filtered): {labevents_df.count():,} rows")
print(f"\n‚è∞ Data loaded at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")"""

üè• Loading MIMIC-III data...
üìÇ Loading CHARTEVENTS...
‚úÖ Loaded CHARTEVENTS from parquet

üìÇ Loading and filtering ICUSTAYS...
üìã Creating ID lookup tables...
üìÇ Loading PATIENTS table...


25/06/07 12:41:40 WARN CacheManager: Asked to cache already cached data.
25/06/07 12:41:40 WARN CacheManager: Asked to cache already cached data.
25/06/07 12:41:40 WARN CacheManager: Asked to cache already cached data.


üìÇ Loading ADMISSIONS table...
üìÇ Loading DIAGNOSES_ICD table...
üìÇ Loading CHARTEVENTS table... [FILTERING BY ICUSTAY_ID]
üìÇ Loading LABEVENTS table... [FILTERING BY HADM_ID]

‚úÖ Data loading complete!
üìä ICUSTAYS: 20 rows
üìä PATIENTS: 20 rows
üìä ADMISSIONS: 20 rows
üìä DIAGNOSES_ICD: 212 rows
üìä CHARTEVENTS (filtered): 57,973 rows
üìä LABEVENTS (filtered): 5,895 rows

‚è∞ Data loaded at: 2025-06-07 12:41:41


# Features Engineering



## Extracting Data From ICUSTAYS

**Purpose**: Create comprehensive ICU dataset by joining ICU stays with patient demographics and admission details.

**Key Features**:
- **Target Variable**: ICU_LOS_DAYS (length of stay)
- **Demographics**: Age (18-80), gender, ethnicity
- **Clinical**: Care units, admission type/location, insurance
- **Outcomes**: Hospital/patient death flags
- **Identifiers**: ICUSTAY_ID, SUBJECT_ID, HADM_ID

**Age Filter**: Adults only (18-80 years) to exclude pediatric/very elderly edge cases.

**Alive Filter**: Only include people who did survive the ICU stay.

**LOS Filter**: Get only LOS values within a range that does'nt include outliers.

**Result**: Clean base dataset ready for vital signs feature engineering.

In [None]:
print("üìä Step 1: Creating base ICU dataset with patient demographics...")

base_icu_df = icustays_df.alias("icu") \
    .join(patients_df.alias("pat"), "SUBJECT_ID", "inner") \
    .join(admissions_df.alias("adm"), ["SUBJECT_ID", "HADM_ID"], "inner") \
    .select(
        # ICU stay identifiers
        col("icu.ICUSTAY_ID"),
        col("icu.SUBJECT_ID"), 
        col("icu.HADM_ID"),
        
        # Target variable - Length of Stay in ICU (days)
        col("icu.LOS").alias("ICU_LOS_DAYS"),
        
        # ICU characteristics
        col("icu.FIRST_CAREUNIT"),
        col("icu.LAST_CAREUNIT"), 
        col("icu.INTIME").alias("ICU_INTIME"),
        col("icu.OUTTIME").alias("ICU_OUTTIME"),
        
        # Patient demographics
        col("pat.GENDER"),
        col("pat.EXPIRE_FLAG").alias("PATIENT_DIED"),
        col("pat.DOB"),
        
        # Admission details
        col("adm.ADMITTIME"),
        col("adm.DISCHTIME"), 
        col("adm.ADMISSION_TYPE"),
        col("adm.ADMISSION_LOCATION"),
        col("adm.INSURANCE"),
        col("adm.ETHNICITY"),
        col("adm.MARITAL_STATUS"),
        col("adm.RELIGION"),
        col("adm.HOSPITAL_EXPIRE_FLAG").alias("HOSPITAL_DEATH"),
        col("adm.DIAGNOSIS").alias("ADMISSION_DIAGNOSIS")
    )

# Calculate age at ICU admission
base_icu_df = base_icu_df.withColumn("AGE_AT_ICU_ADMISSION", \
                                     floor(datediff(col("ICU_INTIME"), col("DOB")) / 365.25)) \
                                     .filter(col("AGE_AT_ICU_ADMISSION").between(18,80)) \
                                    .filter(col("PATIENT_DIED").isin(0))


icustays_df.unpersist()


print("‚úÖ Created base ICU dataset!")

üìä Step 1: Creating base ICU dataset with patient demographics...
‚úÖ Created base ICU dataset!


In [None]:
print("\nüìà ICU Length of Stay Statistics (Days):")
base_icu_df.select("ICU_LOS_DAYS").describe().show()


üìà ICU Length of Stay Statistics (Days):
+-------+-----------------+
|summary|     ICU_LOS_DAYS|
+-------+-----------------+
|  count|               12|
|   mean|2.360116666666667|
| stddev|2.264151572918985|
|    min|            0.848|
|    max|           8.9163|
+-------+-----------------+



We kept every ICU STAY that had duration (LOS) between 0.0 and 9.1 days, considered normal legnths since:

| Statistic                | Value (days)                                    |
| ------------------------ | ----------------------------------------------- |
| **Minimum**              | 0.0 (can be admission + discharge on same day)  |
| **25th percentile (Q1)** | \~1.1                                           |
| **Median (Q2)**          | \~2.                                            |
| **75th percentile (Q3)** | \~4.3                                           |
| **Maximum**              | \~101.739                                       |
| **Mean**                 | \~3.49                                          |

Using interquartile range (IQR) method:

* IQR = Q3 - Q1 = 4.3 - 1.1 = ~3.2

* Upper Bound for outliers = Q3 + 1.5 √ó IQR ‚âà 4.3 + 4.8 = ~9.1 days

* Lower Bound = Q1 - 1.5 √ó IQR ‚âà 1.1 - 4.8 = < 0, which is ignored since LOS can‚Äôt be negative

So:

* Typical ICU LOS: 1.1 to 4.3 days

* Outliers: ICU stays longer than ~9.1 days

In [None]:
# Print initial dataset size
print(f"Number of rows before removing LOS outliers: {base_icu_df.count()}")

print("üìä Cleaning target variable...")

# Filter to keep only records with ICU_LOS_DAYS between 0 and 9.1 days
base_icu_df = base_icu_df.filter(
    (col("ICU_LOS_DAYS") >= 0.0) & 
    (col("ICU_LOS_DAYS") <= 9.1)
).cache()

print("‚úÖ Base ICU Dataset - Outliers Removed")

# Print filtered dataset size
print(f"Number of rows after removing LOS outliers: {base_icu_df.count()}")

Number of rows before removing LOS outliers: 12
üìä Cleaning target variable...
‚úÖ Base ICU Dataset - Outliers Removed
Number of rows after removing LOS outliers: 12


25/06/07 12:41:42 WARN CacheManager: Asked to cache already cached data.


## Extracting Categorical Features

**Features Created**:
- **GENDER_BINARY**: Male = 1, Female = 0
- **CAME_FROM_ER**: Emergency admission = 1
- **HAS_INSURANCE**: Medicare = 1, other = 0
- **ADMISSION_TYPE_ENCODED**: Emergency=1, Elective=2, Urgent=3, Other=0
- **ETHNICITY_ENCODED**: White=1, Black=2, Hispanic=3, Asian=4, Other=5
- **MARITAL_STATUS_ENCODED**: Married=1, Single=2, Divorced=3, Widowed=4, Separated=5, LifePartener=6, Other=0
- **RELIGION_ENCODED**: Catholic=1, Protestant=2, Jewish=3, Other=0
- **FIRST_UNIT_ENCODED**: Numerical encoding of ICU units MICU (Medical) = 1, SICU (Surgical) = 2,  CSRU (Cardiac Surgery) = 3, CCU (Coronary Care) = 4, TSICU (Trauma Surgical) = 5, Other = 0
- **CHANGED_ICU_UNIT**: Binary flag (1 if patient transferred between units)

**Clinical Significance**: Different ICU types have varying complexity and typical LOS patterns. Unit transfers often indicate complications.

**Result**: Categorical variables converted to numerical format for ML models.

In [None]:
print("üìä Step 2: Engineering categorical features...")
base_icu_df = base_icu_df \
    .withColumn("GENDER_BINARY", when(col("GENDER") == "M", 1).otherwise(0)) \
    .withColumn("CAME_FROM_ER", when(col("ADMISSION_LOCATION").contains("EMERGENCY"), 1).otherwise(0)) \
    .withColumn("HAS_INSURANCE", when(col("INSURANCE") == "Medicare", 1).otherwise(0)) \
    .withColumn("ADMISSION_TYPE_ENCODED", 
                when(col("ADMISSION_TYPE") == "EMERGENCY", 1)
                .when(col("ADMISSION_TYPE") == "ELECTIVE", 2)
                .when(col("ADMISSION_TYPE") == "URGENT", 3)
                .otherwise(0)) \
    .withColumn("ETHNICITY_ENCODED",
                when(col("ETHNICITY").contains("WHITE"), 1)
                .when(col("ETHNICITY").contains("BLACK"), 2)
                .when(col("ETHNICITY").contains("HISPANIC"), 3)
                .when(col("ETHNICITY").contains("ASIAN"), 4)
                .otherwise(5)) \
    .withColumn("MARITAL_STATUS_ENCODED",
                when(col("MARITAL_STATUS") == "MARRIED", 1)
                .when(col("MARITAL_STATUS") == "SINGLE", 2)
                .when(col("MARITAL_STATUS") == "DIVORCED", 3)
                .when(col("MARITAL_STATUS") == "WIDOWED", 4)
                .when(col("MARITAL_STATUS") == "SEPARATED", 5)
                .when(col("MARITAL_STATUS") == "LIFE PARTNER", 6)
                .otherwise(0)) \
    .withColumn("RELIGION_ENCODED",
                when(col("RELIGION").contains("CATHOLIC"), 1)
                .when(col("RELIGION").contains("PROTESTANT"), 2)
                .when(col("RELIGION").contains("JEWISH"), 3)
                .otherwise(0)) \
    .withColumn("FIRST_UNIT_ENCODED", 
                when(col("FIRST_CAREUNIT") == "MICU", 1)
                .when(col("FIRST_CAREUNIT") == "SICU", 2)
                .when(col("FIRST_CAREUNIT") == "CSRU", 3)
                .when(col("FIRST_CAREUNIT") == "CCU", 4)
                .when(col("FIRST_CAREUNIT") == "TSICU", 5)
                .otherwise(0)) \
    .withColumn("CHANGED_ICU_UNIT", 
                when(col("FIRST_CAREUNIT") != col("LAST_CAREUNIT"), 1).otherwise(0))

print("‚úÖ Base ICU Dataset - Categorical Features")


üìä Step 2: Engineering categorical features...
‚úÖ Base ICU Dataset - Categorical Features


## Extracting Time-based Features

**Action**: Filter out invalid records where INTIME >= OUTTIME.


In [None]:
print("üìä Step 4: Creating time-based features...")
base_icu_df = base_icu_df \
    .filter(col("ICU_INTIME") < col("ICU_OUTTIME"))
print("‚úÖ Base ICU Dataset - Time Based Features")

üìä Step 4: Creating time-based features...
‚úÖ Base ICU Dataset - Time Based Features


In [None]:
print("üìä Step 5: Dropping useless columns...")

# List of columns to drop (fixed syntax)
drop_cols = [
    "FIRST_CAREUNIT",
    "LAST_CAREUNIT",
    "GENDER",
    "PATIENT_DIED",
    "DOB",
    "ADMITTIME",
    "DISCHTIME",
    "ADMISSION_TYPE",
    "ADMISSION_LOCATION",
    "INSURANCE",
    "ETHNICITY",
    "MARITAL_STATUS",
    "RELIGION",
    "HOSPITAL_DEATH",
    "ADMISSION_DIAGNOSIS"
]

# Keep all columns except those in drop_cols
base_icu_df = base_icu_df.drop(*drop_cols)

print("‚úÖ Base ICU Dataset - Finalized")
base_icu_df.show(5)  # Showing first 5 rows for brevity

üìä Step 5: Dropping useless columns...
‚úÖ Base ICU Dataset - Finalized
+----------+----------+-------+------------+-------------------+-------------------+--------------------+-------------+------------+-------------+----------------------+-----------------+----------------------+----------------+------------------+----------------+
|ICUSTAY_ID|SUBJECT_ID|HADM_ID|ICU_LOS_DAYS|         ICU_INTIME|        ICU_OUTTIME|AGE_AT_ICU_ADMISSION|GENDER_BINARY|CAME_FROM_ER|HAS_INSURANCE|ADMISSION_TYPE_ENCODED|ETHNICITY_ENCODED|MARITAL_STATUS_ENCODED|RELIGION_ENCODED|FIRST_UNIT_ENCODED|CHANGED_ICU_UNIT|
+----------+----------+-------+------------+-------------------+-------------------+--------------------+-------------+------------+-------------+----------------------+-----------------+----------------------+----------------+------------------+----------------+
|    231977|      8470| 184688|      0.9792|2174-09-01 18:14:58|2174-09-02 17:45:00|                  30|            0|           0|  

## Extracting Clinical Events

**Purpose**: Extract top 20 most common CHARTEVENTS as features for ML models.

**Process**:
1. **Identify**: Find 20 most frequent CHARTEVENTS (typically vital signs)
2. **Calculate**: Average value of each test in first 24 hours of ICU stay
3. **Handle Missing**: Set missing values to **-1** (not null) for ML compatibility

**Time Window**: First 24 hours after ICU admission (INTIME + 24h)

**Result**: 20 vital signs features with consistent **-1** encoding for missing data, ensuring ML algorithm compatibility.


In [None]:
icu_stay_ids = base_icu_df.select("ICUSTAY_ID").distinct()

icu_stay_ids.cache()
print(f"üìå Filtering to {icu_stay_ids.count()} ICU stays")

25/06/07 12:41:42 WARN CacheManager: Asked to cache already cached data.


üìå Filtering to 12 ICU stays


In [None]:
# Cache the filtered data only once
chartevents_filtered = chartevents_df.select(
    "ICUSTAY_ID", "ITEMID", "VALUENUM", "CHARTTIME"
).join(
    icu_stay_ids, "ICUSTAY_ID", "inner"
).filter(
    col("VALUENUM").isNotNull() & col("CHARTTIME").isNotNull()
).cache()

# Join with d_items and cache since we'll use it multiple times
chartevents_with_categories = chartevents_filtered.join(
    d_items_df.select("ITEMID", "CATEGORY"), "ITEMID", "left"
).cache()

# Get top categories with handling for "null"
top_categories = chartevents_with_categories.groupBy("CATEGORY").agg(
    count("*").alias("count")
).orderBy(
    col("count").desc()
).limit(7).select("CATEGORY").collect()

top_categories = [cat for cat in [row["CATEGORY"] for row in top_categories] if cat != None  ]

chartevents_top_categories = chartevents_with_categories.filter(
    col("CATEGORY").isin(top_categories)
).cache()

patient_category_stats = chartevents_top_categories.groupBy("ICUSTAY_ID").pivot(
    "CATEGORY", top_categories
).agg(
    sql_sum("VALUENUM").alias("_sum"), 
    count("VALUENUM").alias("_count")
).fillna(0)  

# Clean up cached DataFrames
chartevents_filtered.unpersist()
chartevents_with_categories.unpersist()
chartevents_top_categories.unpersist()

patient_category_stats.show()

+----------+------------------------+--------------------------+-----------------+------------------+------------------+-----------+-----------+-------------+-----------------+-------------------+-----------------+-------------------+
|ICUSTAY_ID|Routine Vital Signs__sum|Routine Vital Signs__count| Respiratory__sum|Respiratory__count|         Labs__sum|Labs__count|Alarms__sum|Alarms__count|Neurological__sum|Neurological__count|Hemodynamics__sum|Hemodynamics__count|
+----------+------------------------+--------------------------+-----------------+------------------+------------------+-----------+-----------+-------------+-----------------+-------------------+-----------------+-------------------+
|    234929|                 21690.4|                       285|         20308.37|               201| 8817.480000000001|        146|     2743.0|           61|             75.0|                 21|7460.600000000001|                292|
|    207525|                 12608.6|                       

## Extracting Laboratory Events

**Purpose**: Extract top 20 most common lab tests as features for ML models.

**Process**:
1. **Identify**: Find 20 most frequent LABEVENTS (blood tests, chemistry panels)
2. **Time Window**: 6 hours before ICU admission + first 24 hours in ICU (30h total)
3. **Calculate**: Average value of each lab test within the 30-hour window

**Time Range**: ICU_INTIME - 6h to ICU_INTIME + 24h

**Result**: 20 lab test features with consistent -1 encoding for missing data, capturing pre-ICU and early ICU clinical status.

In [None]:
hadm_ids = base_icu_df.select("HADM_ID").distinct()

hadm_ids.cache()
print(f"üìå Filtering to {hadm_ids.count()} HADM ids")

üìå Filtering to 12 HADM ids


25/06/07 12:41:43 WARN CacheManager: Asked to cache already cached data.


In [None]:
# First, filter the labevents data similarly to how you filtered chartevents
labevents_filtered = labevents_df \
    .select("HADM_ID", "ITEMID", "VALUENUM", "CHARTTIME") \
    .join(hadm_ids, "HADM_ID", "inner") \
    .filter(col("VALUENUM").isNotNull()) \
    .filter(col("CHARTTIME").isNotNull())
labevents_filtered.cache()

# Join with d_labitems to get the categories
labevents_with_categories = labevents_filtered \
    .join(d_labitems_df.select("ITEMID", "CATEGORY"), "ITEMID", "left")

# Get the top 20 categories
top_lab_categories = labevents_with_categories \
    .groupBy("CATEGORY") \
    .count() \
    .orderBy(col("count").desc()) \
    .limit(7) \
    .select("CATEGORY") \
    .collect()
top_lab_categories = [row["CATEGORY"] for row in top_lab_categories]

# Filter to only include top categories
labevents_top_categories = labevents_with_categories.filter(
    col("CATEGORY").isin(top_lab_categories)
)

patient_lab_category_stats = labevents_top_categories.groupBy("HADM_ID", "CATEGORY") \
    .agg(
        sum("VALUENUM").alias("sum_val"),
        count("*").alias("count_val")
    )

# Pivot both metrics
sum_pivot = patient_lab_category_stats.groupBy("HADM_ID") \
    .pivot("CATEGORY", top_lab_categories) \
    .sum("sum_val")

count_pivot = patient_lab_category_stats.groupBy("HADM_ID") \
    .pivot("CATEGORY", top_lab_categories) \
    .sum("count_val")

# Rename count columns before joining
for category in top_lab_categories:
    count_pivot = count_pivot.withColumnRenamed(
        category, f"{category}_count"
    )
    sum_pivot = sum_pivot.withColumnRenamed(
        category, f"{category}_sum"
    )

# Join the results
final_lab_stats = sum_pivot.join(count_pivot, "HADM_ID", "inner")

final_lab_stats = final_lab_stats.fillna(0)
# Show the final result
final_lab_stats.show()

25/06/07 12:41:43 WARN CacheManager: Asked to cache already cached data.


+-------+------------------+------------------+------------------+----------------+---------------+---------------+
|HADM_ID|    Hematology_sum|     Chemistry_sum|     Blood Gas_sum|Hematology_count|Chemistry_count|Blood Gas_count|
+-------+------------------+------------------+------------------+----------------+---------------+---------------+
| 152943|12508.405999999997|21614.250000000007| 6574.119999999999|             224|            281|             80|
| 163177|2428.5839999999994|1625.6999999999998|              40.0|              68|             42|              2|
| 109820|           3414.12|            4182.5|            585.46|              63|             52|             13|
| 181763| 3908.620000000001|1009.1999999999999| 7738.740000000001|              86|             30|            121|
| 110972| 8419.420999999998| 6029.299999999999|1772.8200000000002|             188|            141|             31|
| 109131|2850.6750000000006| 3791.999999999999|           5323.58|      

## Diagnosis ICD

**Purpose**: Extract diagnosis patterns as ML features from ICD-9 codes.

**Process**:
1. **Top 3**: Get top 3 diagnoses by person, using HADM_ID, to future join with other tables. 
2. **Encode**: Encode the ICD9 diagnoses into a wide range of diagnoses.
3. **Pivot**: Pivot to create the 3 columns with the encoded diagnose type.
4. **Handle Missing Values**: Input -1 in the NULL entries of the table.


**Features Created**:
- **TOTAL_DIAGNOSES**: Count of all diagnoses (comorbidity indicator)
- **PRIMARY_DIAGNOSIS**: Most significant diagnose, encoded.
- **SECONDARY_DIAGNOSIS**: Second most significant diagnose, encoded.
- **TERCIARY_DIAGNOSIS**: Third most significant diagnose, encoded.

**Result**: ??????????????????????????????????????????????????????????

In [None]:
def icd9_to_chapter(code):
    # Convert to string and clean
    code_str = str(code).strip()
    
    # Handle V codes (supplementary classification)
    if code_str.startswith('V'):
        return 18 #'Supplemental'
    
    # Handle E codes (external causes of injury)
    if code_str.startswith('E'):
        return 19 #'External_Injury'
    
    # Extract first 3 digits for numeric codes
    try:
        # Handle codes like '4280' (convert to 428) or '486' (stays 486)
        numeric_part = code_str.split('.')[0] if '.' in code_str else code_str
        code_num = float(numeric_part[:3])
    except:
        return 0 #'Unknown'
    
    # Map to chapters
    if 1 <= code_num <= 139: return 1 #'Infectious'
    elif 140 <= code_num <= 239: return 2 # 'Neoplasms'
    elif 240 <= code_num <= 279: return 3 #'Endocrine'
    elif 280 <= code_num <= 289: return 4 #'Blood'
    elif 290 <= code_num <= 319: return 5 #'Mental'
    elif 320 <= code_num <= 389: return 6 #'Nervous'
    elif 390 <= code_num <= 459: return 7 #'Circulatory'
    elif 460 <= code_num <= 519: return 8 #'Respiratory'
    elif 520 <= code_num <= 579: return 9 #'Digestive'
    elif 580 <= code_num <= 629: return 10 #'Genitourinary'
    elif 630 <= code_num <= 679: return 11 #'Pregnancy'
    elif 680 <= code_num <= 709: return 12 #'Skin'
    elif 710 <= code_num <= 739: return 13 #'Musculoskeletal'
    elif 740 <= code_num <= 759: return 14 #'Congenital'
    elif 760 <= code_num <= 779: return 15 #'Perinatal'
    elif 780 <= code_num <= 799: return 16 #'Ill-defined'
    elif 800 <= code_num <= 999: return 17 #'Injury'
    else: return 20 #'Other' 

In [None]:
print("\nüè• Creating diagnosis features (optimized pipeline)...")

start_time = time.time()

# 1. First filter to only top 3 diagnoses per admission
window_spec = Window.partitionBy("HADM_ID").orderBy("SEQ_NUM")

top_3_filtered = diagnoses_df \
    .withColumn("row_num", row_number().over(window_spec)) \
    .filter(col("row_num") <= 3) \
    .cache()

# 2. Register UDF with Integer return type
icd9_chapter_udf = udf(icd9_to_chapter, IntegerType())  # Changed to IntegerType

# 3. Encode ONLY the top 3 diagnoses
top_3_encoded = top_3_filtered.withColumn(
    "DISEASE_CHAPTER", 
    icd9_chapter_udf(col("ICD9_CODE"))
)


diagnosis_count = diagnoses_df.groupBy("HADM_ID").count().withColumnRenamed("count", "TOTAL_DIAGNOSES")


diagnoses_df.unpersist()


# 4. Pivot to create columns
diagnosis_features = top_3_encoded \
    .groupBy("HADM_ID") \
    .pivot("row_num", [1, 2, 3]) \
    .agg(first("DISEASE_CHAPTER")) \
    .select(
        "HADM_ID",
        col("1").alias("PRIMARY_DIAGNOSIS").cast(IntegerType()),
        col("2").alias("SECONDARY_DIAGNOSIS").cast(IntegerType()),
        col("3").alias("TERTIARY_DIAGNOSIS").cast(IntegerType())
    ) \
    .join(diagnosis_count, "HADM_ID", "left")

# 5. Fill NULLs and ensure consistent types
diagnosis_features = diagnosis_features.fillna(-1, subset=[
    "PRIMARY_DIAGNOSIS",
    "SECONDARY_DIAGNOSIS",
    "TERTIARY_DIAGNOSIS"
])


print("üìä Optimized diagnosis features:")
diagnosis_features.select(
    "HADM_ID",
    "TOTAL_DIAGNOSES",
    "PRIMARY_DIAGNOSIS",
    "SECONDARY_DIAGNOSIS",
    "TERTIARY_DIAGNOSIS"
).show(20, truncate=False)

print(f"‚è∞ Completed in: {time.time() - start_time:.2f}s")


üè• Creating diagnosis features (optimized pipeline)...


25/06/07 12:41:44 WARN CacheManager: Asked to cache already cached data.


üìä Optimized diagnosis features:




+-------+---------------+-----------------+-------------------+------------------+
|HADM_ID|TOTAL_DIAGNOSES|PRIMARY_DIAGNOSIS|SECONDARY_DIAGNOSIS|TERTIARY_DIAGNOSIS|
+-------+---------------+-----------------+-------------------+------------------+
|152943 |7              |7                |6                  |6                 |
|163177 |7              |9                |8                  |5                 |
|110159 |12             |17               |9                  |8                 |
|109820 |11             |1                |8                  |8                 |
|181763 |12             |7                |17                 |4                 |
|150954 |6              |7                |8                  |18                |
|177309 |16             |1                |12                 |10                |
|110972 |13             |17               |7                  |17                |
|197549 |15             |17               |7                  |17                |
|109

                                                                                

## Joining All Features

In [None]:
print("üìä Joining all features and selecting final features for regression modeling...")

# Define feature columns to exclude
exclude_columns = {"ICUSTAY_ID", "HADM_ID", "SUBJECT_ID", "ICU_INTIME", "ICU_OUTTIME"}

# Join all features and immediately select desired columns
modeling_dataset = base_icu_df \
    .join(patient_category_stats, "ICUSTAY_ID", "left") \
    .join(final_lab_stats, "HADM_ID", "left") \
    .join(diagnosis_features, "HADM_ID", "left") \
    .select(*[name for name in base_icu_df \
        .join(patient_category_stats, "ICUSTAY_ID", "left") \
        .join(final_lab_stats, "HADM_ID", "left") \
        .join(diagnosis_features, "HADM_ID", "left") \
        .columns if name not in exclude_columns])

# Cleanup
base_icu_df.unpersist()
patient_category_stats.unpersist()
final_lab_stats.unpersist()
diagnosis_features.unpersist()

# Display final info
print(f"‚úÖ Final modeling dataset created with {modeling_dataset.count()} records")
print("üìã Sample of final modeling dataset:")
modeling_dataset.show(5, truncate=False)


üìä Joining all features and selecting final features for regression modeling...
‚úÖ Final modeling dataset created with 12 records
üìã Sample of final modeling dataset:


                                                                                

+------------+--------------------+-------------+------------+-------------+----------------------+-----------------+----------------------+----------------+------------------+----------------+------------------------+--------------------------+----------------+------------------+------------------+-----------+-----------+-------------+-----------------+-------------------+-----------------+-------------------+------------------+------------------+------------------+----------------+---------------+---------------+-----------------+-------------------+------------------+---------------+
|ICU_LOS_DAYS|AGE_AT_ICU_ADMISSION|GENDER_BINARY|CAME_FROM_ER|HAS_INSURANCE|ADMISSION_TYPE_ENCODED|ETHNICITY_ENCODED|MARITAL_STATUS_ENCODED|RELIGION_ENCODED|FIRST_UNIT_ENCODED|CHANGED_ICU_UNIT|Routine Vital Signs__sum|Routine Vital Signs__count|Respiratory__sum|Respiratory__count|Labs__sum         |Labs__count|Alarms__sum|Alarms__count|Neurological__sum|Neurological__count|Hemodynamics__sum|Hemodynamics

## Normalization & Handeling Missing Values

Display Missing Values by column.

In [None]:
null_counts = modeling_dataset.select(
    [sum(col(c).isNull().cast("int")).alias(c) for c in modeling_dataset.columns]
).collect()[0]

null_counts_dict = {col: null_counts[col] for col in modeling_dataset.columns}
print(null_counts_dict)

                                                                                

{'ICU_LOS_DAYS': 0, 'AGE_AT_ICU_ADMISSION': 0, 'GENDER_BINARY': 0, 'CAME_FROM_ER': 0, 'HAS_INSURANCE': 0, 'ADMISSION_TYPE_ENCODED': 0, 'ETHNICITY_ENCODED': 0, 'MARITAL_STATUS_ENCODED': 0, 'RELIGION_ENCODED': 0, 'FIRST_UNIT_ENCODED': 0, 'CHANGED_ICU_UNIT': 0, 'Routine Vital Signs__sum': 2, 'Routine Vital Signs__count': 2, 'Respiratory__sum': 2, 'Respiratory__count': 2, 'Labs__sum': 2, 'Labs__count': 2, 'Alarms__sum': 2, 'Alarms__count': 2, 'Neurological__sum': 2, 'Neurological__count': 2, 'Hemodynamics__sum': 2, 'Hemodynamics__count': 2, 'Hematology_sum': 0, 'Chemistry_sum': 0, 'Blood Gas_sum': 0, 'Hematology_count': 0, 'Chemistry_count': 0, 'Blood Gas_count': 0, 'PRIMARY_DIAGNOSIS': 0, 'SECONDARY_DIAGNOSIS': 0, 'TERTIARY_DIAGNOSIS': 0, 'TOTAL_DIAGNOSES': 0}


## Normalization & Missing Values Strategy

**Approach**: Min-Max scaling chosen over standardization because -1 represents missing values. With standardization (Gaussian approximation), -1 could correspond to an actual test result rather than indicating a missing/non-existent test result.

**Implementation**: Applied only to float columns since others are binary or integer (like age). Final results limited to 3 decimal places maximum.

In [None]:
print("üìä Filling NULL entries with 0...")
modeling_dataset = modeling_dataset.na.fill(0)

print("üìä Applying StandardScaling to _sum columns...")
std_columns = [c for c in modeling_dataset.columns if c.endswith('_sum')]

if std_columns:
    # Create vector assembler for the columns to scale
    assembler = VectorAssembler(
        inputCols=std_columns,
        outputCol="features_to_scale",
        handleInvalid="keep"
    )
    
    # Create MinMaxScaler
    scaler = StandardScaler(
        inputCol="features_to_scale",
        outputCol="scaled_features"
    )
    
    # Create pipeline
    pipeline = Pipeline(stages=[assembler, scaler])
    
    # Fit and transform
    scaler_model = pipeline.fit(modeling_dataset)
    scaled_data = scaler_model.transform(modeling_dataset)
    
    # Extract scaled values back to individual columns
    scaled_data = scaled_data.withColumn("scaled_array", vector_to_array("scaled_features"))
    
    # Replace original columns with scaled values
    for i, col_name in enumerate(std_columns):
        scaled_data = scaled_data.withColumn(
            col_name,
            scaled_data["scaled_array"][i]
        )
    
    # Drop temporary columns
    modeling_dataset = scaled_data.drop("features_to_scale", "scaled_features", "scaled_array")
    
    print(f"‚úÖ Scaled {len(std_columns)} _sum columns")
else:
    print("‚ö†Ô∏è No _sum columns found to scale")
    

print("‚úÖ Data set ready for Machine Learning!")
modeling_dataset.show(5, truncate=False)

num_rows = modeling_dataset.count()
num_cols = len(modeling_dataset.columns)
print(f"Final DataSet shape: ({num_rows}, {num_cols})")

üìä Filling NULL entries with 0...
üìä Applying StandardScaling to _sum columns...
‚úÖ Scaled 9 _sum columns
‚úÖ Data set ready for Machine Learning!


                                                                                

+------------+--------------------+-------------+------------+-------------+----------------------+-----------------+----------------------+----------------+------------------+----------------+------------------------+--------------------------+------------------+------------------+------------------+-----------+------------------+-------------+-------------------+-------------------+------------------+-------------------+------------------+-------------------+-------------------+----------------+---------------+---------------+-----------------+-------------------+------------------+---------------+
|ICU_LOS_DAYS|AGE_AT_ICU_ADMISSION|GENDER_BINARY|CAME_FROM_ER|HAS_INSURANCE|ADMISSION_TYPE_ENCODED|ETHNICITY_ENCODED|MARITAL_STATUS_ENCODED|RELIGION_ENCODED|FIRST_UNIT_ENCODED|CHANGED_ICU_UNIT|Routine Vital Signs__sum|Routine Vital Signs__count|Respiratory__sum  |Respiratory__count|Labs__sum         |Labs__count|Alarms__sum       |Alarms__count|Neurological__sum  |Neurological__count|Hemod

In [None]:
print("üìä Applying MinMaxScaling to _count columns...")
minmax_columns = [c for c in modeling_dataset.columns if c.endswith('_count')]

if minmax_columns:
    # Create vector assembler for the columns to scale
    assembler = VectorAssembler(
        inputCols=minmax_columns,
        outputCol="features_to_scale",
        handleInvalid="keep"
    )
    
    # Create MinMaxScaler
    scaler = MinMaxScaler(
        inputCol="features_to_scale",
        outputCol="scaled_features",
        min=0,  # Default minimum value after scaling
        max=10   # Default maximum value after scaling
    )
    
    # Create pipeline
    pipeline = Pipeline(stages=[assembler, scaler])
    
    # Fit and transform
    scaler_model = pipeline.fit(modeling_dataset)
    scaled_data = scaler_model.transform(modeling_dataset)
    
    # Extract scaled values back to individual columns
    scaled_data = scaled_data.withColumn("scaled_array", vector_to_array("scaled_features"))
    
    # Replace original columns with scaled values
    for i, col_name in enumerate(minmax_columns):
        scaled_data = scaled_data.withColumn(
            col_name,
            scaled_data["scaled_array"][i]
        )
    
    # Drop temporary columns
    modeling_dataset = scaled_data.drop("features_to_scale", "scaled_features", "scaled_array")
    
    print(f"‚úÖ Scaled {len(minmax_columns)} _count columns using MinMax scaling")
else:
    print("‚ö†Ô∏è No _count columns found to scale")

print("‚úÖ Data set ready for Machine Learning!")
modeling_dataset.show(5, truncate=False)

num_rows = modeling_dataset.count()
num_cols = len(modeling_dataset.columns)
print(f"Final DataSet shape: ({num_rows}, {num_cols})")

üìä Applying MinMaxScaling to _count columns...
‚úÖ Scaled 9 _count columns using MinMax scaling
‚úÖ Data set ready for Machine Learning!


                                                                                

+------------+--------------------+-------------+------------+-------------+----------------------+-----------------+----------------------+----------------+------------------+----------------+------------------------+--------------------------+-------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+-------------------+------------------+-------------------+------------------+-------------------+--------------------+------------------+------------------+------------------+-----------------+-------------------+------------------+---------------+
|ICU_LOS_DAYS|AGE_AT_ICU_ADMISSION|GENDER_BINARY|CAME_FROM_ER|HAS_INSURANCE|ADMISSION_TYPE_ENCODED|ETHNICITY_ENCODED|MARITAL_STATUS_ENCODED|RELIGION_ENCODED|FIRST_UNIT_ENCODED|CHANGED_ICU_UNIT|Routine Vital Signs__sum|Routine Vital Signs__count|Respiratory__sum   |Respiratory__count|Labs__sum         |Labs__count       |Alarms__sum        |Alarms__count     |Neurolog

# Machine Learning

## Preparing for Machine Learning

In [None]:
print("üìä Step 1: Creating train/test split...")
train_data, test_data = modeling_dataset.randomSplit([0.9, 0.1], seed=42)

print("‚úÖ Data split completed.")
print(f"   üöÜ Training samples: {train_data.count()}")
print(f"   üß™ Test samples: {test_data.count()}")


feature_columns = [col for col in modeling_dataset.columns if col != 'ICU_LOS_DAYS']
print("Feature columns:", feature_columns)
target_column = 'ICU_LOS_DAYS'
print("Target column:", target_column)

feature_assembler = VectorAssembler(
    inputCols=feature_columns,  
    outputCol="features"     
)

print("üìä Step 2: Creating the final vectorized train/test datasets...")
train_final = feature_assembler.transform(train_data).select(
    "features", 
    target_column
).withColumnRenamed(target_column, "label")

test_final = feature_assembler.transform(test_data).select(
    "features", 
    target_column
).withColumnRenamed(target_column, "label")

train_final.cache()
test_final.cache()

print("‚úÖ Final datasets prepared:")
print(f"   üöÜ Training features shape: ({train_final.count()}, {len(feature_columns)})")
print(f"   üß™ Test features shape: ({test_final.count()}, {len(feature_columns)})")

üìä Step 1: Creating train/test split...
‚úÖ Data split completed.


                                                                                

   üöÜ Training samples: 10


                                                                                

   üß™ Test samples: 2
Feature columns: ['AGE_AT_ICU_ADMISSION', 'GENDER_BINARY', 'CAME_FROM_ER', 'HAS_INSURANCE', 'ADMISSION_TYPE_ENCODED', 'ETHNICITY_ENCODED', 'MARITAL_STATUS_ENCODED', 'RELIGION_ENCODED', 'FIRST_UNIT_ENCODED', 'CHANGED_ICU_UNIT', 'Routine Vital Signs__sum', 'Routine Vital Signs__count', 'Respiratory__sum', 'Respiratory__count', 'Labs__sum', 'Labs__count', 'Alarms__sum', 'Alarms__count', 'Neurological__sum', 'Neurological__count', 'Hemodynamics__sum', 'Hemodynamics__count', 'Hematology_sum', 'Chemistry_sum', 'Blood Gas_sum', 'Hematology_count', 'Chemistry_count', 'Blood Gas_count', 'PRIMARY_DIAGNOSIS', 'SECONDARY_DIAGNOSIS', 'TERTIARY_DIAGNOSIS', 'TOTAL_DIAGNOSES']
Target column: ICU_LOS_DAYS
üìä Step 2: Creating the final vectorized train/test datasets...
‚úÖ Final datasets prepared:


                                                                                

   üöÜ Training features shape: (11, 32)


                                                                                

   üß™ Test features shape: (1, 32)


## Training Multiple Models

In [None]:
print("üìä Step 1: Setting up evaluation metrics...")

# Create regression evaluators
rmse_evaluator = RegressionEvaluator(
    labelCol="label", 
    predictionCol="prediction", 
    metricName="rmse"
)

mae_evaluator = RegressionEvaluator(
    labelCol="label",
    predictionCol="prediction", 
    metricName="mae"
)

r2_evaluator = RegressionEvaluator(
    labelCol="label",
    predictionCol="prediction",
    metricName="r2"
)

print("‚úÖ Evaluation metrics configured: RMSE, MAE, R¬≤")

üìä Step 1: Setting up evaluation metrics...
‚úÖ Evaluation metrics configured: RMSE, MAE, R¬≤


### Linear Regression

In [None]:
print("\nüìà Step 2: Training Linear Regression model...")
print(f"üïê Started at: {datetime.now().strftime('%H:%M:%S')}")
start_time = time.time()

# Create Linear Regression model
lr = LinearRegression(
    featuresCol="features",
    labelCol="label",
    maxIter=200,                    # Increased for better convergence
    regParam=0.001,                 # Lower regularization for healthcare data
    elasticNetParam=0.1,            # Slight L1 penalty for feature selection
    tol=1e-8,                       # Tighter tolerance for precision
    standardization=False,          # We're doing manual scaling
    fitIntercept=True,
    aggregationDepth=3,             # Better for distributed training
    loss="squaredError",
    solver="normal"                 # Best for small-medium datasets
)


# Train the model
print("   üîÑ Training Linear Regression...")
lr_model = lr.fit(train_final)

print("   üîÑ Linear Regression - Making predictions (test data)...")
lr_predictions = lr_model.transform(test_final)

print("   üîÑ Linear Regression - Evaluation...")
lr_rmse = rmse_evaluator.evaluate(lr_predictions)
lr_mae = mae_evaluator.evaluate(lr_predictions)
lr_r2 = r2_evaluator.evaluate(lr_predictions)

print(f"‚úÖ Linear Regression Results:")
print(f"   üìâ RMSE: {lr_rmse:.3f} days")
print(f"   üìä MAE: {lr_mae:.3f} days")
print(f"   üìà R¬≤: {lr_r2:.3f}")

end_time = time.time()
elapsed_time = end_time - start_time
print(f"üïê Completed at: {datetime.now().strftime('%H:%M:%S')}")
print(f"‚è±Ô∏è Total elapsed time: {elapsed_time:.2f} seconds")


üìà Step 2: Training Linear Regression model...
üïê Started at: 12:42:02
   üîÑ Training Linear Regression...
   üîÑ Linear Regression - Making predictions (test data)...
   üîÑ Linear Regression - Evaluation...
‚úÖ Linear Regression Results:
   üìâ RMSE: 1.166 days
   üìä MAE: 1.166 days
   üìà R¬≤: -inf
üïê Completed at: 12:42:03
‚è±Ô∏è Total elapsed time: 1.42 seconds


### Random Forest

In [None]:

print("\nüå≤ Step 3: Training Random Forest model...")
print(f"üïê Started at: {datetime.now().strftime('%H:%M:%S')}")
start_time = time.time()

# Create Random Forest model
rf = RandomForestRegressor(
    featuresCol="features",
    labelCol="label",
    numTrees=200,                   # More trees = better accuracy (if enough cores/memory)
    maxDepth=12,                    # Deeper trees capture more complexity
    minInstancesPerNode=2,          # Allows more granular splits
    subsamplingRate=0.9,            # Slightly higher sample rate for stability
    featureSubsetStrategy="sqrt",   # Good default for regression
    seed=42                         # Reproducibility
)

print("   üîÑ Training Random Forest...")
rf_model = rf.fit(train_final)

print("   üîÑ Random Forest - Making predictions (test data)...")
rf_predictions = rf_model.transform(test_final)

print("   üîÑ Random Forest - Evaluation...")
rf_rmse = rmse_evaluator.evaluate(rf_predictions)
rf_mae = mae_evaluator.evaluate(rf_predictions)
rf_r2 = r2_evaluator.evaluate(rf_predictions)

print(f"‚úÖ Random Forest Results:")
print(f"   üìâ RMSE: {rf_rmse:.3f} days")
print(f"   üìä MAE: {rf_mae:.3f} days")
print(f"   üìà R¬≤: {rf_r2:.3f}")


end_time = time.time()
elapsed_time = end_time - start_time
print(f"üïê Completed at: {datetime.now().strftime('%H:%M:%S')}")
print(f"‚è±Ô∏è Total elapsed time: {elapsed_time:.2f} seconds")



üå≤ Step 3: Training Random Forest model...
üïê Started at: 12:42:03
   üîÑ Training Random Forest...


25/06/07 12:42:04 WARN DecisionTreeMetadata: DecisionTree reducing maxBins from 32 to 11 (= number of training instances)


   üîÑ Random Forest - Making predictions (test data)...
   üîÑ Random Forest - Evaluation...
‚úÖ Random Forest Results:
   üìâ RMSE: 1.511 days
   üìä MAE: 1.511 days
   üìà R¬≤: -inf
üïê Completed at: 12:42:07
‚è±Ô∏è Total elapsed time: 3.46 seconds


## Model Predictions

In [None]:
evaluator_r2 = RegressionEvaluator(metricName="r2")

print("\nüìà Linear Regression Predictions:")
lr_display = lr_predictions.select(
    col("label").alias("Actual_LOS"),
    round(col("prediction"), 3).alias("Predicted_LOS"),
    round(abs(col("label") - col("prediction")), 3).alias("Absolute_Error"),
    round(((abs(col("label") - col("prediction")) / col("label")) * 100), 2).alias("Percent_Error")
)

lr_display.show(truncate=False)



# Random Forest Predictions
print("\nüå≤ Random Forest Predictions:")
rf_display = rf_predictions.select(
    col("label").alias("Actual_LOS"),
    round(col("prediction"), 3).alias("Predicted_LOS"),
    round(abs(col("label") - col("prediction")), 3).alias("Absolute_Error"),
    round(((abs(col("label") - col("prediction")) / col("label")) * 100), 2).alias("Percent_Error")
)

rf_display.show(truncate=False)


üìà Linear Regression Predictions:
+----------+-------------+--------------+-------------+
|Actual_LOS|Predicted_LOS|Absolute_Error|Percent_Error|
+----------+-------------+--------------+-------------+
|1.2597    |2.426        |1.166         |92.57        |
+----------+-------------+--------------+-------------+


üå≤ Random Forest Predictions:
+----------+-------------+--------------+-------------+
|Actual_LOS|Predicted_LOS|Absolute_Error|Percent_Error|
+----------+-------------+--------------+-------------+
|1.2597    |2.771        |1.511         |119.98       |
+----------+-------------+--------------+-------------+



## Model Comparison

In [None]:
print("\nüèÜ Step 5: Model Performance Comparison...")

# Create comparison summary
results_data = [
    ("Linear Regression", lr_rmse, lr_mae, lr_r2),
    ("Random Forest", rf_rmse, rf_mae, rf_r2)
]

results_df = spark.createDataFrame(results_data, ["Model", "RMSE", "MAE", "R2"])

print("üìä Model Performance Summary:")
results_df.show(truncate=False)


best_rmse_model = builtins.min(results_data, key=operator.itemgetter(1))
best_r2_model = builtins.max(results_data, key=operator.itemgetter(3))

print(f"\nü•á Best Models:")
print(f"   üéØ Lowest RMSE: {best_rmse_model[0]} ({best_rmse_model[1]:.3f} days)")
print(f"   üìà Highest R¬≤: {best_r2_model[0]} ({best_r2_model[3]:.3f})")


üèÜ Step 5: Model Performance Comparison...
üìä Model Performance Summary:
+-----------------+------------------+------------------+---------+
|Model            |RMSE              |MAE               |R2       |
+-----------------+------------------+------------------+---------+
|Linear Regression|1.1661060025268237|1.1661060025268237|-Infinity|
|Random Forest    |1.5113595333333345|1.5113595333333345|-Infinity|
+-----------------+------------------+------------------+---------+


ü•á Best Models:
   üéØ Lowest RMSE: Linear Regression (1.166 days)
   üìà Highest R¬≤: Linear Regression (-inf)
