In [None]:
# Core libraries
import os
import gc
import time
from pathlib import Path

# Data manipulation and analysis
import pandas as pd
import numpy as np

# Database connection
import psycopg2

# Machine learning utilities
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# File and model persistence
import joblib
import tables

# Progress bar for loops
from tqdm import tqdm

In [None]:
# Database connection and ICU admissions extraction
# This block connects to a local PostgreSQL database with MIMIC-IV
# and creates a temporary table (all_icu_stays) containing ICU admissions longer than 1 hour.

# Database connection parameters
DB_NAME = "mimiciv"
DB_HOST = "localhost"
DB_PORT = "5432"
DB_USER = "Enter database username: "
DB_PASSWORD = "Enter database password: "

# Attempt to connect to the database
try:
    conn = psycopg2.connect(
        dbname=DB_NAME,
        user=DB_USER,
        password=DB_PASSWORD,
        host=DB_HOST,
        port=DB_PORT
    )
    print("Connected to the database.")
except Exception as e:
    print("Connection failed:", e)

# Create a temporary table with ICU stays longer than 1 hour
with conn.cursor() as cursor:
    cursor.execute("ROLLBACK")  # Cancel any existing transaction
    cursor.execute("""
        CREATE TEMP TABLE all_icu_stays AS
        SELECT 
            i.subject_id,
            i.hadm_id,
            i.stay_id,
            i.intime,
            i.outtime,
            EXTRACT(EPOCH FROM (i.outtime - i.intime)) / 60 AS duration_minutes
        FROM mimiciv_icu.icustays i
        WHERE EXTRACT(EPOCH FROM (i.outtime - i.intime)) > 3600;
    """)
    conn.commit()

# Confirm number of ICU admissions retrieved
df_total = pd.read_sql("SELECT COUNT(*) AS total_admissions FROM all_icu_stays;", conn)
print(f"Total ICU admissions longer than 1 hour: {df_total['total_admissions'][0]}")

In [None]:
# Definition of clinical variables grouped by sampling frequency.
# These groups help in designing the time grid and imputation strategy later on.

# High-frequency variables: typically measured every few minutes to 1 hour
vars_high_freq = [
    'heart_rate', 'sbp', 'dbp', 'mbp', 'resp_rate', 'spo2',
    'lab_53085', 'lab_51580', 'lab_52642', 'lab_51002', 'lab_52116', 'lab_51623',
    'lab_50928', 'lab_52117', 'lab_50855', 'lab_52546', 'lab_53161', 'lab_53180',
    'lab_52142', 'lab_51266', 'lab_52144', 'lab_51631', 'lab_51638', 'lab_51640',
    'lab_51647', 'lab_51643', 'lab_50975', 'lab_51292', 'lab_51290', 'lab_51291',
    'lab_52551', 'lab_51568', 'lab_51569', 'lab_51570', 'lab_51464', 'lab_51966'
]

# Medium-frequency variables: measured every few hours
vars_medium_freq = [
    'temperature', 'glucose', 'lab_50908', 'lab_50915', 'lab_50856', 'lab_50803',
    'lab_50805', 'lab_50808', 'lab_50809', 'lab_50813'
]

# Low-frequency variables: measured once or twice per day
vars_low_freq = [
    'lab_50861', 'lab_50862', 'lab_50883', 'lab_50884', 'lab_50885', 'lab_50910',
    'lab_50924', 'lab_50963', 'lab_51003', 'lab_50889', 'lab_51214', 'lab_50878',
    'lab_50912', 'lab_51265', 'lab_50931', 'lab_50935', 'lab_51222', 'lab_51223',
    'lab_50852', 'lab_50971', 'lab_50983', 'lab_50990', 'lab_50967', 'lab_50968',
    'lab_50969', 'lab_50960', 'lab_50966', 'lab_50970', 'lab_51099', 'lab_51006',
    'lab_51274', 'lab_51275', 'lab_51196'
]

# Vasopressors: medication administration indicators
vasopressor_vars = [
    'dopamine', 'epinephrine', 'norepinephrine', 'phenylephrine',
    'vasopressin', 'dobutamine', 'milrinone'
]

# Summary
print(f"High-frequency variables: {len(vars_high_freq)}")
print(f"Medium-frequency variables: {len(vars_medium_freq)}")
print(f"Low-frequency variables: {len(vars_low_freq)}")
print(f"Vasopressors: {len(vasopressor_vars)}")

In [None]:
# Default values (NORMAL_VALUES) are used to initialize or impute missing values
# for key vital signs and routine measurements, following standard reference values
# from the original paper and clinical guidelines.

NORMAL_VALUES = {
    'heart_rate': 70,
    'sbp': 125,
    'dbp': 75,
    'mbp': 90,
    'resp_rate': 12,
    'temperature': 37,
    'spo2': 98,
    'glucose': 5,
}

# PARAMETROS_IMPUTACAO contains default values and acceptable clinical ranges
# for laboratory variables. These ranges are used to filter outliers and apply
# median/IQR-based imputation when appropriate.

