# 1.sedatives_id_filtering

In [3]:
# d_items.csv.gz 로딩 (처음 한 번만)
import pandas as pd

d_items = pd.read_csv("/Users/skku_aws165/Documents/MIMIC/icu/d_items.csv.gz", compression="gzip", low_memory=False)

# 진정제 키워드 리스트
sedation_keywords = [
    "midazolam", "propofol", "dexmedetomidine", "lorazepam", 
    "fentanyl", "etomidate", "diazepam", "ketamine", "morphine"
]

# label 또는 drugname에서 진정제 필터링 (대소문자 무시)
sedatives = d_items[
    d_items['label'].str.lower().str.contains('|'.join(sedation_keywords), na=False)
]

# 결과 확인
sedatives[['itemid', 'label', 'linksto']].sort_values(by='label')


Unnamed: 0,itemid,label,linksto
1106,225150,Dexmedetomidine (Precedex),inputevents
3599,229420,Dexmedetomidine (Precedex),inputevents
291,221623,Diazepam (Valium),inputevents
2147,227212,Etomidate (Intubation),chartevents
298,221744,Fentanyl,inputevents
1646,225942,Fentanyl (Concentrate),inputevents
1668,225972,Fentanyl (Push),inputevents
296,221712,Ketamine,inputevents
2146,227211,Ketamine (Intubation),chartevents
285,221385,Lorazepam (Ativan),inputevents


In [2]:
# d_items.csv.gz 로딩 (처음 한 번만)
import pandas as pd

d_items = pd.read_csv("/Users/skku_aws165/Documents/MIMIC/icu/d_items.csv.gz", compression="gzip", low_memory=False)

# 진정제 키워드 리스트
sedation_keywords = [
    "midazolam", "propofol", "dexmedetomidine",
    "fentanyl", "remifentanil", "etomidate"
]

# label 또는 drugname에서 진정제 필터링 (대소문자 무시)
sedatives = d_items[
    d_items['label'].str.lower().str.contains('|'.join(sedation_keywords), na=False)
]

# 결과 확인
sedatives[['itemid', 'label', 'linksto']].sort_values(by='label')


Unnamed: 0,itemid,label,linksto
1106,225150,Dexmedetomidine (Precedex),inputevents
3599,229420,Dexmedetomidine (Precedex),inputevents
2147,227212,Etomidate (Intubation),chartevents
298,221744,Fentanyl,inputevents
1646,225942,Fentanyl (Concentrate),inputevents
1668,225972,Fentanyl (Push),inputevents
294,221668,Midazolam (Versed),inputevents
316,222168,Propofol,inputevents
2145,227210,Propofol (Intubation),chartevents
1802,226224,Propofol Ingredient,ingredientevents


In [8]:
# d_items.csv.gz 에서 Remifentanil 관련 label 검색 -> 라벨 없음을 확인함
import pandas as pd

df = pd.read_csv("/Users/skku_aws165/Documents/MIMIC/icu/d_items.csv.gz", compression="gzip", low_memory=False)
df[df["label"].str.lower().str.contains("remifentanil|ultiva")]


Unnamed: 0,itemid,label,abbreviation,linksto,category,unitname,param_type,lownormalvalue,highnormalvalue


# PySpark 기반 진정제 투약 시점 추출 + GCS와 매칭 (스타터 코드)

### ① 새 코호트 기준 필터링 추가

In [14]:
import glob

paths = glob.glob("/Users/skku_aws165/Documents/MIMIC/MIMIC-IV-Project/notebooks/final/new_cohort*")
print(paths)


['/Users/skku_aws165/Documents/MIMIC/MIMIC-IV-Project/notebooks/final/new_cohort.csv']


In [1]:
from pyspark.sql import SparkSession

# 이미 spark 세션이 있으면 재사용, 없으면 새로 생성
spark = SparkSession.builder.getOrCreate()

