In [13]:
import pandas as pd
from pandas_gbq import read_gbq, to_gbq
import json

In [14]:
project_id = "cs-544-project-455917"

# Loading in initial cohort data (based on GCS data)

In [15]:
gcs_timestamp_alignment_query = """
WITH gcs_raw AS (
  SELECT
    stay_id,
    charttime,
    itemid
  FROM `physionet-data.mimiciv_3_1_icu.chartevents`
  WHERE itemid IN (220739, 223900, 223901)
    AND valuenum IS NOT NULL
)

, gcs_grouped AS (
  SELECT
    stay_id,
    charttime,
    ARRAY_AGG(DISTINCT itemid ORDER BY itemid) AS itemids,
    COUNT(DISTINCT itemid) AS n_components
  FROM gcs_raw
  GROUP BY stay_id, charttime
)

SELECT
  n_components,
  COUNT(*) AS n_rows,
  ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER (), 2) AS percent
FROM gcs_grouped
GROUP BY n_components
ORDER BY n_components DESC
"""

read_gbq(gcs_timestamp_alignment_query, project_id=project_id)

Downloading: 100%|[32m██████████[0m|


Unnamed: 0,n_components,n_rows,percent
0,3,2189488,98.72
1,2,17487,0.79
2,1,10812,0.49


In [16]:
gcs_query = """
WITH gcs_raw AS (
  SELECT
    subject_id,
    hadm_id,
    stay_id,
    charttime,
    itemid,
    valuenum
  FROM `physionet-data.mimiciv_3_1_icu.chartevents`
  WHERE itemid IN (220739, 223900, 223901)
    AND valuenum IS NOT NULL
),

gcs_pivoted AS (
  SELECT
    subject_id,
    hadm_id,
    stay_id,
    charttime,
    MAX(CASE WHEN itemid = 220739 THEN valuenum END) AS eye,
    MAX(CASE WHEN itemid = 223900 THEN valuenum END) AS verbal,
    MAX(CASE WHEN itemid = 223901 THEN valuenum END) AS motor
  FROM gcs_raw
  GROUP BY subject_id, hadm_id, stay_id, charttime
  HAVING eye IS NOT NULL AND verbal IS NOT NULL AND motor IS NOT NULL
),

gcs_summed AS (
  SELECT
    subject_id,
    hadm_id,
    stay_id,
    charttime,
    (eye + verbal + motor) AS gcs_score
  FROM gcs_pivoted
),

gcs_with_lag AS (
  SELECT
    subject_id,
    hadm_id,
    stay_id,
    charttime,
    gcs_score,
    LAG(gcs_score) OVER (PARTITION BY stay_id ORDER BY charttime) AS prev_gcs_score,
    LAG(charttime) OVER (PARTITION BY stay_id ORDER BY charttime) AS prev_time
  FROM gcs_summed
)

SELECT
  *,
  DATETIME_DIFF(charttime, prev_time, SECOND) / 60.0 AS mins_since_prev
FROM gcs_with_lag
WHERE prev_gcs_score IS NOT NULL
  AND DATETIME_DIFF(charttime, prev_time, SECOND) / 60.0 <= 1440.0
ORDER BY stay_id, charttime
"""

gcs_df = read_gbq(gcs_query, project_id=project_id)
gcs_df.head()

Downloading: 100%|[32m██████████[0m|


Unnamed: 0,subject_id,hadm_id,stay_id,charttime,gcs_score,prev_gcs_score,prev_time,mins_since_prev
0,12466550,23998182,30000153,2174-09-29 16:26:00,11.0,9.0,2174-09-29 12:45:00,221.0
1,12466550,23998182,30000153,2174-09-29 17:37:00,11.0,11.0,2174-09-29 16:26:00,71.0
2,12466550,23998182,30000153,2174-09-29 18:00:00,9.0,11.0,2174-09-29 17:37:00,23.0
3,12466550,23998182,30000153,2174-09-29 19:00:00,9.0,9.0,2174-09-29 18:00:00,60.0
4,12466550,23998182,30000153,2174-09-29 20:00:00,12.0,9.0,2174-09-29 19:00:00,60.0


### Reducing cohort based on GCS recording time differential

In [17]:
gcs_df['mins_since_prev'].value_counts().iloc[:10]

mins_since_prev
240.0    731387
120.0    330571
60.0     219254
180.0     34410
480.0     32362
720.0     26497
300.0     22262
360.0     11567
30.0       8986
210.0      8883
Name: count, dtype: int64

In [None]:
l1 = len(gcs_df)

2094207

In [19]:
gcs_df = gcs_df[(gcs_df['mins_since_prev']>=60) & (gcs_df['mins_since_prev']<=240)].copy()

In [21]:
len(gcs_df)

1648362

In [22]:
1648362 / 2094207

0.7871055726582902

### Adding GCS delta and deterioration labels

In [13]:
def gcs_deterioration_multi_map(delta):
    if delta >= 0:
        return 0
    elif delta == -1:
        return 1
    elif delta == -2:
        return 2
    else:
        return 3

gcs_df['gcs_delta'] = gcs_df['gcs_score'] - gcs_df['prev_gcs_score']
gcs_df['gcs_deterioration_binary'] = (gcs_df['gcs_delta'] <= -2).astype(int)
gcs_df['gcs_deterioration_multi'] = gcs_df['gcs_delta'].apply(gcs_deterioration_multi_map)

In [14]:
gcs_df

Unnamed: 0,subject_id,hadm_id,stay_id,charttime,gcs_score,prev_gcs_score,prev_time,mins_since_prev,gcs_delta,gcs_deterioration_binary,gcs_deterioration_multi
0,12466550,23998182,30000153,2174-09-29 16:26:00,11.0,9.0,2174-09-29 12:45:00,221.0,2.0,0,0
1,12466550,23998182,30000153,2174-09-29 17:37:00,11.0,11.0,2174-09-29 16:26:00,71.0,0.0,0,0
3,12466550,23998182,30000153,2174-09-29 19:00:00,9.0,9.0,2174-09-29 18:00:00,60.0,0.0,0,0
4,12466550,23998182,30000153,2174-09-29 20:00:00,12.0,9.0,2174-09-29 19:00:00,60.0,3.0,0,0
5,12466550,23998182,30000153,2174-09-29 21:00:00,12.0,12.0,2174-09-29 20:00:00,60.0,0.0,0,0
...,...,...,...,...,...,...,...,...,...,...,...
2094200,10826461,26489824,39999858,2167-04-29 08:00:00,15.0,15.0,2167-04-29 04:00:00,240.0,0.0,0,0
2094203,10826461,26489824,39999858,2167-04-30 00:00:00,15.0,15.0,2167-04-29 20:00:00,240.0,0.0,0,0
2094204,10826461,26489824,39999858,2167-04-30 04:00:00,15.0,15.0,2167-04-30 00:00:00,240.0,0.0,0,0
2094205,10826461,26489824,39999858,2167-04-30 07:00:00,15.0,15.0,2167-04-30 04:00:00,180.0,0.0,0,0


In [15]:
for metric in ['gcs_score', 'gcs_delta', 'gcs_deterioration_binary', 'gcs_deterioration_multi']:
    print(gcs_df[metric].value_counts())
    print()

gcs_score
15.0    611684
14.0    216588
10.0    158989
11.0    133339
3.0      95135
9.0      81856
8.0      74882
7.0      73673
13.0     72843
6.0      59656
12.0     28628
5.0      21584
4.0      19505
Name: count, dtype: int64

gcs_delta
 0.0     1302939
 1.0      107845
-1.0       94652
 2.0       28768
-2.0       21154
 3.0       18777
 4.0       18021
-3.0       11735
 5.0       10550
-4.0        8276
 7.0        6021
-5.0        4273
 6.0        4071
-6.0        2385
-7.0        2155
 8.0        2128
-8.0        1119
 12.0        796
 11.0        749
-12.0        556
-11.0        358
-9.0         343
 9.0         272
-10.0        249
 10.0        170
Name: count, dtype: int64

gcs_deterioration_binary
0    1595759
1      52603
Name: count, dtype: int64

gcs_deterioration_multi
0    1501107
1      94652
3      31449
2      21154
Name: count, dtype: int64



### Saving cohort data

In [16]:
gcs_df.to_parquet('gcs.parquet', index=False)

In [17]:
to_gbq(
    dataframe=gcs_df,
    destination_table="my_data.gcs_df",
    project_id=project_id,
    if_exists="replace"
)

100%|██████████| 1/1 [00:00<?, ?it/s]


# Getting features from chartevents and labevents based on completeness of data

In [18]:
chart_completeness_query = """
-- Step 1: Create label times and 3 hourly bins
WITH label_times AS (
  SELECT stay_id, charttime AS label_time
  FROM `cs-544-project-455917.my_data.gcs_df`
),

-- Step 2: Create 3-hour input bins for each label
input_hours AS (
  SELECT
    lt.stay_id,
    lt.label_time,
    TIMESTAMP_SUB(lt.label_time, INTERVAL offset_min MINUTE) AS hour_start,
    offset_min
  FROM label_times lt, UNNEST([240, 180, 120]) AS offset_min
),

-- Step 3: Identify itemids present in each bin
chartevents_hits AS (
  SELECT
    ih.stay_id,
    ih.label_time,
    ih.offset_min,
    ce.itemid
  FROM input_hours ih
  JOIN `physionet-data.mimiciv_3_1_icu.chartevents` ce
    ON ih.stay_id = ce.stay_id
    AND ce.charttime BETWEEN ih.hour_start AND TIMESTAMP_ADD(ih.hour_start, INTERVAL 60 MINUTE)
    AND ce.valuenum IS NOT NULL
),

-- Step 4: Count how many bins each itemid appears in per label
chartevents_bin_counts AS (
  SELECT
    stay_id,
    label_time,
    itemid,
    COUNT(DISTINCT offset_min) AS bins_covered
  FROM chartevents_hits
  GROUP BY stay_id, label_time, itemid
),

-- Step 5: Total label sample count
label_count AS (
  SELECT COUNT(*) AS total_labels FROM label_times
),

-- Step 6: Compute coverage stats per itemid
itemid_coverage AS (
  SELECT
    cbc.itemid,
    COUNT(*) AS n_total,
    SUM(CASE WHEN cbc.bins_covered >= 1 THEN 1 ELSE 0 END) AS n_bin1,
    SUM(CASE WHEN cbc.bins_covered >= 2 THEN 1 ELSE 0 END) AS n_bin2,
    SUM(CASE WHEN cbc.bins_covered = 3 THEN 1 ELSE 0 END) AS n_bin3
  FROM chartevents_bin_counts cbc
  GROUP BY cbc.itemid
),

-- Step 7: Attach percentages and labels
final_coverage AS (
  SELECT
    ic.itemid,
    d.label,
    ic.n_total,
    ic.n_bin1,
    ic.n_bin2,
    ic.n_bin3,
    ROUND(ic.n_bin1 * 100.0 / lc.total_labels, 2) AS pct_bin1,
    ROUND(ic.n_bin2 * 100.0 / lc.total_labels, 2) AS pct_bin2,
    ROUND(ic.n_bin3 * 100.0 / lc.total_labels, 2) AS pct_bin3
  FROM itemid_coverage ic
  CROSS JOIN label_count lc
  LEFT JOIN `physionet-data.mimiciv_3_1_icu.d_items` d USING (itemid)
)

SELECT *
FROM final_coverage
ORDER BY pct_bin3 DESC
"""

chart_completeness_df = read_gbq(chart_completeness_query, project_id=project_id)
chart_completeness_df.head(20)

Downloading: 100%|[32m██████████[0m|


Unnamed: 0,itemid,label,n_total,n_bin1,n_bin2,n_bin3,pct_bin1,pct_bin2,pct_bin3
0,220045,Heart Rate,1644816,1644816,1615192,1562499,99.78,97.99,94.79
1,220210,Respiratory Rate,1627892,1627892,1592339,1535286,98.76,96.6,93.14
2,220277,O2 saturation pulseoxymetry,1638851,1638851,1599469,1533487,99.42,97.03,93.03
3,220179,Non Invasive Blood Pressure systolic,1105299,1105299,1002423,840828,67.05,60.81,51.01
4,220180,Non Invasive Blood Pressure diastolic,1105127,1105127,1002319,840728,67.04,60.81,51.0
5,220181,Non Invasive Blood Pressure mean,1102942,1102942,1000393,840201,66.91,60.69,50.97
6,220052,Arterial Blood Pressure mean,622453,622453,605767,580575,37.76,36.75,35.22
7,220050,Arterial Blood Pressure systolic,620894,620894,604093,578594,37.67,36.65,35.1
8,220051,Arterial Blood Pressure diastolic,620908,620908,604087,578549,37.67,36.65,35.1
9,229321,Activity / Mobility (JH-HLM),932058,932058,755069,561482,56.54,45.81,34.06


In [19]:
lab_completeness_query = """
-- Step 1: Create label times and 3 hourly bins
WITH label_times AS (
  SELECT stay_id, hadm_id, subject_id, charttime AS label_time
  FROM `cs-544-project-455917.my_data.gcs_df`
),

-- Step 2: Create 3-hour input bins per label
input_hours AS (
  SELECT
    lt.stay_id,
    lt.hadm_id,
    lt.subject_id,
    lt.label_time,
    TIMESTAMP_SUB(lt.label_time, INTERVAL offset_min MINUTE) AS hour_start,
    offset_min
  FROM label_times lt, UNNEST([240, 180, 120]) AS offset_min
),

-- Step 3: Identify labevents itemids in each bin (joined on hadm_id)
labevents_hits AS (
  SELECT
    ih.stay_id,
    ih.hadm_id,
    ih.label_time,
    ih.offset_min,
    le.itemid
  FROM input_hours ih
  JOIN `physionet-data.mimiciv_3_1_hosp.labevents` le
    ON ih.hadm_id = le.hadm_id
    AND le.charttime BETWEEN ih.hour_start AND TIMESTAMP_ADD(ih.hour_start, INTERVAL 60 MINUTE)
    AND le.valuenum IS NOT NULL
),

-- Step 4: Count how many bins each itemid appears in per label
lab_bin_counts AS (
  SELECT
    stay_id,
    label_time,
    itemid,
    COUNT(DISTINCT offset_min) AS bins_covered
  FROM labevents_hits
  GROUP BY stay_id, label_time, itemid
),

-- Step 5: Total label sample count
label_count AS (
  SELECT COUNT(*) AS total_labels FROM label_times
),

-- Step 6: Compute coverage stats per itemid
itemid_coverage AS (
  SELECT
    lbc.itemid,
    COUNT(*) AS n_total,
    SUM(CASE WHEN lbc.bins_covered >= 1 THEN 1 ELSE 0 END) AS n_bin1,
    SUM(CASE WHEN lbc.bins_covered >= 2 THEN 1 ELSE 0 END) AS n_bin2,
    SUM(CASE WHEN lbc.bins_covered = 3 THEN 1 ELSE 0 END) AS n_bin3
  FROM lab_bin_counts lbc
  GROUP BY lbc.itemid
),

-- Step 7: Attach percentages and labels
final_coverage AS (
  SELECT
    ic.itemid,
    d.label,
    ic.n_total,
    ic.n_bin1,
    ic.n_bin2,
    ic.n_bin3,
    ROUND(ic.n_bin1 * 100.0 / lc.total_labels, 2) AS pct_bin1,
    ROUND(ic.n_bin2 * 100.0 / lc.total_labels, 2) AS pct_bin2,
    ROUND(ic.n_bin3 * 100.0 / lc.total_labels, 2) AS pct_bin3
  FROM itemid_coverage ic
  CROSS JOIN label_count lc
  LEFT JOIN `physionet-data.mimiciv_3_1_hosp.d_labitems` d USING (itemid)
)

SELECT *
FROM final_coverage
ORDER BY pct_bin3 DESC
"""

lab_completeness_df = read_gbq(lab_completeness_query, project_id=project_id)
lab_completeness_df.head(20)

Downloading: 100%|[32m██████████[0m|


Unnamed: 0,itemid,label,n_total,n_bin1,n_bin2,n_bin3,pct_bin1,pct_bin2,pct_bin3
0,50820,pH,254853,254853,33096,3314,15.46,2.01,0.2
1,50802,Base Excess,241892,241892,31675,3164,14.67,1.92,0.19
2,50821,pO2,242004,242004,31721,3175,14.68,1.92,0.19
3,50818,pCO2,241881,241881,31674,3164,14.67,1.92,0.19
4,50804,Calculated Total CO2,241879,241879,31668,3164,14.67,1.92,0.19
5,50809,Glucose,70270,70270,11244,1580,4.26,0.68,0.1
6,50817,Oxygen Saturation,66297,66297,10008,919,4.02,0.61,0.06
7,50808,Free Calcium,139077,139077,11207,765,8.44,0.68,0.05
8,50813,Lactate,148551,148551,12852,822,9.01,0.78,0.05
9,50822,"Potassium, Whole Blood",63700,63700,7599,875,3.86,0.46,0.05


In [20]:
chart_completeness_df.head(100).to_csv('chart_completeness_top_100.csv', index=False)
lab_completeness_df.head(100).to_csv('lab_completeness_top_100.csv', index=False)

In [21]:
# Filter down to high-quality features based on coverage
chart_top = chart_completeness_df.copy()
lab_top = lab_completeness_df.copy()

chart_top['source'] = 'chartevents'
lab_top['source'] = 'labevents'

# Identify chart features for different levels of modeling
chart_sequential = chart_top[chart_top['pct_bin3'] >= 35].copy()
chart_static = chart_top[(chart_top['pct_bin1'] >= 50) & (chart_top['pct_bin3'] < 35)].copy()

# Identify lab features that have at least 10% appearance in >=1 bin
lab_static = lab_top[lab_top['pct_bin1'] >= 10].copy()

# Add strategy annotations
chart_sequential['strategy'] = 'sequential'
chart_static['strategy'] = 'static'
lab_static['strategy'] = 'static'

# Combine and sort
final_features = pd.concat([chart_sequential, chart_static, lab_static], ignore_index=True)
final_features = final_features[['source', 'itemid', 'label', 'pct_bin1', 'pct_bin2', 'pct_bin3', 'strategy']]
final_features.sort_values(['strategy', 'source', 'pct_bin3'], ascending=[True, True, False], inplace=True)
final_features

Unnamed: 0,source,itemid,label,pct_bin1,pct_bin2,pct_bin3,strategy
0,chartevents,220045,Heart Rate,99.78,97.99,94.79,sequential
1,chartevents,220210,Respiratory Rate,98.76,96.60,93.14,sequential
2,chartevents,220277,O2 saturation pulseoxymetry,99.42,97.03,93.03,sequential
3,chartevents,220179,Non Invasive Blood Pressure systolic,67.05,60.81,51.01,sequential
4,chartevents,220180,Non Invasive Blood Pressure diastolic,67.04,60.81,51.00,sequential
...,...,...,...,...,...,...,...
56,labevents,51277,RDW,18.67,1.01,0.01,static
57,labevents,50868,Anion Gap,21.03,1.01,0.01,static
58,labevents,50902,Chloride,22.23,1.13,0.01,static
59,labevents,51249,MCHC,18.69,1.01,0.01,static


In [22]:
remove_labels = ['Goal Richmond-RAS Scale', 'Alarms On', 'Parameters Checked', 'ST Segment Monitoring On']
final_features = final_features[~final_features['label'].isin(remove_labels)].reset_index(drop=True)

In [23]:
final_features['strategy'].value_counts()

strategy
static        48
sequential     9
Name: count, dtype: int64

In [24]:
to_gbq(
    dataframe=final_features,
    destination_table="my_data.selected_features_table",
    project_id=project_id,
    if_exists="replace"
)

100%|██████████| 1/1 [00:00<?, ?it/s]


# Pulling data
- Demographic data (per subject)
- Sequential feature data (average per hour over 3hr period)
- Static feature data (average over 3hr period)

In [25]:
demographic_features_query = """
SELECT 
  subject_id,
  anchor_age,
  CASE
    WHEN gender = 'F' THEN 0
    ELSE 1
  END AS gender_numeric
FROM `physionet-data.mimiciv_3_1_hosp.patients`
"""

demographic_features_df = read_gbq(demographic_features_query, project_id=project_id)
demographic_features_df.head()

Downloading: 100%|[32m██████████[0m|


Unnamed: 0,subject_id,anchor_age,gender_numeric
0,10078138,18,0
1,10851602,18,0
2,10902424,18,0
3,11289691,18,0
4,11739764,18,0


In [27]:
sequential_features_query = """
WITH selected_features AS (
  SELECT itemid
  FROM `cs-544-project-455917.my_data.selected_features_table`
  WHERE strategy = 'sequential'
),

label_times AS (
  SELECT stay_id, charttime AS label_time
  FROM `cs-544-project-455917.my_data.gcs_df`
),

input_hours AS (
  SELECT
    lt.stay_id,
    lt.label_time,
    TIMESTAMP_SUB(lt.label_time, INTERVAL offset_min MINUTE) AS hour_start,
    offset_min
  FROM label_times lt, UNNEST([240, 180, 120]) AS offset_min
),

chartevents_filtered AS (
  SELECT
    ih.stay_id,
    ih.label_time,
    ih.offset_min,
    ce.itemid,
    ce.valuenum
  FROM input_hours ih
  JOIN `physionet-data.mimiciv_3_1_icu.chartevents` ce
    ON ih.stay_id = ce.stay_id
    AND ce.charttime BETWEEN ih.hour_start AND TIMESTAMP_ADD(ih.hour_start, INTERVAL 60 MINUTE)
    AND ce.valuenum IS NOT NULL
  JOIN selected_features sf ON ce.itemid = sf.itemid
),

aggregated_seq AS (
  SELECT
    stay_id,
    label_time,
    offset_min,
    itemid,
    AVG(valuenum) AS val_avg
  FROM chartevents_filtered
  GROUP BY stay_id, label_time, offset_min, itemid
)

SELECT *
FROM aggregated_seq
ORDER BY stay_id, label_time, offset_min
"""

# sequential_features_df = read_gbq(sequential_features_query, project_id=project_id)
# sequential_features_df.head()

In [28]:
static_features_query = """
WITH selected_features AS (
  SELECT itemid, source
  FROM `cs-544-project-455917.my_data.selected_features_table`
  WHERE strategy = 'static'
),

label_times AS (
  SELECT stay_id, hadm_id, charttime AS label_time
  FROM `cs-544-project-455917.my_data.gcs_df`
),

input_window AS (
  SELECT
    stay_id,
    hadm_id,
    label_time,
    TIMESTAMP_SUB(label_time, INTERVAL 240 MINUTE) AS window_start,
    TIMESTAMP_SUB(label_time, INTERVAL 60 MINUTE) AS window_end
  FROM label_times
),

chartevents_filtered AS (
  SELECT
    iw.stay_id,
    iw.label_time,
    ce.itemid,
    AVG(ce.valuenum) AS val_avg
  FROM input_window iw
  JOIN `physionet-data.mimiciv_3_1_icu.chartevents` ce
    ON iw.stay_id = ce.stay_id
    AND ce.charttime BETWEEN iw.window_start AND iw.window_end
    AND ce.valuenum IS NOT NULL
  JOIN selected_features sf ON ce.itemid = sf.itemid AND sf.source = 'chartevents'
  GROUP BY iw.stay_id, iw.label_time, ce.itemid
),

labevents_filtered AS (
  SELECT
    iw.stay_id,
    iw.label_time,
    le.itemid,
    AVG(le.valuenum) AS val_avg
  FROM input_window iw
  JOIN `physionet-data.mimiciv_3_1_hosp.labevents` le
    ON iw.hadm_id = le.hadm_id
    AND le.charttime BETWEEN iw.window_start AND iw.window_end
    AND le.valuenum IS NOT NULL
  JOIN selected_features sf ON le.itemid = sf.itemid AND sf.source = 'labevents'
  GROUP BY iw.stay_id, iw.label_time, le.itemid
)

-- Combine chart and lab static features
SELECT * FROM chartevents_filtered
UNION ALL
SELECT * FROM labevents_filtered
ORDER BY stay_id, label_time
"""

# static_features_df = read_gbq(static_features_query, project_id=project_id)
# static_features_df.head()

In [29]:
demographic_features_df.to_parquet('demographic_features.parquet', index=False)
# sequential_features_df.to_parquet('sequential_features.parquet', index=False)
# static_features_df.to_parquet('static_features.parquet', index=False)

# Interpolating data for sequential features

In [31]:
# Step 1: Convert to wide format with offset-based suffixes
sequential_features_df['feature_time'] = 'item_' + sequential_features_df['itemid'].astype(str) + '_t' + (sequential_features_df['offset_min'] // 60).astype(str)

seq_pivot = sequential_features_df.pivot_table(
    index=['stay_id', 'label_time'],
    columns='feature_time',
    values='val_avg'
).reset_index()

# Interpolation
seq_pivot.iloc[:, 2:] = seq_pivot.iloc[:, 2:].interpolate(axis=1, limit_direction='both')
seq_pivot.head()

feature_time,stay_id,label_time,item_220045_t2,item_220045_t3,item_220045_t4,item_220050_t2,item_220050_t3,item_220050_t4,item_220051_t2,item_220051_t3,...,item_220180_t4,item_220181_t2,item_220181_t3,item_220181_t4,item_220210_t2,item_220210_t3,item_220210_t4,item_220277_t2,item_220277_t3,item_220277_t4
0,30000153,2174-09-29 16:26:00,87.5,95.75,104.0,123.0,131.0,151.0,65.0,61.0,...,77.0,79.333333,81.666667,84.0,14.0,16.0,16.0,100.0,100.0,100.0
1,30000153,2174-09-29 17:37:00,83.0,87.5,98.25,109.0,123.0,131.0,55.0,65.0,...,41.6,35.2,28.8,22.4,16.0,14.0,16.0,100.0,100.0,100.0
2,30000153,2174-09-29 19:00:00,107.0,93.0,87.5,122.0,110.0,116.0,59.5,55.5,...,42.7,36.9,31.1,25.3,19.5,18.0,15.0,99.5,100.0,100.0
3,30000153,2174-09-29 20:00:00,117.0,107.0,93.0,144.0,122.0,110.0,65.5,59.5,...,40.4,35.3,30.2,25.1,20.0,19.5,18.0,97.5,99.5,100.0
4,30000153,2174-09-29 21:00:00,125.5,117.0,107.0,138.5,144.0,122.0,67.5,65.5,...,43.4,37.8,32.2,26.6,21.0,20.0,19.5,97.0,97.5,99.5


In [32]:
seq_pivot.to_parquet('sequential_features_final.parquet', index=False)

# Imputing data for static features

In [33]:
# Convert to wide format
static_features_df['feature'] = 'item_' + static_features_df['itemid'].astype(str)

static_pivot = static_features_df.pivot_table(
    index=['stay_id', 'label_time'],
    columns='feature',
    values='val_avg'
).reset_index()

# Mean imputation
for col in static_pivot.columns[2:]:
    static_pivot[col] = static_pivot[col].fillna(static_pivot[col].mean())
static_pivot.head()

feature,stay_id,label_time,item_220739,item_223761,item_223900,item_223901,item_224054,item_224055,item_224056,item_224057,...,item_51249,item_51250,item_51265,item_51274,item_51275,item_51277,item_51279,item_51301,item_51678,item_52172
0,30000153,2174-09-29 16:26:00,3.0,99.1,1.0,5.0,3.0,3.0,1.0,2.0,...,32.79426,91.655354,207.452422,16.052083,41.787498,15.575067,3.365858,12.369007,15.002034,51.773522
1,30000153,2174-09-29 17:37:00,4.0,99.3,1.0,6.0,3.0,3.0,1.0,2.0,...,34.1,96.0,173.0,13.2,25.3,13.5,3.3,17.0,15.002034,51.773522
2,30000153,2174-09-29 19:00:00,3.666667,99.5,1.0,5.666667,3.0,3.0,1.0,2.0,...,34.1,96.0,173.0,13.2,25.3,13.5,3.3,17.0,15.002034,51.773522
3,30000153,2174-09-29 20:00:00,3.5,99.5,1.0,5.5,3.0,3.0,1.0,2.0,...,32.79426,91.655354,207.452422,16.052083,41.787498,15.575067,3.365858,12.369007,15.002034,51.773522
4,30000153,2174-09-29 21:00:00,3.25,100.8,1.5,5.5,2.848898,3.370517,1.424031,2.451514,...,32.79426,91.655354,207.452422,16.052083,41.787498,15.575067,3.365858,12.369007,15.002034,51.773522


In [34]:
static_pivot.to_parquet('static_features_final.parquet', index=False)

# Merging data into final dataset

In [3]:
gcs_df = pd.read_parquet('gcs.parquet')
demographic_features_df = pd.read_parquet('demographic_features.parquet')
static_features_final_df = pd.read_parquet('static_features_final.parquet')
sequential_features_final_df = pd.read_parquet('sequential_features_final.parquet')

In [4]:
for df in [gcs_df, static_features_final_df, sequential_features_final_df]:
    print(len(df))

1648362
1648362
1645902


In [5]:
merged_df = gcs_df.merge(demographic_features_df, how='left', on='subject_id')
merged_df = merged_df.merge(static_features_final_df, how='left', left_on=['stay_id', 'charttime'], right_on=['stay_id', 'label_time'])
merged_df = merged_df.merge(sequential_features_final_df, how='left', left_on=['stay_id', 'charttime'], right_on=['stay_id', 'label_time'], suffixes=('', '_seq'))
merged_df = merged_df.dropna()
merged_df

Unnamed: 0,subject_id,hadm_id,stay_id,charttime,gcs_score,prev_gcs_score,prev_time,mins_since_prev,gcs_delta,gcs_deterioration_binary,...,item_220180_t4,item_220181_t2,item_220181_t3,item_220181_t4,item_220210_t2,item_220210_t3,item_220210_t4,item_220277_t2,item_220277_t3,item_220277_t4
0,12466550,23998182,30000153,2174-09-29 16:26:00,11.0,9.0,2174-09-29 12:45:00,221.0,2.0,0,...,77.000000,79.333333,81.666667,84.000000,14.0,16.0,16.0,100.0,100.0,100.0
1,12466550,23998182,30000153,2174-09-29 17:37:00,11.0,11.0,2174-09-29 16:26:00,71.0,0.0,0,...,41.600000,35.200000,28.800000,22.400000,16.0,14.0,16.0,100.0,100.0,100.0
2,12466550,23998182,30000153,2174-09-29 19:00:00,9.0,9.0,2174-09-29 18:00:00,60.0,0.0,0,...,42.700000,36.900000,31.100000,25.300000,19.5,18.0,15.0,99.5,100.0,100.0
3,12466550,23998182,30000153,2174-09-29 20:00:00,12.0,9.0,2174-09-29 19:00:00,60.0,3.0,0,...,40.400000,35.300000,30.200000,25.100000,20.0,19.5,18.0,97.5,99.5,100.0
4,12466550,23998182,30000153,2174-09-29 21:00:00,12.0,12.0,2174-09-29 20:00:00,60.0,0.0,0,...,43.400000,37.800000,32.200000,26.600000,21.0,20.0,19.5,97.0,97.5,99.5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1648357,10826461,26489824,39999858,2167-04-29 08:00:00,15.0,15.0,2167-04-29 04:00:00,240.0,0.0,0,...,67.000000,75.000000,72.000000,80.000000,26.5,30.5,23.5,90.0,91.5,90.0
1648358,10826461,26489824,39999858,2167-04-30 00:00:00,15.0,15.0,2167-04-29 20:00:00,240.0,0.0,0,...,52.000000,63.000000,65.000000,67.000000,20.5,17.0,24.0,96.5,95.5,93.0
1648359,10826461,26489824,39999858,2167-04-30 04:00:00,15.0,15.0,2167-04-30 00:00:00,240.0,0.0,0,...,74.333333,79.000000,59.833333,40.666667,21.5,26.5,22.0,92.0,96.0,92.5
1648360,10826461,26489824,39999858,2167-04-30 07:00:00,15.0,15.0,2167-04-30 04:00:00,180.0,0.0,0,...,65.000000,69.000000,55.666667,42.333333,29.0,30.5,24.5,94.5,95.0,92.0


In [6]:
merged_df.to_parquet('final_dataset.parquet', index=False)