PARAMETROS_IMPUTACAO = {
    # Format: variable: {'default': value, 'range': (min, max)}
    'lab_50803': {'default': 24, 'range': (22, 26)},                   # Calculated Bicarbonate, Whole Blood (mEq/L)
    'lab_50805': {'default': 1.5, 'range': (0.5, 2.0)},                # Carboxyhemoglobin, Blood (%)
    'lab_50808': {'default': 1.1, 'range': (0.9, 1.3)},                # Free Calcium, Blood (mmol/L)
    'lab_50809': {'default': 90, 'range': (70, 110)},                  # Glucose, Blood (mg/dL)
    'lab_50813': {'default': 1.0, 'range': (0.5, 2.2)},                # Lactate, Blood, Blood Gas (mmol/L)
    'lab_50852': {'default': 5.5, 'range': (4.0, 6.5)},                # % Hemoglobin A1c, Blood (%)
    'lab_50855': {'default': 14, 'range': (13, 17)},                   # Absolute Hemoglobin, Blood (g/dL)
    'lab_50856': {'default': 10, 'range': (5, 20)},                    # Acetaminophen, Blood (mg/L)
    'lab_50861': {'default': 30, 'range': (10, 50)},                   # Alanine Aminotransferase (ALT), Blood (U/L)
    'lab_50862': {'default': 4.3, 'range': (3.5, 5.5)},                # Albumin, Blood (g/dL)
    'lab_50878': {'default': 40, 'range': (10, 40)},                   # Asparate Aminotransferase (AST), Blood (U/L)
    'lab_50883': {'default': 0.3, 'range': (0.1, 0.4)},                # Bilirubin, Direct, Blood (mg/dL)
    'lab_50884': {'default': 0.6, 'range': (0.3, 1.2)},                # Bilirubin, Indirect, Blood (mg/dL)
    'lab_50885': {'default': 1.0, 'range': (0.3, 1.9)},                # Bilirubin, Total, Blood (mg/dL)
    'lab_50889': {'default': 5.0, 'range': (3.5, 5.1)},                # C-Reactive Protein, Blood (mg/dL)
    'lab_50908': {'default': 0.8, 'range': (0.7, 1.3)},                # CK-MB Index, Blood (ng/mL)
    'lab_50910': {'default': 150, 'range': (50, 200)},                 # Creatine Kinase (CK), Blood (U/L)
    'lab_50912': {'default': 1.0, 'range': (0.6, 1.3)},                # Creatinine, Blood (mg/dL)
    'lab_50915': {'default': 0.5, 'range': (0.2, 0.5)},                # D-Dimer, Blood (µg/mL)
    'lab_50924': {'default': 120, 'range': (20, 300)},                 # Ferritin, Blood (ng/mL)
    'lab_50928': {'default': 100, 'range': (0, 150)},                  # Gastrin, Blood (pg/mL)
    'lab_50931': {'default': 90, 'range': (70, 110)},                  # Glucose, Blood (mg/dL)
    'lab_50935': {'default': 100, 'range': (30, 200)},                 # Haptoglobin, Blood (mg/dL)
    'lab_50960': {'default': 2.0, 'range': (1.5, 2.5)},                # Magnesium, Blood (mg/dL)
    'lab_50963': {'default': 100, 'range': (20, 400)},                 # NTproBNP, Blood (pg/mL)
    'lab_50966': {'default': 15, 'range': (5, 20)},                    # Phenobarbital, Blood (µg/mL)
    'lab_50967': {'default': 15, 'range': (5, 20)},                    # Phenytoin, Blood (µg/mL)
    'lab_50968': {'default': 10, 'range': (5, 15)},                    # Phenytoin, Free, Blood (µg/mL)
    'lab_50969': {'default': 80, 'range': (70, 90)},                   # Phenytoin, Percent Free (%)
    'lab_50970': {'default': 3.5, 'range': (2.5, 4.5)},                # Phosphate, Blood (mg/dL)
    'lab_50971': {'default': 4.5, 'range': (3.5, 5.1)},                # Potassium, Blood (mmol/L)
    'lab_50975': {'default': 5.5, 'range': (4.5, 6.5)},                # Protein Electrophoresis, Blood (g/dL)
    'lab_50983': {'default': 140, 'range': (135, 145)},                # Sodium, Blood (mmol/L)
    'lab_50990': {'default': 10, 'range': (5, 20)},                    # Theophylline, Blood (µg/mL)
    'lab_51002': {'default': 0.01, 'range': (0, 0.1)},                 # Troponin I, Blood (ng/mL)
    'lab_51003': {'default': 0.01, 'range': (0, 0.1)},                 # Troponin T, Blood (ng/mL)
    'lab_51006': {'default': 15, 'range': (5, 20)},                    # Urea Nitrogen, Blood (mg/dL)
    'lab_51099': {'default': 0.2, 'range': (0, 0.5)},                  # Protein/Creatinine Ratio, Urine
    'lab_51196': {'default': 0.5, 'range': (0, 1)},                    # D-Dimer, Blood (mg/mL)
    'lab_51214': {'default': 300, 'range': (200, 400)},                # Fibrinogen, Functional, Blood (mg/dL)
    'lab_51222': {'default': 14, 'range': (12, 16)},                   # Hemoglobin, Blood (g/dL)
    'lab_51223': {'default': 3.0, 'range': (2.5, 3.5)},                # Hemoglobin A2, Blood (%)
    'lab_51265': {'default': 200, 'range': (150, 300)},                # Platelet Count, Blood (x10^3/uL)
    'lab_51266': {'default': 1.0, 'range': (0.5, 1.5)},                # Platelet Smear, Blood (%)
    'lab_51274': {'default': 12, 'range': (10, 14)},                   # PT, Blood (seconds)
    'lab_51275': {'default': 35, 'range': (30, 40)},                   # PTT, Blood (seconds)
    'lab_51290': {'default': 0, 'range': (0, 1)},                      # Sickle Cell Preparation, Blood (binary)
    'lab_51291': {'default': 0, 'range': (0, 1)},                      # Sickle Cells, Blood (binary)
    'lab_51292': {'default': 0, 'range': (0, 1)},                      # Spherocytes, Blood (binary)
    'lab_51464': {'default': 0, 'range': (0, 1)},                      # Bilirubin, Urine (binary)
    'lab_51568': {'default': 0.2, 'range': (0, 1)},                    # Bilirubin, Neonatal, Blood (mg/dL)
    'lab_51569': {'default': 0.2, 'range': (0, 1)},                    # Bilirubin, Neonatal, Direct, Blood (mg/dL)
    'lab_51570': {'default': 0.2, 'range': (0, 1)},                    # Bilirubin, Neonatal, Indirect, Blood (mg/dL)
    'lab_51580': {'default': 0.8, 'range': (0.5, 1.5)},                # Calculated CK-MB, Blood (ng/mL)
    'lab_51623': {'default': 300, 'range': (200, 400)},                # Fibrinogen, Blood (mg/dL)
    'lab_51631': {'default': 6.0, 'range': (4.5, 7.0)},                # Glycated Hemoglobin, Blood (%)
    'lab_51638': {'default': 45, 'range': (10, 50)},                   # Hematocrit, Blood (%)
    'lab_51640': {'default': 14, 'range': (12, 16)},                   # Hemoglobin, Blood (g/dL)
    'lab_51643': {'default': 3.0, 'range': (2.5, 3.5)},                # Hemoglobin A2, Blood (%)
    'lab_51647': {'default': 3.0, 'range': (2.5, 3.5)},                # Hemoglobin S, Blood (%)
    'lab_51966': {'default': 0, 'range': (0, 1)},                      # Bilirubin, Urine (binary)
    'lab_52116': {'default': 300, 'range': (200, 400)},                # Fibrinogen, Blood (mg/dL)
    'lab_52117': {'default': 300, 'range': (200, 400)},                # Fibrinogen, Immunologic, Blood (mg/dL)
    'lab_52142': {'default': 10, 'range': (7, 13)},                    # Mean Platelet Volume, Blood (fL)
    'lab_52144': {'default': 1.0, 'range': (0.5, 1.5)},                # Methemoglobin, Blood (%)
    'lab_52546': {'default': 1.0, 'range': (0.5, 1.5)},                # Creatinine, Blood (mg/dL)
    'lab_52551': {'default': 0.5, 'range': (0.2, 1.0)},                # D-Dimer, Blood (mg/mL)
    'lab_52642': {'default': 0.01, 'range': (0, 0.1)},                 # Troponin I, Blood (ng/mL)
    'lab_53085': {'default': 4.3, 'range': (3.5, 5.5)},                # Albumin, Blood (g/dL)
}