# 🔁 코호트 먼저 불러와서 stay_id 필터링 기준 만들기
cohort = spark.read.csv("/Users/skku_aws165/Documents/MIMIC/MIMIC-IV-Project/notebooks/final/new_cohort.csv", header=True, inferSchema=True)
cohort_ids = cohort.select("stay_id").distinct()


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/06 16:33:36 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/08/06 16:33:37 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [7]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_timestamp, expr, when, lit, max as spark_max, udf
from pyspark.sql.types import TimestampType, StructType, StructField, IntegerType
import datetime

# 1. Spark 세션 시작
spark = SparkSession.builder.appName("GCS Sedation Flag").getOrCreate()

# 2. 코호트 로딩 및 stay_id 추출
cohort = spark.read.csv(
    "/Users/skku_aws165/Documents/MIMIC/MIMIC-IV-Project/notebooks/final/new_cohort.csv",
    header=True, inferSchema=True
)
cohort_ids = cohort.select("stay_id").distinct()

# 3. 진정제 itemid 및 약효 지속 시간 정의 (시간 단위)
sedative_itemids = [
    (221319, 4),  # Midazolam (Versed)
    (221668, 4),  # Midazolam
    (222168, 1),  # Propofol
    (221744, 2),  # Fentanyl
    (225942, 2),  # Fentanyl (Concentrate)
    (225972, 2),  # Fentanyl (Push)
    (221320, 4),  # Dexmedetomidine
    (229420, 4),  # Dexmedetomidine (Precedex)
    (225150, 4),  # Dexmedetomidine (Precedex)
    (221195, 4),  # Lorazepam
    (227212, 1)   # Etomidate
]
sedative_ids_only = [x[0] for x in sedative_itemids]

# 4. itemid + 지속시간을 PySpark DataFrame으로 변환
schema = StructType([
    StructField("itemid", IntegerType(), False),
    StructField("window_hours", IntegerType(), False)
])
sedative_windows = spark.createDataFrame(sedative_itemids, schema)

# 5. 진정제 투약 이벤트 로딩 및 필터링
inputevents = spark.read.csv(
    "/Users/skku_aws165/Documents/MIMIC/icu/inputevents.csv.gz",
    header=True, inferSchema=True
)
sedation_events = inputevents.filter(col("itemid").isin(sedative_ids_only)) \
    .join(cohort_ids, on="stay_id", how="inner") \
    .withColumn("starttime", to_timestamp("starttime")) \
    .join(sedative_windows, on="itemid", how="left")

# 6. UDF로 진정제 투약 종료 시점(endtime) 계산
def add_hours(starttime, hours):
    if starttime is None or hours is None:
        return None
    return starttime + datetime.timedelta(hours=int(hours))
add_hours_udf = udf(add_hours, TimestampType())

sedation_events = sedation_events.withColumn(
    "endtime", add_hours_udf(col("starttime"), col("window_hours"))
).select("stay_id", "starttime", "endtime")

# 7. GCS 관측치 로딩 및 필터링
chartevents = spark.read.csv(
    "/Users/skku_aws165/Documents/MIMIC/icu/chartevents.csv.gz",
    header=True, inferSchema=True
)
gcs_itemids = [220739, 223900, 223901]  # GCS Eye, Verbal, Motor
gcs = chartevents.filter(col("itemid").isin(gcs_itemids)) \
    .join(cohort_ids, on="stay_id", how="inner") \
    .withColumn("charttime", to_timestamp("charttime")) \
    .filter(col("valuenum").isNotNull()) \
    .select("stay_id", "charttime", "itemid", "valuenum")

# 8. 캐시로 성능 최적화
gcs.cache()
sedation_events.cache()

# 9. 조인 후 진정제 투약 여부 (sedated_flag) 부여
sedated_flagged = gcs.join(
    sedation_events,
    on="stay_id",
    how="left"
).withColumn(
    "time_diff",
    expr("abs(unix_timestamp(charttime) - unix_timestamp(starttime))")
).withColumn(
    "temp_flag",
    when((col("charttime") >= col("starttime")) & (col("charttime") <= col("endtime")), 1).otherwise(0)
).groupBy("stay_id", "charttime", "itemid", "valuenum").agg(
    expr("min(time_diff) as min_time_diff"),
    spark_max("temp_flag").alias("sedated_flag")  # 다중 투약 고려
)

# 10. 컬럼명 정리
sedated_flagged = sedated_flagged.withColumnRenamed("itemid", "gcs_itemid")

# 11. 결과 저장 (Parquet + 파티셔닝)
sedated_flagged.select(
    "stay_id", "charttime", "gcs_itemid", "valuenum", "sedated_flag"
).write.mode("overwrite") \
  .partitionBy("stay_id") \
  .parquet("outputs/gcs_with_sedation_flag")

# 12. 간단한 검증 출력
sedated_flagged.groupBy("sedated_flag").count().show()
sedated_flagged.groupBy("gcs_itemid").agg(
    expr("count(*) as total"),
    expr("count(valuenum) as non_null")
).show()

# 13. Spark 종료
spark.stop()


25/08/06 17:17:53 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
25/08/06 17:17:53 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
25/08/06 17:17:53 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 76.00% for 10 writers
25/08/06 17:17:54 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
25/08/06 17:17:54 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
25/08/06 17:17:54 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                     

+------------+-------+
|sedated_flag|  count|
+------------+-------+
|           1| 619690|
|           0|3161090|
+------------+-------+



                                                                                

+----------+-------+--------+
|gcs_itemid|  total|non_null|
+----------+-------+--------+
|    223900|1260658| 1260658|
|    220739|1262538| 1262538|
|    223901|1257584| 1257584|
+----------+-------+--------+



In [12]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, mean, max as spark_max, min as spark_min, stddev, expr
from pyspark.sql.window import Window
import matplotlib.pyplot as plt
import seaborn as sns

# 1. Spark 세션 시작
spark = SparkSession.builder.appName("GCS Preprocessing").getOrCreate()

# 2. 데이터 로딩
gcs_data = spark.read.parquet("outputs/gcs_with_sedation_flag")

# 3. sedated_flag를 charttime 단위로 집계
sed_flag_summary = gcs_data.groupBy("stay_id", "charttime").agg(
    spark_max("sedated_flag").alias("sedated_flag")
)

# 4. GCS 피벗
gcs_pivoted = gcs_data.groupBy("stay_id", "charttime").pivot("gcs_itemid").agg(spark_max("valuenum")).fillna(0)

# 5. GCS 총점 계산 및 sedated_flag 추가
gcs_pivoted = gcs_pivoted.withColumn(
    "gcs_total",
    col("220739") + col("223900") + col("223901")
).join(sed_flag_summary, on=["stay_id", "charttime"], how="left")

# 6. 원본 분포 시각화 (Pandas로)
gcs_total_pd = gcs_pivoted.select("gcs_total", "sedated_flag").toPandas()
plt.figure(figsize=(10, 6))
sns.histplot(data=gcs_total_pd, x="gcs_total", hue="sedated_flag", bins=13, multiple="stack")
plt.title("GCS Total Distribution by Sedation Flag")
plt.xlabel("GCS Total")
plt.ylabel("Count")
plt.savefig("outputs/gcs_total_distribution.png")
plt.close()

# 7. 실험 1: 그대로 사용
gcs_summary_exp1 = gcs_pivoted.groupBy("stay_id").agg(
    mean("gcs_total").alias("gcs_mean"),
    spark_max("gcs_total").alias("gcs_max"),
    spark_min("gcs_total").alias("gcs_min"),
    stddev("gcs_total").alias("gcs_std"),
    mean("220739").alias("eye_mean"),
    mean("223900").alias("verbal_mean"),
    mean("223901").alias("motor_mean")
)
gcs_summary_exp1.write.mode("overwrite").parquet("outputs/exp1_gcs_summary")

# 8. 실험 2: GCS + sedation 비율
gcs_summary_exp2 = gcs_pivoted.groupBy("stay_id").agg(
    mean("gcs_total").alias("gcs_mean"),
    spark_max("gcs_total").alias("gcs_max"),
    spark_min("gcs_total").alias("gcs_min"),
    stddev("gcs_total").alias("gcs_std"),
    mean("sedated_flag").alias("sedated_ratio"),
    mean("220739").alias("eye_mean"),
    mean("223900").alias("verbal_mean"),
    mean("223901").alias("motor_mean")
)
gcs_summary_exp2.write.mode("overwrite").parquet("outputs/exp2_gcs_flagged_summary")