print("Normal values and imputation parameters loaded.")

In [None]:
# Generate a unified time grid (5-minute intervals) covering both vitals and labs
def generate_timegrid(df_vitals, df_labs):
    t0_vitals = df_vitals['charttime'].min() if not df_vitals.empty else None
    t0_labs = df_labs['charttime'].min() if not df_labs.empty else None
    tmax_vitals = df_vitals['charttime'].max() if not df_vitals.empty else None
    tmax_labs = df_labs['charttime'].max() if not df_labs.empty else None

    if t0_vitals is None and t0_labs is None:
        return pd.DatetimeIndex([])

    t0 = min(filter(None, [t0_vitals, t0_labs])).floor('5T')
    tmax = max(filter(None, [tmax_vitals, tmax_labs])).ceil('5T')

    return pd.date_range(start=t0, end=tmax, freq='5T')


# Returns a vector filled with a default value
def value_filled_array(size, default_val, dtype=None):
    arr = np.empty(size, dtype=dtype or float)
    arr[:] = default_val
    return arr

# Returns an array filled with NaNs
def empty_nan_array(size):
    arr = np.empty(size)
    arr[:] = np.nan
    return arr


# Check if a batch has already been processed by the existence of a .done file
def is_batch_processed(batch_idx, output_dir):
    done_flag = os.path.join(output_dir, f'batch_{batch_idx:04d}.parquet.done')
    return os.path.exists(done_flag)


# Simple forward-fill imputation over a fixed time grid
def impute_forward_fill_simple(obs_times, obs_values, time_grid, global_fill=np.nan):
    n = len(time_grid)
    output = np.full(n, global_fill, dtype=np.float32)
    last_val = np.nan
    i_obs = 0

    for i_pred, t_pred in enumerate(time_grid):
        while i_obs < len(obs_times) and obs_times[i_obs] <= t_pred:
            if not np.isnan(obs_values[i_obs]):
                last_val = obs_values[i_obs]
            i_obs += 1
        output[i_pred] = last_val if not np.isnan(last_val) else global_fill

    return output


# Applies imputation to a batch using the generated time grid
def impute_batch_with_grid(df_batch, variables, impute_defaults, time_grid, time_grid_dt, stay_id_if_empty=None):
    if df_batch.empty:
        if stay_id_if_empty is None:
            raise ValueError("Empty batch and no fallback stay_id provided.")
        result = {
            'stay_id': [stay_id_if_empty] * len(time_grid),
            'charttime': time_grid_dt
        }
        for var in variables:
            fill_val = impute_defaults.get(var, np.nan)
            result[var] = np.full(len(time_grid), fill_val, dtype=np.float32)
            result[f'{var}_imputed'] = np.ones(len(time_grid), dtype=int)
        return pd.DataFrame(result)

    frames = []

    for stay_id, group in df_batch.groupby('stay_id'):
        group = group.sort_values('charttime').reset_index(drop=True)
        timestamps = group['charttime'].values.astype('datetime64[s]').astype(np.int64)

        result = {
            'stay_id': [stay_id] * len(time_grid),
            'charttime': time_grid_dt
        }

        for var in variables:
            if var not in group.columns or group[var].dropna().empty:
                fill_val = impute_defaults.get(var, np.nan)
                result[var] = np.full(len(time_grid), fill_val, dtype=np.float32)
                result[f'{var}_imputed'] = np.ones(len(time_grid), dtype=int)
                continue

            raw_vals = group[var].values
            pred_vals = impute_forward_fill_simple(timestamps, raw_vals, time_grid, impute_defaults.get(var, np.nan))
            result[var] = pred_vals

            mask = np.zeros(len(time_grid), dtype=int)
            valid_ts = timestamps[(timestamps >= time_grid[0]) & (timestamps <= time_grid[-1]) & ~np.isnan(raw_vals)]
            idxs = np.searchsorted(time_grid, valid_ts)
            idxs = idxs[idxs < len(mask)]
            mask[idxs] = 1
            result[f'{var}_imputed'] = 1 - mask

        frames.append(pd.DataFrame(result))

    return pd.concat(frames, ignore_index=True)