# 9. 실험 3: sedation 시 verbal 제거 (마스킹)
gcs_masked = gcs_data.withColumn(
    "valuenum_masked",
    when((col("sedated_flag") == 1) & (col("gcs_itemid") == 223900), None).otherwise(col("valuenum"))
)

# 10. 피벗
gcs_masked_pivot = gcs_masked.groupBy("stay_id", "charttime").pivot("gcs_itemid").agg(spark_max("valuenum_masked")).fillna(0)

# 11. 마스킹 총점 계산
gcs_masked_pivot = gcs_masked_pivot.join(sed_flag_summary, on=["stay_id", "charttime"], how="left") \
    .withColumn(
        "gcs_partial",
        when(col("sedated_flag") == 1, col("223901"))  # sedation 시 motor만
        .otherwise(col("220739") + col("223900") + col("223901"))  # 아니면 전체
    )

# 12. 실험 3 요약 저장
gcs_summary_exp3 = gcs_masked_pivot.groupBy("stay_id").agg(
    mean("gcs_partial").alias("gcs_partial_mean"),
    spark_max("223901").alias("motor_max"),
    mean("223901").alias("motor_mean"),
    stddev("223901").alias("motor_std"),
    mean("220739").alias("eye_mean"),
    mean("223900").alias("verbal_mean")
)
gcs_summary_exp3.write.mode("overwrite").parquet("outputs/exp3_gcs_masked")

# 13. 결측 확인
gcs_summary_exp3.selectExpr(
    "count(*) as total",
    "count(gcs_partial_mean) as non_null_gcs_partial",
    "count(verbal_mean) as non_null_verbal"
).show()

# 14. Spark 종료
spark.stop()


25/08/06 17:39:08 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
25/08/06 17:39:08 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
25/08/06 17:39:08 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
25/08/06 17:40:22 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
25/08/06 17:40:22 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
25/08/06 17:40:22 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
25/08/06 17:40:22 WARN MemoryManager: Total allocation exceeds 95.00% 



25/08/06 17:42:57 ERROR Executor: Exception in task 4.0 in stage 89.0 (TID 38018)
java.lang.OutOfMemoryError: Java heap space
	at java.base/java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:64)
	at java.base/java.nio.ByteBuffer.allocate(ByteBuffer.java:363)
	at org.apache.spark.io.ReadAheadInputStream.<init>(ReadAheadInputStream.java:111)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader.<init>(UnsafeSorterSpillReader.java:78)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.getReader(UnsafeSorterSpillWriter.java:159)
	at org.apache.spark.unsafe.map.BytesToBytesMap$MapIterator.advanceToNextPage(BytesToBytesMap.java:287)
	at org.apache.spark.unsafe.map.BytesToBytesMap$MapIterator.next(BytesToBytesMap.java:315)
	at org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap$1.next(UnsafeFixedWidthAggregationMap.java:175)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Sour

Py4JJavaError: An error occurred while calling o877.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 4 in stage 89.0 failed 1 times, most recent failure: Lost task 4.0 in stage 89.0 (TID 38018) (172.16.0.199 executor driver): java.lang.OutOfMemoryError: Java heap space
	at java.base/java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:64)
	at java.base/java.nio.ByteBuffer.allocate(ByteBuffer.java:363)
	at org.apache.spark.io.ReadAheadInputStream.<init>(ReadAheadInputStream.java:111)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader.<init>(UnsafeSorterSpillReader.java:78)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.getReader(UnsafeSorterSpillWriter.java:159)
	at org.apache.spark.unsafe.map.BytesToBytesMap$MapIterator.advanceToNextPage(BytesToBytesMap.java:287)
	at org.apache.spark.unsafe.map.BytesToBytesMap$MapIterator.next(BytesToBytesMap.java:315)
	at org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap$1.next(UnsafeFixedWidthAggregationMap.java:175)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50)
	at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.processInputs(TungstenAggregationIterator.scala:196)
	at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.<init>(TungstenAggregationIterator.scala:369)
	at org.apache.spark.sql.execution.aggregate.HashAggregateExec.$anonfun$doExecute$1(HashAggregateExec.scala:126)
	at org.apache.spark.sql.execution.aggregate.HashAggregateExec.$anonfun$doExecute$1$adapted(HashAggregateExec.scala:100)
	at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$Lambda$5933/0x000000f8024ae270.apply(Unknown Source)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndex$2(RDD.scala:918)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndex$2$adapted(RDD.scala:918)
	at org.apache.spark.rdd.RDD$$Lambda$3062/0x000000f801edd260.apply(Unknown Source)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:107)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:171)
	at org.apache.spark.scheduler.Task.run(Task.scala:147)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:647)
	at org.apache.spark.executor.Executor$TaskRunner$$Lambda$3017/0x000000f801ec0ba8.apply(Unknown Source)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:80)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$3(DAGScheduler.scala:2935)
	at scala.Option.getOrElse(Option.scala:201)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2935)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2927)
	at scala.collection.immutable.List.foreach(List.scala:334)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2927)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1295)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1295)
	at scala.Option.foreach(Option.scala:437)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1295)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3207)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3141)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3130)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:50)
Caused by: java.lang.OutOfMemoryError: Java heap space
	at java.base/java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:64)
	at java.base/java.nio.ByteBuffer.allocate(ByteBuffer.java:363)
	at org.apache.spark.io.ReadAheadInputStream.<init>(ReadAheadInputStream.java:111)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader.<init>(UnsafeSorterSpillReader.java:78)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.getReader(UnsafeSorterSpillWriter.java:159)
	at org.apache.spark.unsafe.map.BytesToBytesMap$MapIterator.advanceToNextPage(BytesToBytesMap.java:287)
	at org.apache.spark.unsafe.map.BytesToBytesMap$MapIterator.next(BytesToBytesMap.java:315)
	at org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap$1.next(UnsafeFixedWidthAggregationMap.java:175)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50)
	at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.processInputs(TungstenAggregationIterator.scala:196)
	at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.<init>(TungstenAggregationIterator.scala:369)
	at org.apache.spark.sql.execution.aggregate.HashAggregateExec.$anonfun$doExecute$1(HashAggregateExec.scala:126)
	at org.apache.spark.sql.execution.aggregate.HashAggregateExec.$anonfun$doExecute$1$adapted(HashAggregateExec.scala:100)
	at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$Lambda$5933/0x000000f8024ae270.apply(Unknown Source)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndex$2(RDD.scala:918)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndex$2$adapted(RDD.scala:918)
	at org.apache.spark.rdd.RDD$$Lambda$3062/0x000000f801edd260.apply(Unknown Source)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:107)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:171)
	at org.apache.spark.scheduler.Task.run(Task.scala:147)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:647)
	at org.apache.spark.executor.Executor$TaskRunner$$Lambda$3017/0x000000f801ec0ba8.apply(Unknown Source)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:80)


In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, mean, max as spark_max, min as spark_min, stddev, expr, isnan, isnull
from pyspark.sql.types import IntegerType, DoubleType
import matplotlib.pyplot as plt
import seaborn as sns

# 1. Spark 세션 시작 (메모리 최적화)
spark = SparkSession.builder \
    .appName("GCS Preprocessing") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .getOrCreate()

# 2. 데이터 로딩 및 기본 검증
gcs_data = spark.read.parquet("outputs/gcs_with_sedation_flag")
print(f"Total records: {gcs_data.count()}")
print("Schema:")
gcs_data.printSchema()

# 3. 데이터 타입 명시적 변환 (타입 안정성 확보)
gcs_data = gcs_data.withColumn("valuenum", col("valuenum").cast(DoubleType())) \
    .withColumn("sedated_flag", col("sedated_flag").cast(IntegerType()))