# Query vitals for a batch of stay_ids from MIMIC-IV derived table
def query_vitals_batch(batch_stay_ids, conn):
    stay_id_list = ",".join(map(str, batch_stay_ids))
    query = f"""
    SELECT
        vs.stay_id,
        vs.charttime,
        vs.heart_rate,
        vs.sbp,
        vs.dbp,
        vs.mbp,
        vs.resp_rate,
        vs.temperature,
        vs.spo2,
        vs.glucose
    FROM mimiciv_derived.vitalsign vs
    WHERE vs.stay_id IN ({stay_id_list})
    ORDER BY vs.stay_id, vs.charttime;
    """
    print(f"Querying vitals for {len(batch_stay_ids)} stay_ids...")
    df = pd.read_sql(query, conn)
    df['charttime'] = pd.to_datetime(df['charttime'])
    print(f"Query completed. Retrieved {len(df)} rows.")
    return df

In [None]:
# Extracts and pivots laboratory test results for a batch of ICU stays.
# Lab values are reshaped into a wide format: one row per (stay_id, charttime) with columns named lab_<itemid>.

# Create a dictionary of default values (to be used in imputation later)
default_lab_values = {k: v['default'] for k, v in PARAMETROS_IMPUTACAO.items()}

# List of itemids to extract from labevents
lab_itemids = [
    50861, 50862, 53085, 50908, 51580, 50883, 50884, 50885, 50910, 50924,
    50963, 52642, 51002, 51003, 50889, 52116, 51623, 50928, 52117,
    51214, 50878, 50855, 50912, 52546, 53161, 53180, 52142, 51265, 51266,
    52144, 50931, 50935, 51631, 51638, 51640, 51222, 51223, 50856, 51647,
    50852, 51643, 50971, 50983, 50990, 50967, 50968, 50969, 50960, 50966,
    50970, 50975, 51099, 51006, 51274, 51275, 51292, 51290, 51291,
    51196, 52551, 50915, 51568, 51569, 51570, 51464, 51966, 50803, 50805,
    50808, 50809, 50813
]

def query_labs_batch(batch_stay_ids):
    stay_id_str = ",".join(map(str, batch_stay_ids))
    itemid_str = ",".join(map(str, lab_itemids))
    
    query = f"""
    SELECT
        icu.stay_id,
        le.charttime,
        le.itemid,
        le.valuenum
    FROM mimiciv_hosp.labevents le
    JOIN mimiciv_icu.icustays icu 
        ON le.subject_id = icu.subject_id AND le.hadm_id = icu.hadm_id
    WHERE le.itemid IN ({itemid_str})
      AND le.valuenum IS NOT NULL
      AND icu.stay_id IN (SELECT stay_id FROM todas_utis)
      AND icu.stay_id IN ({stay_id_str})
    ORDER BY icu.stay_id, le.charttime;
    """

    print(f"Querying lab results for {len(batch_stay_ids)} stay_ids...")
    df_labs = pd.read_sql(query, conn)
    df_labs['charttime'] = pd.to_datetime(df_labs['charttime'])
    print(f"Query completed. Retrieved {len(df_labs)} rows.")

    # Pivot to wide format: one row per stay_id x charttime
    df_wide = df_labs.pivot_table(index=['stay_id', 'charttime'], columns='itemid', values='valuenum')
    df_wide.columns = [f'lab_{col}' for col in df_wide.columns]
    df_wide = df_wide.reset_index().sort_values(['stay_id', 'charttime']).reset_index(drop=True)

    return df_wide

In [None]:
OUTPUT_DIR_FINAL = 'output/final_batch'
os.makedirs(OUTPUT_DIR_FINAL, exist_ok=True)


def is_batch_processed(batch_idx, output_dir):
    """Check if a batch has already been processed by looking for a .done flag."""
    return os.path.exists(os.path.join(output_dir, f'batch_{batch_idx:04d}.parquet.done'))


def save_batch(df, batch_idx, output_dir):
    """Save processed batch to .parquet and create a .done flag."""
    parquet_path = os.path.join(output_dir, f'batch_{batch_idx:04d}.parquet')
    done_path = os.path.join(output_dir, f'batch_{batch_idx:04d}.parquet.done')

    df.to_parquet(parquet_path)
    with open(done_path, 'w') as f:
        f.write('done')
    print(f"Batch {batch_idx} saved.")


def extract_stays_from_csv(path):
    """Load a list of stay_ids from CSV file with a 'stay_id' column."""
    return pd.read_csv(path)['stay_id'].unique().tolist()


# Load pre-extracted stay_ids from vitals and labs
stays_labs = extract_stays_from_csv('output/stays_labs.csv')
stays_vitals = extract_stays_from_csv('output/stays_vitals.csv')

# Use only stays that exist in both datasets
stay_ids_common = sorted(set(stays_labs).intersection(stays_vitals))

# Define batches of fixed size
BATCH_SIZE = 5000
batches = [stay_ids_common[i:i + BATCH_SIZE] for i in range(0, len(stay_ids_common), BATCH_SIZE)]

print(f"Total common stays: {len(stay_ids_common)}")
print(f"Total batches: {len(batches)}")