# 4. sedated_flag를 charttime 단위로 집계 (캐싱 추가)
sed_flag_summary = gcs_data.groupBy("stay_id", "charttime").agg(
    spark_max("sedated_flag").alias("sedated_flag")
).cache()

# 5. GCS 피벗 (안전한 타입 처리)
gcs_pivoted = gcs_data.groupBy("stay_id", "charttime").pivot("gcs_itemid") \
    .agg(spark_max("valuenum")) \
    .fillna(0.0)  # fillna에 명시적 타입 지정

# 컬럼명 확인 및 존재 여부 체크
pivot_columns = gcs_pivoted.columns
expected_cols = ["220739", "223900", "223901"]
missing_cols = [col for col in expected_cols if col not in pivot_columns]
if missing_cols:
    print(f"Warning: Missing columns {missing_cols}")
    for col_name in missing_cols:
        gcs_pivoted = gcs_pivoted.withColumn(col_name, lit(0.0))

# 6. GCS 유효성 검증 및 이상치 마스킹 처리 (Validation & Masking)
gcs_pivoted = gcs_pivoted.withColumn(
    "gcs_total",
    col("220739") + col("223900") + col("223901")
).join(sed_flag_summary, on=["stay_id", "charttime"], how="left") \
.fillna({"sedated_flag": 0})  # null sedated_flag 처리
from pyspark.sql.functions import when

# GCS 구성 요소별 정상 범위:
# - Eye:    1~4
# - Verbal: 1~5
# - Motor:  1~6
# → 유효 범위를 벗어난 값은 None (null)으로 마스킹 처리
gcs_pivoted = gcs_pivoted.withColumn(
    "220739", when((col("220739") >= 1) & (col("220739") <= 4), col("220739")).otherwise(None)  # Eye
).withColumn(
    "223900", when((col("223900") >= 1) & (col("223900") <= 5), col("223900")).otherwise(None)  # Verbal
).withColumn(
    "223901", when((col("223901") >= 1) & (col("223901") <= 6), col("223901")).otherwise(None)  # Motor
)

# 마스킹 처리된 값들을 이용해 GCS 총점 재계산
gcs_pivoted = gcs_pivoted.withColumn(
    "gcs_total",
    col("220739") + col("223900") + col("223901")
)



# 캐싱으로 성능 최적화
gcs_pivoted.cache()

# 7. 시각화 (샘플링으로 메모리 절약)
try:
    gcs_sample = gcs_pivoted.sample(0.1).select("gcs_total", "sedated_flag").toPandas()
    plt.figure(figsize=(10, 6))
    sns.histplot(data=gcs_sample, x="gcs_total", hue="sedated_flag", bins=13, multiple="stack")
    plt.title("GCS Total Distribution by Sedation Flag (Sample)")
    plt.xlabel("GCS Total")
    plt.ylabel("Count")
    plt.savefig("outputs/gcs_total_distribution.png")
    plt.close()
    print("Visualization saved successfully")
except Exception as e:
    print(f"Visualization failed: {e}")

# 8. 실험 1: 그대로 사용
print("Running Experiment 1...")
gcs_summary_exp1 = gcs_pivoted.groupBy("stay_id").agg(
    mean("gcs_total").alias("gcs_mean"),
    spark_max("gcs_total").alias("gcs_max"),
    spark_min("gcs_total").alias("gcs_min"),
    stddev("gcs_total").alias("gcs_std"),
    mean("220739").alias("eye_mean"),
    mean("223900").alias("verbal_mean"),
    mean("223901").alias("motor_mean")
)
gcs_summary_exp1.write.mode("overwrite").parquet("outputs/exp1_gcs_summary")
print(f"Exp1 completed: {gcs_summary_exp1.count()} patients")