# Define variable names for vitals and labs
vital_vars = ['heart_rate', 'sbp', 'dbp', 'mbp', 'resp_rate', 'temperature', 'spo2', 'glucose']
lab_vars = [f'lab_{i}' for i in lab_itemids]


# Main batch loop
for batch_idx, stay_ids in enumerate(batches):
    if is_batch_processed(batch_idx, OUTPUT_DIR_FINAL):
        print(f"Skipping batch {batch_idx} (already processed).")
        continue

    print(f"\nProcessing batch {batch_idx} with {len(stay_ids)} stays...")

    df_vitals = query_vitals_batch(stay_ids, conn)
    df_labs = query_labs_batch(stay_ids)

    merged_data = []

    for sid in stay_ids:
        df_v = df_vitals[df_vitals['stay_id'] == sid]
        df_l = df_labs[df_labs['stay_id'] == sid]

        if df_v.empty and df_l.empty:
            continue

        timegrid = generate_timegrid(df_v, df_l)
        if timegrid.empty:
            continue

        timegrid_dt = pd.to_datetime(timegrid, unit='s')

        df_v_filled = impute_batch_with_grid(df_v, vital_vars, NORMAL_VALUES, timegrid, timegrid_dt, sid)
        df_l_filled = impute_batch_with_grid(df_l, lab_vars, default_lab_values, timegrid, timegrid_dt, sid)

        df_merged = pd.merge(df_v_filled, df_l_filled, on=['stay_id', 'charttime'], how='outer')
        merged_data.append(df_merged)

    if merged_data:
        df_batch = pd.concat(merged_data, ignore_index=True)
        save_batch(df_batch, batch_idx, OUTPUT_DIR_FINAL)

        # Save preview CSV
        if len(df_batch) >= 1000:
            csv_path = os.path.join(OUTPUT_DIR_FINAL, f'batch_{batch_idx:04d}_sample.csv')
            df_batch.head(1000).to_csv(csv_path, index=False)
    else:
        print(f"No data in batch {batch_idx}.")

In [None]:
# Paths to input and output directories
OUTPUT_DIR_MERGED = 'output/final_batch'
OUTPUT_DIR_FAILURE = 'output/falencia_batches'
os.makedirs(OUTPUT_DIR_FAILURE, exist_ok=True)

# List of vasopressors used for circulatory failure definition
vasopressor_cols = [
    'dopamine', 'epinephrine', 'norepinephrine', 'phenylephrine',
    'vasopressin', 'dobutamine', 'milrinone'
]

def process_batch_if_needed(batch_idx):
    """
    Skips processing if this batch has already been labeled and saved.
    """
    done_flag = os.path.join(OUTPUT_DIR_FAILURE, f'batch_{batch_idx:04d}.parquet.done')
    if os.path.exists(done_flag):
        print(f"[SKIP] Batch {batch_idx} already processed.")
        return
    try:
        process_batch(batch_idx)
        open(done_flag, 'a').close()
    except Exception as e:
        print(f"[ERROR] Failed to process batch {batch_idx}: {e}")


def process_batch(batch_idx):
    """
    Loads batch, labels circulatory failure events, and saves results.
    Failure definition is based on vasopressor usage or MBP < 65 mmHg and lactate >= 2 mmol/L.
    """
    import time
    import gc

    start_time = time.time()
    print(f"\n[INFO] Processing batch {batch_idx}...")

    path = os.path.join(OUTPUT_DIR_MERGED, f'batch_{batch_idx:04d}.parquet')
    if not os.path.exists(path):
        print(f"[ERROR] Batch {batch_idx} not found.")
        return

    # Load merged data with vitals and labs
    df = pd.read_parquet(path)
    df['charttime'] = pd.to_datetime(df['charttime'])
    df['stay_id'] = df['stay_id'].astype(int)

    # Get vasopressor administration times for relevant stays
    stays = df['stay_id'].dropna().unique().astype(int)
    stays_str = ",".join(map(str, stays))
    query = f"""
    SELECT stay_id, starttime, endtime, {', '.join(vasopressor_cols)}
    FROM mimiciv_derived.vasoactive_agent
    WHERE stay_id IN ({stays_str});
    """
    df_vaso = pd.read_sql(query, conn)
    df_vaso['starttime'] = pd.to_datetime(df_vaso['starttime'])
    df_vaso['endtime'] = pd.to_datetime(df_vaso['endtime'])
    df_vaso['stay_id'] = df_vaso['stay_id'].astype(int)

    results = []

    for sid in stays:
        df_sid = df[df['stay_id'] == sid].copy()
        df_vsid = df_vaso[df_vaso['stay_id'] == sid]

        # Initialize flags
        df_sid['vasopressor_active'] = 0
        df_sid['falencia'] = 0  # circulatory failure

        # Annotate periods with active vasopressors
        if not df_vsid.empty:
            for _, row in df_vsid.iterrows():
                if row[vasopressor_cols].notna().any():
                    mask = (df_sid['charttime'] >= row['starttime']) & (df_sid['charttime'] <= row['endtime'])
                    df_sid.loc[mask, 'vasopressor_active'] = 1

        # Apply circulatory failure rule
        mbp_low = df_sid['mbp'] < 65
        lactate_high = df_sid.get('lab_50813', pd.Series([0] * len(df_sid))) >= 2
        df_sid.loc[(mbp_low) | (df_sid['vasopressor_active'] & lactate_high), 'falencia'] = 1

        results.append(df_sid)

    if not results:
        print(f"[WARN] No data found for batch {batch_idx}.")
        return

    df_final = pd.concat(results, ignore_index=True)

    # Save final labeled batch as Parquet
    out_parquet = os.path.join(OUTPUT_DIR_FAILURE, f'batch_{batch_idx:04d}.parquet')
    df_final.to_parquet(out_parquet)

    # Save sample for inspection (merged general + positive class)
    sample_all = df_final.head(5000)
    sample_failure = df_final[df_final['falencia'] == 1].head(5000)
    sample_combined = pd.concat([sample_all, sample_failure]).drop_duplicates()
    sample_path = os.path.join(OUTPUT_DIR_FAILURE, f'batch_{batch_idx:04d}_sample.csv')
    sample_combined.to_csv(sample_path, index=False)

    print(f"[OK] Batch {batch_idx} saved with {len(df_final)} rows.")
    print(f"[SAMPLE] Saved to {sample_path}")
    print(f"[TIME] Elapsed: {time.time() - start_time:.2f} seconds")
    gc.collect()