# 9. 실험 2: GCS + sedation 비율
print("Running Experiment 2...")
gcs_summary_exp2 = gcs_pivoted.groupBy("stay_id").agg(
    mean("gcs_total").alias("gcs_mean"),
    spark_max("gcs_total").alias("gcs_max"),
    spark_min("gcs_total").alias("gcs_min"),
    stddev("gcs_total").alias("gcs_std"),
    mean("sedated_flag").alias("sedated_ratio"),
    mean("220739").alias("eye_mean"),
    mean("223900").alias("verbal_mean"),
    mean("223901").alias("motor_mean")
)
gcs_summary_exp2.write.mode("overwrite").parquet("outputs/exp2_gcs_flagged_summary")
print(f"Exp2 completed: {gcs_summary_exp2.count()} patients")

# 10. 실험 3: sedation 시 verbal 마스킹 (안전한 처리)
print("Running Experiment 3...")
gcs_masked = gcs_data.withColumn(
    "valuenum_masked",
    when((col("sedated_flag") == 1) & (col("gcs_itemid") == 223900), None)
    .otherwise(col("valuenum").cast(DoubleType()))
)

# 11. 마스킹된 데이터 피벗 (타입 안정성)
gcs_masked_pivot = gcs_masked.groupBy("stay_id", "charttime").pivot("gcs_itemid") \
    .agg(spark_max("valuenum_masked")) \
    .fillna(0.0)

# 컬럼 존재 확인
if all(col in gcs_masked_pivot.columns for col in expected_cols):
    # 12. 마스킹 총점 계산
    gcs_masked_pivot = gcs_masked_pivot.join(sed_flag_summary, on=["stay_id", "charttime"], how="left") \
        .fillna({"sedated_flag": 0}) \
        .withColumn(
            "gcs_partial",
            when(col("sedated_flag") == 1, col("223901"))  # sedation 시 motor만
            .otherwise(col("220739") + col("223900") + col("223901"))  # 아니면 전체
        )
    
    # 13. 실험 3 요약 저장
    gcs_summary_exp3 = gcs_masked_pivot.groupBy("stay_id").agg(
        mean("gcs_partial").alias("gcs_partial_mean"),
        spark_max("223901").alias("motor_max"),
        mean("223901").alias("motor_mean"),
        stddev("223901").alias("motor_std"),
        mean("220739").alias("eye_mean"),
        mean("223900").alias("verbal_mean")
    )
    gcs_summary_exp3.write.mode("overwrite").parquet("outputs/exp3_gcs_masked")
    print(f"Exp3 completed: {gcs_summary_exp3.count()} patients")
    
    # 14. 안전한 결측 확인
    try:
        result = gcs_summary_exp3.agg(
            expr("count(*) as total"),
            expr("count(gcs_partial_mean) as non_null_gcs_partial"),
            expr("count(verbal_mean) as non_null_verbal")
        ).collect()[0]
        
        print(f"Results Summary:")
        print(f"- Total patients: {result['total']}")
        print(f"- Non-null GCS partial: {result['non_null_gcs_partial']}")
        print(f"- Non-null verbal: {result['non_null_verbal']}")
    except Exception as e:
        print(f"Result summary failed: {e}")
else:
    print("Error: Required columns missing in masked pivot")

# 15. 캐시 해제
gcs_pivoted.unpersist()
sed_flag_summary.unpersist()

# 16. Spark 종료
spark.stop()
print("Processing completed successfully")


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/06 20:52:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
                                                                                

Total records: 3780780
Schema:
root
 |-- charttime: timestamp (nullable = true)
 |-- gcs_itemid: integer (nullable = true)
 |-- valuenum: double (nullable = true)
 |-- sedated_flag: integer (nullable = true)
 |-- stay_id: integer (nullable = true)



                                                                                

Visualization saved successfully
Running Experiment 1...


25/08/06 20:54:38 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
25/08/06 20:54:38 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
25/08/06 20:54:38 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 76.00% for 10 writers
25/08/06 20:54:38 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
25/08/06 20:54:38 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

Exp1 completed: 47606 patients
Running Experiment 2...


                                                                                

Exp2 completed: 47606 patients
Running Experiment 3...


                                                                                

Exp3 completed: 47606 patients


                                                                                

Results Summary:
- Total patients: 47606
- Non-null GCS partial: 47606
- Non-null verbal: 47606
Processing completed successfully