# Run sequentially over all merged batches
merged_files = sorted([
    f for f in os.listdir(OUTPUT_DIR_MERGED)
    if f.endswith('.parquet') and f.startswith('batch_')
])
batches = [int(f.split('_')[1].split('.')[0]) for f in merged_files]
print(f"[INFO] {len(batches)} batches found in {OUTPUT_DIR_MERGED}")

for batch in batches:
    process_batch_if_needed(batch)

In [None]:
# Input and output directories
INPUT_DIR = 'output/falencia_batches'
OUTPUT_DIR = 'output/falencia_batches_divided'
os.makedirs(OUTPUT_DIR, exist_ok=True)

def split_batch_by_stay(batch_idx, stays_per_subbatch=1000):
    """
    Split a large batch (parquet file) into smaller chunks based on stay_id.

    Parameters:
        batch_idx (int): Index of the batch file to process.
        stays_per_subbatch (int): Max number of unique stay_ids per subbatch.
    """
    input_path = os.path.join(INPUT_DIR, f'batch_{batch_idx:04d}.parquet')
    print(f"[LOAD] Reading: {input_path}")

    df = pd.read_parquet(input_path)

    # Get all unique stay_ids and split them evenly
    stay_ids = df['stay_id'].unique()
    parts = np.array_split(stay_ids, len(stay_ids) // stays_per_subbatch + 1)

    # Save each sub-batch to a separate parquet file
    for i, ids in enumerate(parts):
        df_sub = df[df['stay_id'].isin(ids)].copy()
        output_path = os.path.join(OUTPUT_DIR, f'batch_{batch_idx:04d}_sub{i:02d}.parquet')
        df_sub.to_parquet(output_path)
        print(f"[OK] Sub-batch {i:02d} saved with {len(df_sub)} rows and {len(ids)} stay_ids.")

# Automatically find all processed batches (excluding samples)
batch_files = sorted([
    f for f in os.listdir(INPUT_DIR)
    if f.endswith('.parquet') and f.startswith('batch_') and '_amostra' not in f
])
batches = [int(f.split('_')[1].split('.')[0]) for f in batch_files]

# Process each batch
for batch_idx in batches:
    split_batch_by_stay(batch_idx, stays_per_subbatch=1000)

In [None]:
# Input and output directories
INPUT_DIR = 'output/falencia_batches_dividido'
OUTPUT_DIR = 'output/features_batches'
os.makedirs(OUTPUT_DIR, exist_ok=True)

def extract_features(df, batch_idx):
    """
    Extracts time-series statistical features for each stay_id in the dataframe.

    Parameters:
        df (DataFrame): Input time-series data.
        batch_idx (str): Identifier for the batch being processed (used for saving).
    """
    df = df.sort_values(['stay_id', 'charttime']).copy()

    # Select only relevant variables (excluding identifiers, flags, and imputation columns)
    variables = [
        col for col in df.columns
        if col not in ['stay_id', 'charttime', 'vasopressor_ativo', 'falencia']
        and not col.endswith('_imputed')
    ]

    print(f"Processing {len(variables)} variables for feature extraction.")

    results = []

    for sid, group in tqdm(df.groupby('stay_id'), desc=f"Batch {batch_idx}"):
        group = group.sort_values('charttime')
        feature_list = []

        for var in variables:
            try:
                n_meas = pd.notna(group[var]).astype(int).cumsum()
                min_val = group[var].cummin()
                max_val = group[var].cummax()
                mean_val = group[var].cumsum().fillna(0) / n_meas.replace(0, np.nan)
                instab = group[var].rolling(window=12, min_periods=1).std()
                intens = group[var].diff().abs()
                cumul = group[var].fillna(0).cumsum()

                df_var = pd.DataFrame({
                    f'n_meas_{var}': n_meas,
                    f'min_{var}': min_val,
                    f'max_{var}': max_val,
                    f'mean_{var}': mean_val,
                    f'{var}_instab': instab,
                    f'{var}_intens': intens,
                    f'{var}_cumul': cumul,
                })

                feature_list.append(df_var)

            except Exception as e:
                print(f"Error processing variable {var} for stay_id {sid}: {e}")

        if feature_list:
            df_features = pd.concat(feature_list, axis=1)
            group = pd.concat([group.reset_index(drop=True), df_features.reset_index(drop=True)], axis=1)

        results.append(group)
        del group, df_features, feature_list
        gc.collect()

    df_final = pd.concat(results, ignore_index=True)

    # Save full feature-enhanced batch
    output_path = os.path.join(OUTPUT_DIR, f'{batch_idx}.parquet')
    df_final.to_parquet(output_path)

    # Save a 10k-row checkpoint as CSV for quick inspection
    checkpoint_path = os.path.join(OUTPUT_DIR, f'{batch_idx}_checkpoint.csv')
    df_final.head(10000).to_csv(checkpoint_path, index=False)

    print(f"Saved full batch: {output_path}")
    print(f"Saved 10k-row checkpoint: {checkpoint_path}")
    del df_final
    gc.collect()

def process_feature_batch(filename):
    """
    Reads a single batch file and applies feature extraction.

    Parameters:
        filename (str): File name of the .parquet batch to process.
    """
    path = os.path.join(INPUT_DIR, filename)
    if not os.path.exists(path):
        print(f"File not found: {filename}")
        return

    print(f"\nProcessing feature extraction for: {filename}")
    df = pd.read_parquet(path)
    print(f"File loaded with {len(df)} rows.")

    batch_id = os.path.splitext(os.path.basename(filename))[0]
    extract_features(df, batch_id)

    del df
    gc.collect()

# Entry point: process all batch files in input directory
if __name__ == '__main__':
    batch_files = sorted([
        f for f in os.listdir(INPUT_DIR)
        if f.endswith('.parquet') and f.startswith('batch_')
    ])
    print(f"{len(batch_files)} batch files found for feature extraction.")

    for batch_file in batch_files:
        process_feature_batch(batch_file)

In [None]:
# Directories for input (raw feature batches) and output (filtered)
INPUT_DIR = 'output/features_batches'
OUTPUT_DIR = 'output/features_filtered'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Step 1: Calculate global missing percentage per column
print("Calculating global missing percentage...")

null_counts = {}
total_rows = 0

# List all input parquet files
files = sorted([
    f for f in os.listdir(INPUT_DIR)
    if f.endswith('.parquet') and f.startswith('batch_')
])

for file in files:
    df = pd.read_parquet(os.path.join(INPUT_DIR, file))
    total_rows += len(df)
    for col in df.columns:
        null_counts[col] = null_counts.get(col, 0) + df[col].isna().sum()

# Identify columns to drop (over 50% missing)
missing_percent = {col: null_counts[col] / total_rows for col in null_counts}
cols_to_drop = [
    col for col, pct in missing_percent.items()
    if pct > 0.5 and col not in ['stay_id', 'charttime']
]

print(f"{len(cols_to_drop)} columns will be removed (over 50% missing):")
for col in cols_to_drop:
    print(f"  - {col}")

# Step 2: Filter and save each file without the critical columns
print("\nSaving filtered files...")

for file in files:
    df = pd.read_parquet(os.path.join(INPUT_DIR, file))
    df_filtered = df.drop(columns=[col for col in cols_to_drop if col in df.columns])
    df_filtered.to_parquet(os.path.join(OUTPUT_DIR, file))
    print(f"{file} saved (filtered).")

print("\nFiltering completed successfully.")


In [None]:
FEATURE_DIR = Path("output/features_filtered")

# Identify all feature files excluding label files
parts = sorted([
    f for f in FEATURE_DIR.glob("batch_*.parquet")
    if "_labels" not in f.stem
])

def generate_labels_for_part(part_path):
    """
    Extracts and saves label-related columns for a given feature file.
    """
    df = pd.read_parquet(part_path)

    # Ensure charttime is datetime and calculate relative time in seconds
    df['charttime'] = pd.to_datetime(df['charttime'])
    df['rel_charttime'] = (df['charttime'] - df.groupby('stay_id')['charttime'].transform('min')).dt.total_seconds()

    # Extract columns for label set
    df_label = df[['stay_id', 'rel_charttime', 'falencia', 'vasopressor_ativo']]

    # Save label file alongside original
    label_path = part_path.parent / f"{part_path.stem}_labels.parquet"
    df_label.to_parquet(label_path, index=False)
    return label_path

# Generate all label files
label_paths = [generate_labels_for_part(p) for p in parts]

In [None]:
FEATURE_DIR = Path("output/features_filtered")

# List all feature and label files
parts = sorted([
    f for f in FEATURE_DIR.glob("batch_*.parquet")
    if "_labels" not in f.stem
])
labels = sorted([
    f for f in FEATURE_DIR.glob("batch_*_labels.parquet")
])

# Load one file as reference to inspect columns
example_df = pd.read_parquet(parts[0])

# Define columns to exclude from modeling
exclude_cols = ['stay_id', 'charttime', 'falencia', 'vasopressor_ativo']
exclude_cols += [col for col in example_df.columns if col.endswith('_imputed')]

# Define final set of feature columns
output_cols = [col for col in example_df.columns if col not in exclude_cols]
print(f"Total selected feature columns: {len(output_cols)}")
print(output_cols)

# Save selected feature column names
pd.Series(output_cols).to_csv("output/features_names.csv", index=False)


In [None]:
def to_ml(
    save_path,
    parts,
    labels,
    endpoint_names,
    output_cols,
    fill_string='ffill',
    split_path=None,
    random_seed=42
):
    """
    Prepares time series data in HDF5 format for ML training using a standard time grid,
    pre-normalization, and label alignment. Supports optional split file or random split.
    """

    # ---------------------------------------------------------
    # Split helpers
    # ---------------------------------------------------------
    def get_splits(df, split_path, random_seed):
        if split_path:
            split_df = pd.read_csv(split_path, sep='\t')
            return {
                split: split_df.loc[split_df['split'] == split, 'stay_id'].values
                for split in split_df['split'].unique()
            }
        else:
            all_ids = np.unique(df['stay_id'])
            train_val, test = train_test_split(all_ids, test_size=0.15, random_state=random_seed)
            train, val = train_test_split(train_val, test_size=0.1765, random_state=random_seed)
            return {'train': train, 'val': val, 'test': test}

    def get_windows_split(df_split, offset=0):
        pid_array = df_split['stay_id']
        starts = sorted(np.unique(pid_array, return_index=True)[1])
        stops = np.concatenate([starts[1:], [df_split.shape[0]]])
        ids = pid_array.values[starts]
        return np.stack([np.array(starts) + offset, np.array(stops) + offset, ids], axis=1)

    # ---------------------------------------------------------
    # HDF5 write helpers
    # ---------------------------------------------------------
    def save_to_h5_incremental(f, group_name, split, data, chunk_size=100_000):
        if not f.__contains__('/data'):
            n_data = f.create_group("/", 'data', 'Dataset')
        else:
            n_data = f.get_node('/data')

        if not n_data.__contains__(split):
            atom = tables.Atom.from_dtype(data.dtype)
            ea = f.create_earray(n_data, split, atom=atom, shape=(0, data.shape[1]), expectedrows=10**7)
        else:
            ea = n_data.get_node(split)

        for i in range(0, len(data), chunk_size):
            ea.append(data[i:i+chunk_size])

    def save_labels_incremental(f, task_names, split, labels_array, chunk_size=100_000):
        if not f.__contains__('/labels'):
            labels_group = f.create_group("/", 'labels', 'Labels')
            f.create_array(labels_group, 'tasks', obj=[str(k).encode('utf-8') for k in task_names])
        else:
            labels_group = f.get_node('/labels')

        if not labels_group.__contains__(split):
            atom = tables.Atom.from_dtype(labels_array.dtype)
            ea = f.create_earray(labels_group, split, atom=atom, shape=(0, labels_array.shape[1]), expectedrows=10**7)
        else:
            ea = labels_group.get_node(split)

        for i in range(0, len(labels_array), chunk_size):
            ea.append(labels_array[i:i+chunk_size])

    def save_windows_incremental(f, split, windows_array, chunk_size=100_000):
        if not f.__contains__('/patient_windows'):
            pw_group = f.create_group("/", 'patient_windows', 'Windows')
        else:
            pw_group = f.get_node('/patient_windows')

        if not pw_group.__contains__(split):
            atom = tables.Atom.from_dtype(windows_array.dtype)
            ea = f.create_earray(pw_group, split, atom=atom, shape=(0, windows_array.shape[1]), expectedrows=10**7)
        else:
            ea = pw_group.get_node(split)

        for i in range(0, len(windows_array), chunk_size):
            ea.append(windows_array[i:i+chunk_size])

    # ---------------------------------------------------------
    # Split setup
    # ---------------------------------------------------------
    print("🔍 Generating data splits...")
    df_all_ids = pd.read_parquet(parts[0], columns=["stay_id", "charttime"])
    split_ids = get_splits(df_all_ids, split_path, random_seed)
    print(f"✅ Splits created: { {k: len(v) for k, v in split_ids.items()} }")

    # ---------------------------------------------------------
    # Compute mean and std for normalization
    # ---------------------------------------------------------
    print("📊 Computing global mean/std for normalization...")
    n_total = 0
    mean_total = None
    M2 = None

    for p in parts:
        df = pd.read_parquet(p, columns=output_cols)
        df = df.replace([np.inf, -np.inf], np.nan).dropna()
        batch_n = len(df)
        batch_mean = df.mean()
        batch_var = df.var(ddof=0)

        if mean_total is None:
            mean_total = batch_mean
            M2 = batch_var * batch_n
        else:
            delta = batch_mean - mean_total
            mean_total += delta * batch_n / (n_total + batch_n)
            M2 += batch_var * batch_n + (delta ** 2) * (n_total * batch_n) / (n_total + batch_n)

        n_total += batch_n

    means = mean_total
    stds = (M2 / n_total).apply(np.sqrt)
    print("✅ Normalization parameters computed.")

    # ---------------------------------------------------------
    # Process and write HDF5
    # ---------------------------------------------------------
    print("🔄 Saving to HDF5 incrementally...")
    offset = {'train': 0, 'val': 0, 'test': 0}

    with tables.open_file(save_path, 'w') as f:
        for i, (p, l) in enumerate(zip(parts, labels)):
            print(f"📦 Processing batch {i+1}/{len(parts)}: {p.name}")

            df = pd.read_parquet(p)
            df_label = pd.read_parquet(l).rename(columns={'rel_charttime': 'charttime'})

            df['charttime'] = (pd.to_datetime(df['charttime']) - pd.Timestamp("1970-01-01")) / pd.Timedelta(minutes=1)
            df_label['charttime'] = pd.to_datetime(df_label['charttime'], errors='coerce')

            # Normalize features
            df[output_cols] = (df[output_cols] - means) / stds
            df = df.fillna(0.0)
            df_label = df_label.fillna(0.0)

            df_label = df_label[['stay_id', 'falencia']].copy()
            df_label['falencia'] = df_label['falencia'].astype(np.float32).clip(0, 1)

            for split in ['train', 'val', 'test']:
                df_split = df[df['stay_id'].isin(split_ids[split])]
                df_label_split = df_label[df_label['stay_id'].isin(split_ids[split])]

                if df_split.empty:
                    continue

                # Handle missing columns
                missing_cols = [c for c in output_cols if c not in df_split.columns]
                for c in missing_cols:
                    df_split[c] = 0.0

                win = get_windows_split(df_split, offset=offset[split])
                features_array = df_split[output_cols].astype(np.float32).values
                labels_array = df_label_split.drop(columns=['stay_id']).values.astype(np.float32)

                save_to_h5_incremental(f, 'data', split, features_array)
                save_labels_incremental(f, endpoint_names, split, labels_array)
                save_windows_incremental(f, split, win.astype(np.int32))

                offset[split] += df_split.shape[0]

            gc.collect()

    print(f"✅ Finished! Saved to {save_path}")

to_ml(
    save_path=Path("output/dataset.h5"),
    parts=parts,
    labels=labels,
    endpoint_names=["falencia"],
    output_cols=output_cols,
    fill_string="ffill",
    split_path=None,
    random_seed=42
)