# Eligibility for mobilization: Cohort ID and Discretizing script

Author: Kaveri Chhikara

This script identifies the cohort using CLIF 2.1 tables and discretizes the dataset at an hourly level

 
                        🚨Code will break if the following requirements are not satisfied🚨  
#### Requirements

* Required table filenames should be `clif_patient`, `clif_hospitalization`, `clif_adt`, `clif_vitals`, `clif_labs`, `clif_medication_admin_continuous`, `clif_respiratory_support` 
* Within each table, the following variables and categories are required.

| Table Name | Required Variables | Required Categories |
| --- | --- | --- |
| **patient** | `patient_id`, `race_category`, `ethnicity_category`, `sex_category`, `death_dttm` | - |
| **hospitalization** | `patient_id`, `hospitalization_id`, `admission_dttm`, `discharge_dttm`, `age_at_admission` | - |
| **adt** |  `hospitalization_id`, `hospital_id`,`in_dttm`, `out_dttm`, `location_category` | - |
| **vitals** | `hospitalization_id`, `recorded_dttm`, `vital_category`, `vital_value` | heart_rate, resp_rate, sbp, dbp, map, spo2, weight_kg, height_cm |
| **labs** | `hospitalization_id`, `lab_result_dttm`, `lab_category`, `lab_value` | lactate, creatinine, bilirubin_total, po2_arterial, platelet_count |
| **medication_admin_continuous** | `hospitalization_id`, `admin_dttm`, `med_name`, `med_category`, `med_dose`, `med_dose_unit` | norepinephrine, epinephrine, phenylephrine, vasopressin, dopamine, angiotensin(optional), nicardipine, nitroprusside, clevidipine, cisatracurium, vecuronium, rocuronium |
| **respiratory_support** | `hospitalization_id`, `recorded_dttm`, `device_category`, `mode_category`, `tracheostomy`, `fio2_set`, `lpm_set`, `resp_rate_set`, `peep_set`, `resp_rate_obs`, `tidal_volume_set`, `pressure_control_set`, `pressure_support_set`, `peak_inspiratory_pressure_set`, `tidal_volume_obs` | - |
| **crrt_therapy** | `hospitalization_id`, `recorded_dttm` | - |


## Load Libraries

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import shutil
from datetime import datetime
import json
import pyCLIF
from datetime import timedelta
import pyarrow
import sofa_score
import waterfall
import warnings
warnings.filterwarnings('ignore')

## import outlier json
with open('../config/outlier_config.json', 'r', encoding='utf-8') as f:
    outlier_cfg = json.load(f)

In [None]:
## ── Output Folder Management ──
print("=== Output Folder Management ===")

output_folder = '../output'
output_old_folder = '../output_old'

# Check if output folder exists
if os.path.exists(output_folder):
    print(f"Existing output folder found: {output_folder}")
    
    # If output_old already exists, remove it first
    if os.path.exists(output_old_folder):
        print(f"Removing existing output_old folder...")
        shutil.rmtree(output_old_folder)
    
    # Rename current output to output_old
    print(f"Renaming {output_folder} → {output_old_folder}")
    os.rename(output_folder, output_old_folder)
    
    # Log what was backed up
    if os.path.exists(output_old_folder):
        backup_size = sum(
            os.path.getsize(os.path.join(dirpath, filename))
            for dirpath, dirnames, filenames in os.walk(output_old_folder)
            for filename in filenames
        ) / (1024 * 1024)  # Convert to MB
        print(f"Backup created: {backup_size:.1f} MB")

# Create fresh output directory structure
print(f"Creating fresh output directory structure...")
os.makedirs(output_folder, exist_ok=True)
os.makedirs(f'{output_folder}/final', exist_ok=True)
os.makedirs(f'{output_folder}/intermediate', exist_ok=True)
# Create empty output files
with open(f'{output_folder}/final/final_output.txt', 'w') as f:
    pass
with open(f'{output_folder}/intermediate/intermediate.txt', 'w') as f:
    pass

# Create graphs subfolder
graphs_folder = f'{output_folder}/final/graphs'
os.makedirs(graphs_folder, exist_ok=True)

print(f"Output directory structure ready:")
print(f"   {output_folder}/")
print(f"   ├── final/")
print(f"   │   └── graphs/")
print(f"   └── intermediate/")

## Required columns and categories

In [None]:
rst_required_columns = [
    'hospitalization_id',
    'recorded_dttm',
    'device_name',
    'device_category',
    'mode_name', 
    'mode_category',
    'tracheostomy',
    'fio2_set',
    'lpm_set',
    'resp_rate_set',
    'peep_set',
    'resp_rate_obs',
    'tidal_volume_set', 
    'pressure_control_set',
    'pressure_support_set',
    'peak_inspiratory_pressure_set'

]

vitals_required_columns = [
    'hospitalization_id',
    'recorded_dttm',
    'vital_category',
    'vital_value'
]
vitals_of_interest = ['heart_rate', 'respiratory_rate', 'sbp', 'dbp', 'map', 'spo2', 'weight_kg', 'height_cm']

labs_required_columns = [
    'hospitalization_id',
    'lab_result_dttm',
    'lab_category',
    'lab_value',
    'lab_value_numeric'
]
labs_of_interest = ['lactate']

meds_required_columns = [
    'hospitalization_id',
    'admin_dttm',
    'med_name',
    'med_category',
    'med_dose',
    'med_dose_unit'
]
meds_of_interest = [
    'norepinephrine', 'epinephrine', 'phenylephrine', 'vasopressin',
    'dopamine', 'angiotensin', 'nicardipine', 'nitroprusside',
    'clevidipine', 'cisatracurium', 'vecuronium', 'rocuronium '
]

## Load data

In [None]:
patient = pyCLIF.load_data('clif_patient')
hospitalization = pyCLIF.load_data('clif_hospitalization')
adt = pyCLIF.load_data('clif_adt')

# ensure id variable is of dtype character
hospitalization['hospitalization_id']= hospitalization['hospitalization_id'].astype(str)
patient['patient_id']= patient['patient_id'].astype(str)
adt['hospitalization_id']= adt['hospitalization_id'].astype(str)

## Duplicate check

If duplicates exist, only the first row is preserved after arranging the data by time. Please check your CLIF tables if there are duplicates. 

In [None]:
# check for duplicates
# patient table should be unique by patient id
patient = pyCLIF.remove_duplicates(patient, ['patient_id'], 'patient')
# hospitalization table should be unique by hospitalization id
hospitalization = pyCLIF.remove_duplicates(hospitalization, ['hospitalization_id'], 'hospitalization')
# adt table should be unique by hospitalization id and in dttm
adt = pyCLIF.remove_duplicates(adt, ['hospitalization_id', 'hospital_id', 'in_dttm'], 'adt')

In [None]:
print(f"Total Number of unique encounters in the hospitalization table: {pyCLIF.count_unique_encounters(hospitalization, 'hospitalization_id')}")

In [None]:
# Standardize all _dttm variables to the same format
patient = pyCLIF.convert_datetime_columns_to_site_tz(patient,  pyCLIF.helper['timezone'])
hospitalization = pyCLIF.convert_datetime_columns_to_site_tz(hospitalization, pyCLIF.helper['timezone'])
adt = pyCLIF.convert_datetime_columns_to_site_tz(adt,  pyCLIF.helper['timezone'])

## Cohort Identification

**Inclusion Criteria:**

* Adult admissions between 2018-01-01 and 2024-12-31
* Encounters receiving invasive mechanical ventilation during this period

**Exclusion criteria:**

1. Encounters that were on vent for less than 4 hours in the first 72 hours of first intubation
2. Encounters that were on trach at the time of intubation

In [None]:
# setting up a dictionary to keep track of STROBE counts
strobe_counts = {}

#### (A) Date and Age Filter

In [None]:
# STEP A: Basic Data Cleaning + Date/Age Filter
#   - Filter hospitalization for date range & adult patients
#   - Then reduce ADT to those hospitalization_ids
print("\n=== STEP A: Filter by date range & age ===\n")
date_mask = (hospitalization['admission_dttm'] >= '2018-01-01') & \
            (hospitalization['admission_dttm'] <= '2024-12-31')
age_mask = (hospitalization['age_at_admission'] >= 18)

if pyCLIF.helper['site_name'].lower() == 'mimic':
    hospitalization_cohort = hospitalization[age_mask].copy()
else:
    hospitalization_cohort = hospitalization[date_mask & age_mask].copy()

strobe_counts['A_after_date_age_filter'] = hospitalization_cohort['hospitalization_id'].nunique()
print(f"Number of unique hospitalizations after date & age filter: {strobe_counts['A_after_date_age_filter']}")

In [None]:
# Get total unique hospitalizations without time filter, only age filter
age_mask = (hospitalization['age_at_admission'] >= 18)
total_adult_hospitalizations = hospitalization[age_mask]['hospitalization_id'].nunique()
strobe_counts['A_after_age_filter'] = total_adult_hospitalizations
print(f"\nTotal number of unique adult hospitalizations (no date filter): {total_adult_hospitalizations}")

In [None]:
strobe_counts

#### (B) Stitch hospitalizations

Combine multiple `hospitalization_ids` into a single `encounter_block` for patients who transfer between hospital campuses or return soon after discharge. Hospitalizations that have a gap of **6 hours or less** between the discharge dttm and admission dttm are put in one encounter block.

In [None]:
# Filter ADT to only those in the cohort set
cohort_ids = hospitalization_cohort['hospitalization_id'].unique().tolist()
adt_cohort = adt[adt['hospitalization_id'].isin(cohort_ids)].copy()

In [None]:
# Check for missing values in admission and discharge dates
print("\nMissing values in admission_dttm:", hospitalization_cohort['admission_dttm'].isna().sum())
print("Missing values in discharge_dttm:", hospitalization_cohort['discharge_dttm'].isna().sum())

In [None]:
# STEP B: Stitch Encounters => 'encounter_block'
# Use stitch_encounters from pyCLIF with time_interval=6
print("\n=== STEP B: Stitch encounters ===\n")
stitched_cohort = pyCLIF.stitch_encounters(hospitalization_cohort, adt_cohort, time_interval=6)

In [None]:
# stitched_cohort now has: 'patient_id','hospitalization_id','encounter_block', discharge category and other ADT variables. This will have duplicate rows because of location category
# We only want 1 row per unique encounter_block for the next steps.
stitched_unique = stitched_cohort[['patient_id', 'encounter_block']].drop_duplicates()

strobe_counts['B_before_stitching'] = stitched_cohort['hospitalization_id'].nunique()
strobe_counts['B_after_stitching'] = stitched_unique['encounter_block'].nunique()
strobe_counts['B_stitched_hosp_ids'] = strobe_counts['B_before_stitching']-strobe_counts['B_after_stitching']
print(f"Number of unique hospitalizations before stitching: {stitched_cohort['hospitalization_id'].nunique()}")
print(f"Number of unique encounter blocks after stitching: {strobe_counts['B_after_stitching']}")
print(f"Number of linked hospitalization ids: {strobe_counts['B_before_stitching']-strobe_counts['B_after_stitching']}")

In [None]:
# Mapping of patient id, hospitalization id and encounter blocks
all_ids = stitched_cohort[['patient_id', 'hospitalization_id', 'encounter_block', 'discharge_category', 'discharge_dttm']].drop_duplicates()
print("\nUnique values in each column:")
for col in all_ids.columns[:3]:
    print(f"\n{col}:")
    print(all_ids[col].nunique())

#### (C) Identify ventilator usage

Filter down to encounters that received invasive mechanical ventilation

In [None]:
# STEP C: Identify Ventilator Usage
# Load respiratory support only for the relevant “hospitalization_id” set
# These hospitalizations map to an encounter_block for final grouping.

print("\n=== STEP C: Load & process respiratory support => Apply Waterfall & Identify IMV usage ===\n")

# 1) Load respiratory support
resp_support_raw = pyCLIF.load_data(
    'clif_respiratory_support',
    columns=rst_required_columns,
    filters={'hospitalization_id': all_ids['hospitalization_id'].unique().tolist()}
)

resp_support = resp_support_raw.copy()
resp_support['device_category'] = resp_support['device_category'].str.lower()
resp_support['mode_category'] = resp_support['mode_category'].str.lower()
resp_support['lpm_set'] = pd.to_numeric(resp_support['lpm_set'], errors='coerce')
resp_support['resp_rate_set'] = pd.to_numeric(resp_support['resp_rate_set'], errors='coerce')
resp_support['peep_set'] = pd.to_numeric(resp_support['peep_set'], errors='coerce')
resp_support['resp_rate_obs'] = pd.to_numeric(resp_support['resp_rate_obs'], errors='coerce')
resp_support = resp_support.sort_values(['hospitalization_id', 'recorded_dttm'])
# del resp_support_raw

print("\n=== Apply outlier thresholds ===\n")
resp_support['fio2_set'] = pd.to_numeric(resp_support['fio2_set'], errors='coerce')
# (Optional) If FiO2 is >1 on average => scale by /100
fio2_mean = resp_support['fio2_set'].mean(skipna=True)
# If the mean is greater than 1, divide 'fio2_set' by 100
if fio2_mean and fio2_mean > 1.0:
    # Only divide values greater than 1 to avoid re-dividing already correct values
    resp_support.loc[resp_support['fio2_set'] > 1, 'fio2_set'] = \
        resp_support.loc[resp_support['fio2_set'] > 1, 'fio2_set'] / 100
    print("Updated fio2_set to be between 0.21 and 1")
else:
    print("FIO2_SET mean=", fio2_mean, "is within the required range")

Respiratory Support Summary

In [None]:
results_list = []
group_cols = 'device_category'  # or a list like ['device_category','mode_category']
numeric_cols = ['fio2_set','peep_set','lpm_set', 'resp_rate_set', 'resp_rate_obs']

for col in numeric_cols:
    tmp = pyCLIF.create_summary_table(
        df=resp_support,
        numeric_col=col,
        group_by_cols=group_cols
    )
    #   ['device_category','N','missing','min','q25','median','q75','mean','max']
    # Insert a "variable" column next to the group-by columns:
    tmp['variable'] = col
    # We want "device_category" (the group col), then "variable", then the rest
    if isinstance(group_cols, str):
        group_cols_list = [group_cols]  # unify into list
    else:
        group_cols_list = group_cols  # already a list
    # Reorder so that group-by columns come first, then 'variable', then the rest
    front_cols = group_cols_list + ['variable']
    # Build the list of remaining columns
    rest_cols = [c for c in tmp.columns if c not in front_cols]
    new_cols = front_cols + rest_cols
    tmp = tmp[new_cols]
    results_list.append(tmp)

# concatenate all results
final_summary_resp_support = pd.concat(results_list, ignore_index=True)
final_summary_resp_support.to_csv('../output/final/summary_respiratory_support_by_device.csv', index=False)

In [None]:
results_list = []
group_cols = ['device_category','mode_category']
numeric_cols = ['fio2_set','peep_set','lpm_set', 'resp_rate_set', 'resp_rate_obs']

for col in numeric_cols:
    tmp = pyCLIF.create_summary_table(
        df=resp_support,
        numeric_col=col,
        group_by_cols=group_cols
    )
    # tmp might have columns:
    #   ['device_category','N','missing','min','q25','median','q75','mean','max']
    # Insert a "variable" column next to the group-by columns:
    tmp['variable'] = col
    # We want "device_category" (the group col), then "variable", then the rest
    if isinstance(group_cols, str):
        group_cols_list = [group_cols]  # unify into list
    else:
        group_cols_list = group_cols  # already a list
    # Reorder so that group-by columns come first, then 'variable', then the rest
    front_cols = group_cols_list + ['variable']
    # Build the list of remaining columns
    rest_cols = [c for c in tmp.columns if c not in front_cols]
    new_cols = front_cols + rest_cols
    tmp = tmp[new_cols]
    results_list.append(tmp)

# Finally, concatenate all results
final_summary_resp_support = pd.concat(results_list, ignore_index=True)
final_summary_resp_support.to_csv('../output/final/summary_respiratory_support_by_device_mode.csv', index=False)

##### (C.1) Waterfall

In [None]:
## Identify encounters on IMV
# Create mask to identify IMV entries
imv_mask = resp_support['device_category'].str.contains("imv", case=False, na=False)

# Get unique hospitalization_ids with at least one IMV entry
resp_stitched_imv_ids = resp_support[imv_mask][['hospitalization_id']].drop_duplicates()

# Filter the full table to just these hospitalization_ids
resp_support_filtered = resp_support[
    resp_support["hospitalization_id"].isin(resp_stitched_imv_ids["hospitalization_id"])
].reset_index(drop=True)

# filter down to only those hospitalization_ids that are in the cohort
all_ids = all_ids[all_ids['hospitalization_id'].isin(resp_support_filtered['hospitalization_id'].unique())]

In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="pandas")

processed_resp_support = waterfall.process_resp_support_waterfall(resp_support_filtered, 
                                                        id_col = "hospitalization_id",
                                                        verbose = True)

processed_resp_support = pyCLIF.convert_datetime_columns_to_site_tz(processed_resp_support, pyCLIF.helper['timezone'])
processed_resp_support.to_parquet('../output/intermediate/processed_resp_support.parquet', index=False)
# processed_resp_support = pd.read_parquet('../output_old/intermediate/processed_resp_support.parquet')

In [None]:
# Merge to get encounter_block for the cohort identified so far
resp_stitched = processed_resp_support.merge(
    all_ids[['hospitalization_id','encounter_block']],
    on='hospitalization_id', how='right'
)

print("Missing values in recorded_dttm:", resp_stitched['recorded_dttm'].isna().sum())

In [None]:
pyCLIF.apply_outlier_thresholds(resp_stitched, 'fio2_set', *outlier_cfg['fio2_set'])
pyCLIF.apply_outlier_thresholds(resp_stitched, 'peep_set', *outlier_cfg['peep_set'])
pyCLIF.apply_outlier_thresholds(resp_stitched, 'lpm_set',  *outlier_cfg['lpm_set'])
pyCLIF.apply_outlier_thresholds(resp_stitched, 'resp_rate_set', *outlier_cfg['resp_rate_set'])
pyCLIF.apply_outlier_thresholds(resp_stitched, 'resp_rate_obs', *outlier_cfg['resp_rate_obs'])

In [None]:
# fill values of fio2_set if the device is nasal cannula, and lpm_set is available
# https://www.respiratorytherapyzone.com/oxygen-flow-rate-fio2/
resp_stitched = pyCLIF.impute_fio2_from_nasal_cannula_flow(resp_stitched)

In [None]:
# 4) Identify IMV
imv_mask = resp_stitched['device_category'].str.contains("imv", case=False, na=False)
resp_stitched_imv = resp_stitched[imv_mask].copy()
# this creates a on vent field for everytime the patient is on a vent
# Create on_vent column for IMV records
resp_stitched_imv['on_vent'] = 1

# Left join back to full resp_stitched to include non-vent records
resp_stitched_final = resp_stitched.merge(
    resp_stitched_imv[['hospitalization_id', 'recorded_dttm', 'on_vent']], 
    on=['hospitalization_id', 'recorded_dttm'],
    how='left'
)

# Fill NaN values with 0 for times when not on vent
resp_stitched_final['on_vent'] = resp_stitched_final['on_vent'].fillna(0)
strobe_counts['C_imv_hospitalizations'] = resp_stitched_final['hospitalization_id'].nunique()
strobe_counts['C_imv_encounter_blocks'] = resp_stitched_final['encounter_block'].nunique()

print(f"Total IMV respiratory support hospitalizations: {strobe_counts['C_imv_hospitalizations']}")
print(f"Total IMV respiratory support encounter blocks: {strobe_counts['C_imv_encounter_blocks']}")

In [None]:
all_ids =  all_ids[all_ids['encounter_block'].isin(resp_stitched_final['encounter_block'].unique())]
all_ids = all_ids[all_ids['hospitalization_id'].isin(resp_stitched_final['hospitalization_id'].unique())]

In [None]:
for col in all_ids.columns[:3]:
    print(f"\n{col}:")
    print(all_ids[col].nunique())

#### (D) Vent start and end times 

Calculate vent start times for the first episode of invasive mechanical intubation.   
Limitation: the vent end time might not be associated with the same intubation episode.

In [None]:
# STEP E: Determine Vent Start/End for Each Hospitalization and Encounter block

print("\n=== STEP D: Determine ventilation times (start/end) at d encounter block level ===\n")

# at the hospitalization id level
vent_start_end = resp_stitched_imv.groupby('hospitalization_id').agg(
    vent_start_time=('recorded_dttm','min'),
    vent_end_time=('recorded_dttm','max')
).reset_index()

# Exclude edge case: if start_time == end_time 
# these would otherwise have been excluded when we remove encounters on vent for less than 4 hours
check_same_vent_start_end = vent_start_end[vent_start_end['vent_start_time'] == vent_start_end['vent_end_time']].copy()
vent_start_end= vent_start_end[vent_start_end['vent_start_time'] != vent_start_end['vent_end_time']].copy()

strobe_counts['D_hospitalizations_with_valid_vent'] = vent_start_end['hospitalization_id'].nunique()
strobe_counts['D_hospitalizations_with_same_vent_start_end'] = check_same_vent_start_end['hospitalization_id'].nunique()
print(f"Unique hospitalizations with valid IMV start/end: {strobe_counts['D_hospitalizations_with_valid_vent']}")

# at the block level
block_vent_times = resp_stitched_imv.groupby('encounter_block', dropna=True).agg(
    block_vent_start_dttm=('recorded_dttm','min'),
    block_vent_end_dttm=('recorded_dttm','max')
).reset_index()

# If start==end, no real vent- there was just ONE vent entry
block_same_vent = block_vent_times[block_vent_times['block_vent_start_dttm']==block_vent_times['block_vent_end_dttm']].copy()
block_vent_times = block_vent_times[block_vent_times['block_vent_start_dttm']!=block_vent_times['block_vent_end_dttm']].copy()

strobe_counts['D_blocks_with_valid_vent'] = block_vent_times['encounter_block'].nunique()
strobe_counts['D_blocks_with_same_vent_start_end'] = block_same_vent['encounter_block'].nunique()
print(f"Unique encounter blocks with valid IMV start/end: {strobe_counts['D_blocks_with_valid_vent']}")

valid_blocks_vent = block_vent_times['encounter_block'].unique()

In [None]:
strobe_counts

In [None]:
# Filter all_ids to only keep rows where encounter_block is in valid_blocks_vent
all_ids = all_ids[all_ids['encounter_block'].isin(valid_blocks_vent)]

In [None]:
for col in all_ids.columns[:3]:
    print(f"\n{col}:")
    print(all_ids[col].nunique())

#### (E) Hourly Sequence 

This section achieves the following steps:  
* Identifies the first and last recorded times for vitals for each encounter block
* These times are used to generate an hourly sequence of patients hospitalization journey
* Combines with hourly vent usage data from the respiratory support table
* Excludes encounters on vent for less than 4 hours in the first 72 hours
* Creates a final dataframe with the identified cohort

In [None]:
# Generate Hourly Sequence & Exclude encounter blocks with <4 Vent Hours
#  Create an hourly timeline from vent_start to last vital or outcome time for each encounter block
# We stop operating at hospitalization id level 

print("\n=== STEP E: Hourly sequence generation & < 4 hour vent exclusion BLOCK level===\n")

# 1) define the 'end_time' for the sequence from vitals or outcome.
vitals_cohort = pyCLIF.load_data('clif_vitals',
    columns=vitals_required_columns,
    filters={'hospitalization_id': all_ids['hospitalization_id'].unique().tolist(), 
             'vital_category': vitals_of_interest}
)
vitals_cohort = pyCLIF.convert_datetime_columns_to_site_tz(vitals_cohort, pyCLIF.helper['timezone'])
vitals_cohort['vital_value'] = pd.to_numeric(vitals_cohort['vital_value'], errors='coerce')
# sort vitals cohort by hospitalization_id and recorded_dttm
vitals_cohort = vitals_cohort.sort_values(['hospitalization_id', 'recorded_dttm'])

In [None]:
# Replace outliers with NAs in the vitals table 
# Extract min/max values from config for each vital
min_hr, max_hr = outlier_cfg['heart_rate']
min_rr, max_rr = outlier_cfg['respiratory_rate'] 
min_sbp, max_sbp = outlier_cfg['sbp']
min_dbp, max_dbp = outlier_cfg['dbp']
min_map, max_map = outlier_cfg['map']
min_spo2, max_spo2 = outlier_cfg['spo2']
min_weight, max_weight = outlier_cfg['weight_kg']
min_height, max_height = outlier_cfg['height_cm']

# For each vital category, set out-of-range values to NaN
is_hr = vitals_cohort['vital_category'] == 'heart_rate'
vitals_cohort.loc[is_hr & (vitals_cohort['vital_value'] < min_hr), 'vital_value'] = np.nan
vitals_cohort.loc[is_hr & (vitals_cohort['vital_value'] > max_hr), 'vital_value'] = np.nan

is_rr = vitals_cohort['vital_category'] == 'respiratory_rate'
vitals_cohort.loc[is_rr & (vitals_cohort['vital_value'] < min_rr), 'vital_value'] = np.nan
vitals_cohort.loc[is_rr & (vitals_cohort['vital_value'] > max_rr), 'vital_value'] = np.nan

is_sbp = vitals_cohort['vital_category'] == 'sbp'
vitals_cohort.loc[is_sbp & (vitals_cohort['vital_value'] < min_sbp), 'vital_value'] = np.nan
vitals_cohort.loc[is_sbp & (vitals_cohort['vital_value'] > max_sbp), 'vital_value'] = np.nan

is_dbp = vitals_cohort['vital_category'] == 'dbp'
vitals_cohort.loc[is_dbp & (vitals_cohort['vital_value'] < min_dbp), 'vital_value'] = np.nan
vitals_cohort.loc[is_dbp & (vitals_cohort['vital_value'] > max_dbp), 'vital_value'] = np.nan

is_map = vitals_cohort['vital_category'] == 'map'
vitals_cohort.loc[is_map & (vitals_cohort['vital_value'] < min_map), 'vital_value'] = np.nan
vitals_cohort.loc[is_map & (vitals_cohort['vital_value'] > max_map), 'vital_value'] = np.nan

is_spo2 = vitals_cohort['vital_category'] == 'spo2'
vitals_cohort.loc[is_spo2 & (vitals_cohort['vital_value'] < min_spo2), 'vital_value'] = np.nan
vitals_cohort.loc[is_spo2 & (vitals_cohort['vital_value'] > max_spo2), 'vital_value'] = np.nan

is_weight = vitals_cohort['vital_category'] == 'weight_kg'
vitals_cohort.loc[is_weight & (vitals_cohort['vital_value'] < min_weight), 'vital_value'] = np.nan
vitals_cohort.loc[is_weight & (vitals_cohort['vital_value'] > max_weight), 'vital_value'] = np.nan

is_height = vitals_cohort['vital_category'] == 'height_cm'
vitals_cohort.loc[is_height & (vitals_cohort['vital_value'] < min_height), 'vital_value'] = np.nan
vitals_cohort.loc[is_height & (vitals_cohort['vital_value'] > max_height), 'vital_value'] = np.nan

In [None]:
summary_vitals = pyCLIF.create_summary_table(
        df=vitals_cohort,
        numeric_col='vital_value',
        group_by_cols='vital_category'
    )
summary_vitals.to_csv('../output/final/summary_vitals_by_category.csv', index=False)

In [None]:
# Merge to get encounter_block on each vital
vitals_stitched = vitals_cohort.merge(all_ids, on='hospitalization_id', how='left')
# Group by block => find earliest & latest vital for that block
vital_bounds_block = vitals_stitched.groupby('encounter_block', dropna=True)['recorded_dttm'].agg(['min','max']).reset_index()
vital_bounds_block.columns = ['encounter_block','block_first_vital_dttm','block_last_vital_dttm']

# 2) Merge block_vent_times with vital_bounds_block
final_blocks = block_vent_times.merge(vital_bounds_block, on='encounter_block', how='inner')

In [None]:
# 3) If block_last_vital_dttm < vent_start_time => weird edge case. Ideally shouldn't happen. 
# If such bad blocks exist, check your CLIF tables bro
bad_block = final_blocks[final_blocks['block_last_vital_dttm'] < final_blocks['block_vent_start_dttm']]
strobe_counts['E_blocks_with_vent_end_before_vital_start'] = bad_block['encounter_block'].nunique()
if len(bad_block) > 0:
    print("Warning: Some blocks have last vital < vent start:\n", len(bad_block))
else:
    print("There are no bad blocks! Good job CLIF-ing")

In [None]:
# 4) Generate the hourly sequence at block level
def generate_hourly_sequence_block(group):
    blk = group.name  # use group name from groupby
    start_time = group['block_vent_start_dttm'].iloc[0]
    end_time   = group['block_last_vital_dttm'].iloc[0]
    hourly_timestamps = pd.date_range(start=start_time, end=end_time, freq='h')
    return pd.DataFrame({
        'encounter_block': blk,
        'recorded_dttm': hourly_timestamps
    })

with warnings.catch_warnings():
    warnings.simplefilter("ignore", DeprecationWarning)
    hourly_seq_block = (
    final_blocks
    .groupby('encounter_block')
    .apply(generate_hourly_sequence_block)
    .reset_index(drop=True)
    )
hourly_seq_block = hourly_seq_block.reset_index(drop=True)

hourly_seq_block['recorded_date'] = hourly_seq_block['recorded_dttm'].dt.date
hourly_seq_block['recorded_hour'] = hourly_seq_block['recorded_dttm'].dt.hour
hourly_seq_block = hourly_seq_block.drop(columns=['recorded_dttm'])
hourly_seq_block = hourly_seq_block.drop_duplicates(subset=['encounter_block', 'recorded_date', 'recorded_hour'])

In [None]:
# 6) Combine with actual vent usage by hour
resp_stitched_final = resp_stitched_final[resp_stitched_final['encounter_block'].isin(all_ids['encounter_block'])].copy()
resp_stitched_final['recorded_date'] = resp_stitched_final['recorded_dttm'].dt.date
resp_stitched_final['recorded_hour'] = resp_stitched_final['recorded_dttm'].dt.hour

In [None]:
# Forward fill tracheostomy within each encounter_block BEFORE hourly aggregation
print("Forward filling tracheostomy within encounter blocks...")

# Sort data properly
resp_stitched_final = resp_stitched_final.sort_values(['encounter_block', 'recorded_dttm'])

# Forward fill tracheostomy within each encounter_block
resp_stitched_final['tracheostomy_filled'] = (
    resp_stitched_final.groupby('encounter_block')['tracheostomy']
    .transform(lambda x: x.ffill())
)

# Fill any remaining NaN values with 0 (no trach)
resp_stitched_final['tracheostomy_filled'] = resp_stitched_final['tracheostomy_filled'].fillna(0)

# Show the impact
before_blocks = resp_stitched_final[resp_stitched_final['tracheostomy'] == 1]['encounter_block'].nunique()
after_blocks = resp_stitched_final[resp_stitched_final['tracheostomy_filled'] == 1]['encounter_block'].nunique()

print(f"Blocks with trach (before forward fill): {before_blocks}")
print(f"Blocks with trach (after forward fill): {after_blocks}")

In [None]:
hourly_vent_block = resp_stitched_final.groupby(['encounter_block','recorded_date','recorded_hour']).agg(
    min_fio2_set=('fio2_set','min'),
    max_fio2_set=('fio2_set','max'),
    min_peep_set=('peep_set','min'),
    max_peep_set=('peep_set','max'),
    min_lpm_set=('lpm_set', 'min'),
    max_lpm_set=('lpm_set', 'max'),
    min_resp_rate_obs=('resp_rate_obs', 'min'),
    max_resp_rate_obs=('resp_rate_obs', 'max'),
    hourly_trach=('tracheostomy_filled', 'max'), # 1 if any value within that hour is 1
    hourly_on_vent=('on_vent','max'),
).reset_index()

In [None]:
# Sanity check- Find encounter_blocks that are in hourly_seq_block but not in hourly_vent_block and vice versa
# This is possible when the patient is put on IMV in the ED, and dies shortly after. 
# Still might be worth exploring the trajectory for these patients 
seq_blocks = set(hourly_seq_block['encounter_block'].unique())
vent_blocks = set(hourly_vent_block['encounter_block'].unique())

blocks_in_seq_not_vent = seq_blocks - vent_blocks
blocks_in_vent_not_seq = vent_blocks - seq_blocks

print("Blocks in hourly_seq_block but not in hourly_vent_block:", len(blocks_in_seq_not_vent))
if len(blocks_in_seq_not_vent) > 0:
    print(sorted(list(blocks_in_seq_not_vent)))
print("\nBlocks in hourly_vent_block but not in hourly_seq_block:", len(blocks_in_vent_not_seq))

In [None]:
## We want all hours from hourly_seq block, and 
# any extra hours from hourly_vent_block that occur after the last hour in hourly_seq_block
# This is to ensure that we capture all the hours of ventilation, even if they are not in the hourly_seq_block

# Step 1: Reconstruct timestamps
hourly_seq_block['recorded_dttm'] = pd.to_datetime(hourly_seq_block['recorded_date']) + pd.to_timedelta(hourly_seq_block['recorded_hour'], unit='h')
hourly_vent_block['recorded_dttm'] = pd.to_datetime(hourly_vent_block['recorded_date']) + pd.to_timedelta(hourly_vent_block['recorded_hour'], unit='h')

# Step 2: Get max scaffold time per encounter
max_times = (
    hourly_seq_block.groupby('encounter_block')['recorded_dttm']
    .max().reset_index()
    .rename(columns={'recorded_dttm': 'max_seq_dttm'})
)

# Step 3: Identify extra vent rows beyond scaffold
vent_plus_max = pd.merge(hourly_vent_block, max_times, on='encounter_block', how='left')

extra_rows = vent_plus_max[
    vent_plus_max['recorded_dttm'] > vent_plus_max['max_seq_dttm']
].copy()

# Step 4: Create gap-filler rows for each encounter with extra data
gap_rows = []
for enc_id, group in extra_rows.groupby('encounter_block'):
    max_time = pd.to_datetime(
        max_times.loc[max_times['encounter_block'] == enc_id, 'max_seq_dttm'].values[0]
    )
    first_extra_time = group['recorded_dttm'].min()
    
    # Skip if there's no gap
    if first_extra_time <= max_time + timedelta(hours=1):
        continue

    # Fill hourly timestamps between scaffold end and first extra
    gap_times = pd.date_range(
        start=max_time + timedelta(hours=1),
        end=first_extra_time - timedelta(hours=1),
        freq='H'
    )

    for dt in gap_times:
        gap_rows.append({
            'encounter_block': enc_id,
            'recorded_date': dt.date(),
            'recorded_hour': dt.hour,
            'recorded_dttm': dt
        })

# Convert to DataFrame
gap_df = pd.DataFrame(gap_rows)

# Step 5: Add all required columns to gap_df, using NA defaults
missing_cols = set(hourly_vent_block.columns) - set(gap_df.columns)
for col in missing_cols:
    gap_df[col] = np.nan

# Ensure column order matches
gap_df = gap_df[hourly_vent_block.columns]

# Step 6: Get scaffold rows with vent info via left join
scaffold_df = pd.merge(
    hourly_seq_block.drop(columns='recorded_dttm'),
    hourly_vent_block.drop(columns='recorded_dttm'),
    on=['encounter_block', 'recorded_date', 'recorded_hour'],
    how='left'
)

gap_df = gap_df.drop(columns='recorded_dttm', errors='ignore')
extra_rows = extra_rows.drop(columns='recorded_dttm', errors='ignore')
extra_rows = extra_rows.drop(columns='max_seq_dttm', errors='ignore')
# Step 7: Combine all three
final_df_block = pd.concat([scaffold_df, gap_df, extra_rows], ignore_index=True)

# Step 8: Sort
final_df_block = final_df_block.sort_values(
    by=['encounter_block', 'recorded_date', 'recorded_hour']
).reset_index(drop=True)

# Step 9: Add time_from_vent
final_df_block['time_from_vent'] = final_df_block.groupby('encounter_block').cumcount()
final_df_block['time_from_vent_adjusted'] = np.where(
    final_df_block['time_from_vent'] < 4, -1, final_df_block['time_from_vent'] - 4
)

# arrange columns as 'encounter_block', 'recorded_date', 'recorded_hour' 'time_from_vent' 'time_from_vent_adjusted' and then the rest
cols = ['encounter_block', 'recorded_date', 'recorded_hour', 'time_from_vent', 'time_from_vent_adjusted']
cols += [col for col in final_df_block.columns if col not in cols]
final_df_block = final_df_block[cols]

print("Final shape:", final_df_block.shape)
print("Unique encounter_blocks:", final_df_block['encounter_block'].nunique())

In [None]:
# 7) Count how many vent hours per block in the first 72 hours after first intubation,
#  Exclude <4 hours on vent in first 72 hours at block level- They cannot meaningfully be studied for early mobilization if they’re barely intubated.. including them could bias results
first_72_hours = final_df_block[(final_df_block['time_from_vent'] >= 0) & (final_df_block['time_from_vent'] < 72)]

# forward fill the hourly_on_vent column in first_72_hours
first_72_hours['hourly_on_vent'] = first_72_hours['hourly_on_vent'].ffill()
first_72_hours['hourly_trach'] = first_72_hours['hourly_trach'].ffill()
vent_hours_per_block = first_72_hours.groupby('encounter_block')['hourly_on_vent'].sum()

In [None]:
# Exclude blocks with imv for 4 hours or less
blocks_under_4 = vent_hours_per_block[vent_hours_per_block < 4].index
blocks_under_4_df = final_df_block[final_df_block['encounter_block'].isin(blocks_under_4)]
final_df_block = final_df_block[~final_df_block['encounter_block'].isin(blocks_under_4)]

strobe_counts['G_blocks_with_vent_4_or_more'] = final_df_block['encounter_block'].nunique()
strobe_counts['G_blocks_with_vent_less_than_4'] = len(blocks_under_4)
print(f"Unique encounter blocks with valid IMV start/end: {strobe_counts['G_blocks_with_vent_4_or_more']}")
print(f"Excluded {len(blocks_under_4)} encounter blocks with <4 vent hours in first 72 hours of intubation.\n")

In [None]:
# 8) Exclude blocks with trach at the time of intubation
# Check for trach at time of intubation (time_from_vent = 0)
blocks_with_trach_at_intubation = final_df_block[
    (final_df_block['time_from_vent'] == 0) & 
    (final_df_block['hourly_trach'] == 1)
]['encounter_block'].unique()

print(f"Blocks with trach at intubation: {len(blocks_with_trach_at_intubation)}")

# Exclude these blocks
final_df_block = final_df_block[
    ~final_df_block['encounter_block'].isin(blocks_with_trach_at_intubation)
]

# Update STROBE counts
strobe_counts['G_final_blocks_with_trach_at_intubation'] = len(blocks_with_trach_at_intubation)
strobe_counts['G_final_blocks_without_trach_at_intubation'] = final_df_block['encounter_block'].nunique()

print(f"Excluded {len(blocks_with_trach_at_intubation)} blocks with trach at intubation")
print(f"Final cohort size: {strobe_counts['G_final_blocks_without_trach_at_intubation']}")

In [None]:
strobe_counts

In [None]:
all_ids = all_ids[all_ids['encounter_block'].isin(final_df_block['encounter_block'])]
all_ids.shape

In [None]:
final_df = pd.merge(
    final_df_block,
    all_ids,
    on='encounter_block',
    how='left'
).reindex(columns=[
    'encounter_block', 'recorded_date', 'recorded_hour',
    'time_from_vent', 'time_from_vent_adjusted', 
    'min_fio2_set', 'max_fio2_set', 'min_peep_set', 'max_peep_set',
    'min_lpm_set', 'max_lpm_set', 'min_resp_rate_obs', 'max_resp_rate_obs',
    'hourly_trach', 'hourly_on_vent'
])

In [None]:
# Check for duplicates
key_cols = ['encounter_block', 'recorded_date', 'recorded_hour']
duplicates = final_df.duplicated(subset=key_cols).sum()
print(f"Number of duplicate rows: {duplicates}")

In [None]:
all_ids = all_ids[all_ids['encounter_block'].isin(final_df['encounter_block'])]
all_ids.shape

In [None]:
for col in all_ids.columns[:3]:
    print(f"\n{col}:")
    print(all_ids[col].nunique())

#### (F) Add final outcome dttm

Calculate final outcome dttm for each encounter block using last vital recorded dttm and discharge disposition.   

To get the `final_outcome_dttm`, we use the `block_last_vital_dttm`. Added a `is_dead` flag when `discharge_category` == `Expired` or `Hospice`. 


In [None]:
# 1) Merge `all_ids` (patient_id, hospitalization_id, encounter_block)
#    with final blocks DataFrame (which has block-level columns -  	block_vent_start_dttm,	block_vent_end_dttm block_first_vital_dttm, block_last_vital_dttm, ).
all_ids_w_outcome = pd.merge(
    all_ids,
    final_blocks,           
    on='encounter_block',
    how='left'
)

# 2) Merge with patient table to get death_dttm
all_ids_w_outcome = pd.merge(
    all_ids_w_outcome,
    patient[['patient_id','death_dttm']],
    on='patient_id',
    how='left'
)

# Use block_last_vital_dttm as the final_outcome_dttm
all_ids_w_outcome['final_outcome_dttm'] = all_ids_w_outcome['block_last_vital_dttm']

# Add is_dead flag based on discharge_category
all_ids_w_outcome['is_dead'] = (all_ids_w_outcome['discharge_category'].str.lower().isin(['expired', 'hospice'])).astype(int)

# Handle case where death_dttm is less than discharge_dttm, final outcome should be death_dttm
# and is_dead should be 1
mask_death_before_discharge = all_ids_w_outcome['death_dttm'] < all_ids_w_outcome['discharge_dttm']
all_ids_w_outcome.loc[mask_death_before_discharge, 'final_outcome_dttm'] = all_ids_w_outcome['death_dttm']
all_ids_w_outcome.loc[mask_death_before_discharge, 'is_dead'] = 1

In [None]:
# SANITY CHECK- check blocks where death_dttm is before block_last_vital_dttm
## For this project, we used bloack_last_vital_dttm as the final_outcome_dttm to circumvent possible issues 
mask_death_before_vitals = (all_ids_w_outcome['death_dttm'].notna()) & (all_ids_w_outcome['death_dttm'] < all_ids_w_outcome['block_last_vital_dttm'])
print("Number of blocks where death_dttm is before block_last_vital_dttm:", mask_death_before_vitals.sum())
print("\nExample cases:")
death_before_vitals_df = all_ids_w_outcome[mask_death_before_vitals][['patient_id', 'hospitalization_id', 'encounter_block', 'death_dttm', 'block_last_vital_dttm', 'final_outcome_dttm']]

# Calculate the difference in hours between death_dttm and block_last_vital_dttm
death_before_vitals_df['diff_hour'] = (death_before_vitals_df['death_dttm'] - death_before_vitals_df['block_last_vital_dttm']).dt.total_seconds() / 3600

In [None]:
for col in all_ids_w_outcome.columns[:3]:
    print(f"\n{col}:")
    print(all_ids_w_outcome[col].nunique())

## Hourly Vitals

In [None]:
## get height , weight to calculate bmi
# Filter vitals to include only height and weight
vitals_bmi = vitals_stitched[
    (vitals_stitched['vital_category'].isin(['weight_kg', 'height_cm'])) &
    (vitals_stitched['encounter_block'].isin(all_ids_w_outcome['encounter_block']))
].copy()

# Remove outliers
# Extract the min/max from the config
min_height, max_height = outlier_cfg['height_cm']
min_weight, max_weight = outlier_cfg['weight_kg']

# For height rows: set out-of-range to NaN
is_height = vitals_bmi['vital_category'] == 'height_cm'
height_mask_low  = is_height & (vitals_bmi['vital_value'] < min_height)
height_mask_high = is_height & (vitals_bmi['vital_value'] > max_height)
vitals_bmi.loc[height_mask_low | height_mask_high, 'vital_value'] = np.nan

# For weight rows: set out-of-range to NaN
is_weight = vitals_bmi['vital_category'] == 'weight_kg'
weight_mask_low  = is_weight & (vitals_bmi['vital_value'] < min_weight)
weight_mask_high = is_weight & (vitals_bmi['vital_value'] > max_weight)
vitals_bmi.loc[weight_mask_low | weight_mask_high, 'vital_value'] = np.nan

# Merge with vent_start_end to get ventilation start time
vitals_bmi = vitals_bmi.merge(
    block_vent_times[['encounter_block','block_vent_start_dttm']],
    on='encounter_block',
    how='left'
)

# Calculate time difference between recorded_dttm and vent_start_time
vitals_bmi['time_diff'] = (vitals_bmi['recorded_dttm'] - vitals_bmi['block_vent_start_dttm']).dt.total_seconds() / 3600  # in hours

# Define whether measurement is before or after vent_start_time
vitals_bmi['before_vent_start'] = (vitals_bmi['time_diff'] <= 0).astype(int)

# Calculate absolute time difference
vitals_bmi['abs_time_diff'] = vitals_bmi['time_diff'].abs()

# Sort data to prioritize measurements before vent start and closest in time
vitals_bmi = vitals_bmi.sort_values(['encounter_block', 'vital_category', 'before_vent_start', 'abs_time_diff'], 
                                    ascending=[True, True, False, True])

# Drop duplicates to keep the closest measurement for each vital_category per encounter block
vitals_bmi = vitals_bmi.drop_duplicates(subset=['encounter_block', 'vital_category'], keep='first')

# Pivot to get height and weight per encounter block
vitals_bmi_pivot = vitals_bmi.pivot(index='encounter_block', 
                                    columns='vital_category', 
                                    values='vital_value'
                                    ).reset_index()

# Calculate BMI
vitals_bmi_pivot['bmi'] = vitals_bmi_pivot['weight_kg'] / ((vitals_bmi_pivot['height_cm'] / 100) ** 2)

print(f"Number of unique encounter blocks with BMI data: {vitals_bmi_pivot['encounter_block'].nunique()}")

In [None]:
# Extract 'recorded_date' and 'recorded_hour' from recorded_dttm
vitals_stitched['recorded_date'] = vitals_stitched['recorded_dttm'].dt.date
vitals_stitched['recorded_hour'] = vitals_stitched['recorded_dttm'].dt.hour
print(f"Number of unique encounter blocks BEFORE stitching vitals: {vitals_stitched['encounter_block'].nunique()}")
vitals_stitched = vitals_stitched[vitals_stitched['encounter_block'].isin(all_ids_w_outcome['encounter_block'])]
print(f"Number of unique encounter blocks AFTER stitching vitals: {vitals_stitched['encounter_block'].nunique()}")
strobe_counts['final_blocks_with_vitals'] = vitals_stitched['encounter_block'].nunique()

In [None]:
# Calculate MAP if it doesn't exist
# Calculate MAP, even if it exists in the data
vitals_stitched = vitals_stitched[vitals_stitched['vital_category'] != 'map']
if 'map' not in vitals_stitched['vital_category'].unique():
    # 1) Filter for sbp & dbp
    sbp_dbp = vitals_stitched[vitals_stitched['vital_category'].isin(['sbp','dbp'])].copy()
    
    # 2) Pivot at the encounter_block + recorded_dttm level
    sbp_dbp_pivot = sbp_dbp.pivot_table(
        index=['encounter_block','recorded_dttm'],
        columns='vital_category',
        values='vital_value'
    ).reset_index()
    
    # 3) Drop any row missing sbp or dbp
    sbp_dbp_pivot = sbp_dbp_pivot.dropna(subset=['sbp','dbp'])
    
    # 4) Calculate MAP
    sbp_dbp_pivot['map'] = (sbp_dbp_pivot['sbp'] + 2*sbp_dbp_pivot['dbp']) / 3
    
    # 5) Build a DataFrame for map
    map_vitals = sbp_dbp_pivot[['encounter_block','recorded_dttm','map']].copy()
    map_vitals['vital_category'] = 'map'
    map_vitals['vital_value'] = map_vitals['map']
    
    # Also add recorded_date/hour
    map_vitals['recorded_date'] = map_vitals['recorded_dttm'].dt.date
    map_vitals['recorded_hour'] = map_vitals['recorded_dttm'].dt.hour
    
    # Keep only the needed columns
    map_vitals = map_vitals[[
        'encounter_block','recorded_dttm','recorded_date','recorded_hour','vital_category','vital_value'
    ]]
    
    # 6) Append 'map' to the main vitals_stitched DataFrame
    vitals_stitched = pd.concat([vitals_stitched, map_vitals], ignore_index=True)
    print("...map was calculated and appended to vitals_stitched.")
else:
    print("Map exists in your CLIF database")

In [None]:
#Compute min/max vitals  at the BLOCK level
# group by encounter_block + recorded_date + recorded_hour + vital_category
vitals_min_max = vitals_stitched.groupby(
    ['encounter_block', 'recorded_date', 'recorded_hour', 'vital_category']
).agg(
    min_val=('vital_value', 'min'),
    max_val=('vital_value', 'max'),
    avg_val=('vital_value', 'mean')
).reset_index()

# 2) Pivot to get one row per (encounter_block, date, hour), columns for min/max/avg of each vital
vitals_pivot = vitals_min_max.pivot_table(
    index=['encounter_block', 'recorded_date', 'recorded_hour'],
    columns='vital_category',
    values=['min_val', 'max_val', 'avg_val']
).reset_index()

# 3) Flatten multi-level columns like ('min_val', 'sbp') -> 'min_sbp'
vitals_pivot.columns = [
    '_'.join(col).rstrip('_') if isinstance(col, tuple) else col
    for col in vitals_pivot.columns
]

# 4) Clean up prefixes to be 'min_', 'max_', and 'avg_'
rename_dict = {}
for c in vitals_pivot.columns:
    if c.startswith('min_val_'):
        rename_dict[c] = c.replace('min_val_', 'min_')
    elif c.startswith('max_val_'):
        rename_dict[c] = c.replace('max_val_', 'max_')
    elif c.startswith('avg_val_'):
        rename_dict[c] = c.replace('avg_val_', 'avg_')

vitals_pivot = vitals_pivot.rename(columns=rename_dict)

print("Finished creating block-level min/max/avg vitals pivot:")
vitals_pivot.columns

In [None]:
## confirm duplicates don't exist
checkpoint_vitals = pyCLIF.remove_duplicates(vitals_pivot, [
    'encounter_block','recorded_date', 'recorded_hour'
], 'final_df')
del checkpoint_vitals

In [None]:
# merge vitals with final_df
final_df = pd.merge(final_df, vitals_pivot, on=['encounter_block', 'recorded_date', 'recorded_hour'], 
                   how='left')
print("\n Columns in final_df after merging with vitals:")
final_df.columns

## Hourly Meds

* Handle med dose unit conversion for all vasoactives
* Calculate NE equivalent levels using "norepinephrine", "epinephrine", "phenylephrine", "vasopressin", "dopamine",  "angiotensin"
* Create flags for "nicardipine", "nitroprusside", "clevidipine" for the red criteria under consensus criteria
* Identify encounters on paralytics - cisatracurium, vecuronium, rocuronium- and create flags for each of these paralytic meds. These patients will not be considered eligible for mobilization during the hour they were receiving paralytic medication. 


In [None]:
# Import clif continuous meds for the cohort on vent during the required time period
meds_filters = {
    'hospitalization_id': all_ids['hospitalization_id'].unique().tolist(),
    'med_category': meds_of_interest
}
meds = pyCLIF.load_data('clif_medication_admin_continuous', columns=meds_required_columns, filters=meds_filters)
meds = meds.merge(all_ids, on='hospitalization_id', how='left')
print("Unique encounters in meds", pyCLIF.count_unique_encounters(meds))

In [None]:
# ensure correct format
meds['hospitalization_id']= meds['hospitalization_id'].astype(str)
meds['med_dose_unit'] = meds['med_dose_unit'].str.lower()
meds = pyCLIF.convert_datetime_columns_to_site_tz(meds,  pyCLIF.helper['timezone'])
meds['med_dose'] = pd.to_numeric(meds['med_dose'], errors='coerce')
# Create 'date' and 'hour_of_day' columns
meds['recorded_date'] = meds['admin_dttm'].dt.date
meds['recorded_hour'] = meds['admin_dttm'].dt.hour

In [None]:
# Create a summary table for each med_category
summary_meds= meds.groupby('med_category').agg(
    total_N=('med_category', 'size'),
    min=('med_dose', 'min'),
    max=('med_dose', 'max'),
    first_quantile=('med_dose', lambda x: x.quantile(0.25)),
    second_quantile=('med_dose', lambda x: x.quantile(0.5)),
    third_quantile=('med_dose', lambda x: x.quantile(0.75)),
    missing_values=('med_dose', lambda x: x.isna().sum())
).reset_index()

summary_meds.to_csv('../output/final/summary_meds_by_category.csv', index=False)

In [None]:
# Create a summary table for each med_category and med_dose_unit combination
summary_meds_cat_dose= meds.groupby(['med_category', 'med_dose_unit']).agg(
    total_N=('med_category', 'size'),
    min=('med_dose', 'min'),
    max=('med_dose', 'max'),
    first_quantile=('med_dose', lambda x: x.quantile(0.25)),
    second_quantile=('med_dose', lambda x: x.quantile(0.5)),
    third_quantile=('med_dose', lambda x: x.quantile(0.75)),
    missing_values=('med_dose', lambda x: x.isna().sum())
).reset_index()
summary_meds_cat_dose.to_csv('../output/final/summary_meds_by_category_dose_units.csv', index=False)
## check the distrbituon of required continuous meds

In [None]:
# Diagnostic: Check which groups have all NaN values
print("Groups with all NaN med_dose values:")
for (med_category, med_dose_unit), group in meds.groupby(['med_category', 'med_dose_unit']):
    if group['med_dose'].isna().all():
        print(f"  {med_category} - {med_dose_unit}: {len(group)} rows, all NaN")

In [None]:
# Group by med_category and med_dose_unit
grouped_data = meds.groupby(['med_category', 'med_dose_unit'])

# Dynamically determine the number of required subplots
n_plots = len(grouped_data.groups.keys())
n_cols = 4
n_rows = (n_plots + n_cols - 1) // n_cols  # Round up to determine rows

fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, n_rows * 5))

# Flatten the axs array for easier indexing
axs = axs.flatten()

# Loop through each group and plot the histogram
for i, ((med_category, med_dose_unit), group) in enumerate(grouped_data):
    ax = axs[i]
    
    # Filter out NaN values
    valid_doses = group['med_dose'].dropna()
    
    if len(valid_doses) > 0:
        ax.hist(valid_doses, bins=20, alpha=0.7, label=f"N = {len(valid_doses)}")
        ax.set_title(f"{med_category} - {med_dose_unit}")
        ax.set_xlabel('Med Dose')
        ax.set_ylabel('Frequency')
        ax.legend()
        ax.grid(True)
    else:
        # Handle case where all values are NaN
        ax.text(0.5, 0.5, 'No valid data', ha='center', va='center', transform=ax.transAxes)
        ax.set_title(f"{med_category} - {med_dose_unit} (No Data)")

# Hide any unused axes
for j in range(i + 1, len(axs)):
    axs[j].axis('off')

plt.tight_layout()
plt.savefig('../output/final/graphs/meds_histograms.png')
plt.close(fig)

In [None]:
# SANITY CHECKS- Check the med_dose_unit for each med_category in the meds table
med_dose_unit_check = meds.groupby(['med_category', 'med_dose_unit']).size().reset_index(name='count')

# Apply the function to the DataFrame
med_dose_unit_check['unit_validity'] = med_dose_unit_check.apply(pyCLIF.check_dose_unit, axis=1)

# # Optional: Filter for invalid units
invalid_units = med_dose_unit_check[med_dose_unit_check['unit_validity'] == 'Not an acceptable unit']
print("Invalid units. These will be dropped:\n")
print(invalid_units)

In [None]:
# ## Norepinephrine equivalent calculation
# Goradia S, Sardaneh AA, Narayan SW, Penm J, Patanwala AE. Vasopressor dose equivalence: 
# A scoping review and suggested formula. J Crit Care. 2021 Feb;61:233-240. doi: 10.1016/j.jcrc.2020.11.002. Epub 2020 Nov 14. PMID: 33220576.

# Filter meds to include only rows with '/hr' or '/min' in 'med_dose_unit'
meds_filtered = meds[~meds['med_dose'].isnull()].copy()
meds_filtered = meds_filtered[meds_filtered['med_dose_unit'].apply(pyCLIF.has_per_hour_or_min)].copy()

In [None]:
meds_list = [
    "norepinephrine", "epinephrine", "phenylephrine", 
    "vasopressin", "dopamine",  
    "angiotensin"
]

# **2. Convert Medication Doses to Required Units**
ne_df = meds_filtered[meds_filtered['med_category'].isin(meds_list)].copy()
# Merge weight_kg into meds_filtered (assuming 'vitals_bmi_pivot' is available)
ne_df = ne_df.merge(vitals_bmi_pivot[['encounter_block', 'weight_kg']], on='encounter_block', how='left')
ne_df["med_dose_converted"] = ne_df.apply(pyCLIF.convert_dose, axis=1)

# Filter doses within acceptable ranges
ne_df = ne_df[ne_df.apply(pyCLIF.is_dose_within_range, axis=1, args=(outlier_cfg,))].copy()

In [None]:
# **4. Flag Medications Not in the Dataset**
for med in meds_list:
    if med not in ne_df['med_category'].unique():
        print(f"{med} is not in the dataset.")
    else:
        print(f"{med} is in the dataset.")

In [None]:
# Pivot and Aggregate the Data**
# Group and aggregate doses
group_cols = ['encounter_block', 'recorded_date', 'recorded_hour', 'med_category']
dose_agg = ne_df.groupby(group_cols)['med_dose_converted'].agg(['min', 'max', 'first', 'last']).reset_index()

# Pivot to have medications as columns
dose_pivot_min   = dose_agg.pivot_table(index=['encounter_block', 'recorded_date', 'recorded_hour'], columns='med_category', values='min').reset_index()
dose_pivot_max   = dose_agg.pivot_table(index=['encounter_block', 'recorded_date', 'recorded_hour'], columns='med_category', values='max').reset_index()
dose_pivot_first = dose_agg.pivot_table(index=['encounter_block', 'recorded_date', 'recorded_hour'], columns='med_category', values='first').reset_index()
dose_pivot_last  = dose_agg.pivot_table(index=['encounter_block', 'recorded_date', 'recorded_hour'], columns='med_category', values='last').reset_index()

# Rename columns to indicate min and max
dose_pivot_min.columns   = ['encounter_block', 'recorded_date', 'recorded_hour'] + ['min_'   + col for col in dose_pivot_min.columns if col not in ['encounter_block', 'recorded_date', 'recorded_hour']]
dose_pivot_max.columns   = ['encounter_block', 'recorded_date', 'recorded_hour'] + ['max_'   + col for col in dose_pivot_max.columns if col not in ['encounter_block', 'recorded_date', 'recorded_hour']]
dose_pivot_first.columns = ['encounter_block', 'recorded_date', 'recorded_hour'] + ['first_' + col for col in dose_pivot_first.columns if col not in ['encounter_block', 'recorded_date', 'recorded_hour']]
dose_pivot_last.columns  = ['encounter_block', 'recorded_date', 'recorded_hour'] + ['last_'  + col for col in dose_pivot_last.columns if col not in ['encounter_block', 'recorded_date', 'recorded_hour']]

# Merge min and max DataFrames
dose_pivot = pyCLIF.merge_multiple_dfs(dose_pivot_min, dose_pivot_max, dose_pivot_first, dose_pivot_last,
                              on=['encounter_block', 'recorded_date', 'recorded_hour'],
                              how='outer')

# **6. Calculate Norepinephrine Equivalents**

# Replace NaN with 0 for calculations
dose_pivot.fillna(0, inplace=True)

# Calculate NE min
dose_pivot['ne_calc_min'] = (
    dose_pivot.get('min_norepinephrine', 0) +
    dose_pivot.get('min_epinephrine', 0) +
    dose_pivot.get('min_phenylephrine', 0) / 10 +
    dose_pivot.get('min_dopamine', 0) / 100 +
    dose_pivot.get('min_metaraminol', 0) / 8 +
    dose_pivot.get('min_vasopressin', 0) * 2.5 +
    dose_pivot.get('min_angiotensin', 0) * 10
)

# Calculate NE max
dose_pivot['ne_calc_max'] = (
    dose_pivot.get('max_norepinephrine', 0) +
    dose_pivot.get('max_epinephrine', 0) +
    dose_pivot.get('max_phenylephrine', 0) / 10 +
    dose_pivot.get('max_dopamine', 0) / 100 +
    dose_pivot.get('max_metaraminol', 0) / 8 +
    dose_pivot.get('max_vasopressin', 0) * 2.5 +
    dose_pivot.get('max_angiotensin', 0) * 10
)

# Calculate NE first
dose_pivot['ne_calc_first'] = (
    dose_pivot.get('first_norepinephrine', 0) +
    dose_pivot.get('first_epinephrine', 0) +
    dose_pivot.get('first_phenylephrine', 0) / 10 +
    dose_pivot.get('first_dopamine', 0) / 100 +
    dose_pivot.get('first_metaraminol', 0) / 8 +
    dose_pivot.get('first_vasopressin', 0) * 2.5 +
    dose_pivot.get('first_angiotensin', 0) * 10
)

# Calculate NE last
dose_pivot['ne_calc_last'] = (
    dose_pivot.get('last_norepinephrine', 0) +
    dose_pivot.get('last_epinephrine', 0) +
    dose_pivot.get('last_phenylephrine', 0) / 10 +
    dose_pivot.get('last_dopamine', 0) / 100 +
    dose_pivot.get('last_metaraminol', 0) / 8 +
    dose_pivot.get('last_vasopressin', 0) * 2.5 +
    dose_pivot.get('last_angiotensin', 0) * 10
)

# **7. Prepare the Final Dataset**
# Keep only the required columns
ne_calc_df = dose_pivot[['encounter_block', 'recorded_date', 
                         'recorded_hour', 
                         'ne_calc_min', 'ne_calc_max', 
                         'ne_calc_first', 'ne_calc_last']].drop_duplicates(subset=['encounter_block', 'recorded_date', 'recorded_hour'])

In [None]:
strobe_counts['final_blocks_with_norepi_eq'] = ne_calc_df['encounter_block'].nunique()

In [None]:
encounter_blocks_list = ne_df['encounter_block'].unique().tolist()
hourly_ne = pyCLIF.build_meds_hourly_scaffold(
    ne_df,
    id_col="encounter_block",      # column to group by
    ids=encounter_blocks_list,     # Iterable of id_col to keep
    timestamp_col="admin_dttm",    # change if your column is named differently
    site_tz=pyCLIF.helper['timezone']          # change to the zone you need
)

In [None]:
# Ensure the DataFrame is sorted by ID and 'time_from_vent'
ne_calc_df = ne_calc_df.sort_values(by=['encounter_block', 'recorded_date', 'recorded_hour'])
# Merge the norepinephrine equivalent DataFrame with the hourly norepinephrine DataFrame
hourly_ne_merged = pd.merge(
    hourly_ne,
    ne_calc_df,
    on=['encounter_block', 'recorded_date', 'recorded_hour'],
    how='left'
)
# Fill forward the specified columns
cols_to_fill = ['ne_calc_min', 'ne_calc_max', 'ne_calc_first', 'ne_calc_last']
hourly_ne_merged[cols_to_fill] = hourly_ne_merged[cols_to_fill].fillna(method='ffill')

In [None]:
def add_last_ne_6h(group: pd.DataFrame) -> pd.DataFrame:
    """
    For one encounter_block add/overwrite the column
    `last_ne_dose_last_6_hours` with the `ne_calc_last` value that
    occurred **exactly six hours earlier**.  If that row does not
    exist (e.g. the first <6 hours of the stay) the value is 0.
    """
    group['last_ne_dose_last_6_hours'] = (
        group['ne_calc_last']
        .shift(6)           # value 6 rows (hours) ago
        .fillna(0)          # treat “no record” as 0
    )
    return group

hourly_ne_merged = (
    hourly_ne_merged
      .groupby('encounter_block', group_keys=False)
      .apply(add_last_ne_6h)
      .reset_index(drop=True)
)

In [None]:
## confirm duplicates don't exist
checkpoint_meds = pyCLIF.remove_duplicates(hourly_ne_merged, [
    'encounter_block','recorded_date', 'recorded_hour'
], 'final_df')
del checkpoint_meds

In [None]:
print("final_df shape before merging", final_df.shape)
final_df = pyCLIF.extend_hourly_dataset(
    base_df=final_df,
    addon_df=hourly_ne_merged,
    merge_cols=['encounter_block', 'recorded_date', 'recorded_hour']
)
print("final_df shape after merging", final_df.shape)

In [None]:
red_meds_list = [
    "nicardipine", "nitroprusside", "clevidipine"
]

# Filter meds_filtered for the medications in red_meds_list
red_meds_df = meds[meds['med_category'].isin(red_meds_list)].copy()

# Create a flag for each medication in red_meds_list
for med in red_meds_list:
    # Create a flag that is 1 if the medication was administered in that hour, 0 otherwise
    red_meds_df[med + '_flag'] = np.where((red_meds_df['med_category'] == med) & 
                                         (red_meds_df['med_dose'] > 0.0) & 
                                         (red_meds_df['med_dose'].notna()), 1, 0).astype(int)

# Aggregate to get the maximum value for each flag (per hospitalization_id, recorded_date, recorded_hour)
# This ensures that if the medication was administered even once in the hour, the flag is 1
red_meds_flags = red_meds_df.groupby(['encounter_block', 'recorded_date', 'recorded_hour']).agg(
    {med + '_flag': 'max' for med in red_meds_list}
).reset_index()

#  combine all flags into a single 'red_meds_flag', you can do so like this:
red_meds_flags['red_meds_flag'] = red_meds_flags[[med + '_flag' for med in red_meds_list]].max(axis=1)

# Select the relevant columns
red_meds_flags_final = red_meds_flags[[
    'encounter_block', 'recorded_date', 'recorded_hour',
    'nicardipine_flag', 'nitroprusside_flag',
    'clevidipine_flag', 'red_meds_flag'
]].drop_duplicates(subset=['encounter_block', 'recorded_date', 'recorded_hour'])

red_meds_flags_final['nicardipine_flag'] = red_meds_flags_final['nicardipine_flag'].astype(int)
red_meds_flags_final['nitroprusside_flag'] = red_meds_flags_final['nitroprusside_flag'].astype(int)
red_meds_flags_final['clevidipine_flag'] = red_meds_flags_final['clevidipine_flag'].astype(int)
red_meds_flags_final['red_meds_flag'] = red_meds_flags_final['red_meds_flag'].astype(int)

In [None]:
strobe_counts['final_blocks_with_red_meds'] = red_meds_flags_final['encounter_block'].nunique()

In [None]:
## confirm duplicates don't exist
checkpoint_red_meds = pyCLIF.remove_duplicates(red_meds_flags_final, [
    'encounter_block','recorded_date', 'recorded_hour'
], 'final_df')
del checkpoint_red_meds

In [None]:
print("final_df shape before merging", final_df.shape)
final_df = pyCLIF.extend_hourly_dataset(
    base_df=final_df,
    addon_df=red_meds_flags_final,
    merge_cols=['encounter_block', 'recorded_date', 'recorded_hour']
)
print("final_df shape after merging", final_df.shape)
print("\n Columns in final_df after merging with red_meds_flags_final columns:")
final_df.columns

In [None]:
paralytics_list = [
    "cisatracurium", "vecuronium", "rocuronium" 
]

# Filter meds_filtered for the medications in paralytics_list
paralytics_df = meds[meds['med_category'].isin(paralytics_list)].copy()

# Create a flag for each medication in paralytics_list
for med in paralytics_list:
    # Create a flag that is 1 if the medication was administered in that hour, 0 otherwise
    paralytics_df[med + '_flag'] = np.where((paralytics_df['med_category'] == med) & 
                                           (paralytics_df['med_dose'] > 0.0) &
                                           (paralytics_df['med_dose'].notna()), 1, 0).astype(int)

# Aggregate to get the maximum value for each flag (per hospitalization_id, recorded_date, recorded_hour)
# This ensures that if the medication was administered even once in the hour, the flag is 1
paralytics_flags = paralytics_df.groupby(['encounter_block', 'recorded_date', 'recorded_hour']).agg(
    {med + '_flag': 'max' for med in paralytics_list}
).reset_index()

#  combine all flags into a single 'paralytics_flag', you can do so like this:
paralytics_flags['paralytics_flag'] = paralytics_flags[[med + '_flag' for med in paralytics_list]].max(axis=1)

# Select the relevant columns
paralytics_flags_final = paralytics_flags[[
    'encounter_block', 'recorded_date', 'recorded_hour',
    'cisatracurium_flag', 'vecuronium_flag',
    'rocuronium_flag', 'paralytics_flag'
]].drop_duplicates(subset=['encounter_block', 'recorded_date', 'recorded_hour'])

paralytics_flags_final['cisatracurium_flag'] = paralytics_flags_final['cisatracurium_flag'].astype(int)
paralytics_flags_final['vecuronium_flag'] = paralytics_flags_final['vecuronium_flag'].astype(int)
paralytics_flags_final['rocuronium_flag'] = paralytics_flags_final['rocuronium_flag'].astype(int)
paralytics_flags_final['paralytics_flag'] = paralytics_flags_final['paralytics_flag'].astype(int)

In [None]:
strobe_counts['final_blocks_with_paralytics'] = paralytics_flags_final['encounter_block'].nunique()

In [None]:
## confirm duplicates don't exist
checkpoint_paralytics_meds = pyCLIF.remove_duplicates(paralytics_flags_final, [
    'encounter_block','recorded_date', 'recorded_hour'
], 'final_df')
del checkpoint_paralytics_meds

In [None]:
# final_df = pd.merge(final_df, 
#                     paralytics_flags_final, 
#                     on=['encounter_block', 'recorded_date', 'recorded_hour'], 
#                     how='left')

print("final_df shape before merging", final_df.shape)
final_df = pyCLIF.extend_hourly_dataset(
    base_df=final_df,
    addon_df=paralytics_flags_final,
    merge_cols=['encounter_block', 'recorded_date', 'recorded_hour']
)
print("final_df shape after merging", final_df.shape)

print("\n Columns in final_df after merging with paralytics columns:")
final_df.columns

## Hourly Labs

Get most recent lactate defined as closest lab result time to the start of first intubation event

In [None]:
# Import labs
labs_filters = {
    'hospitalization_id': all_ids['hospitalization_id'].unique().tolist(),
    'lab_category': labs_of_interest
}
labs = pyCLIF.load_data('clif_labs', columns=labs_required_columns, filters=labs_filters)
print("unique encounters in labs", pyCLIF.count_unique_encounters(labs))
labs['hospitalization_id']= labs['hospitalization_id'].astype(str)
labs = labs.merge(all_ids, on='hospitalization_id', how='left')
labs = labs.sort_values(by=['encounter_block', 'lab_result_dttm'])

In [None]:
strobe_counts['final_blocks_with_lactate_lab'] = labs['encounter_block'].nunique()

In [None]:
labs = pyCLIF.convert_datetime_columns_to_site_tz(labs, pyCLIF.helper['timezone'])
labs['lab_value_numeric'] = pd.to_numeric(labs['lab_value_numeric'], errors='coerce')
labs['recorded_hour'] = labs['lab_result_dttm'].dt.hour
labs['recorded_date'] = labs['lab_result_dttm'].dt.date

lactate_df = pd.merge(labs, block_vent_times, on='encounter_block', how='left')
lactate_df['time_since_vent_start_hours'] = (
    (lactate_df['lab_result_dttm'] - lactate_df['block_vent_start_dttm']).dt.total_seconds() / 3600
)

# Calculate the absolute time difference between lab_result_dttm and vent_start_time in hours
lactate_df['time_diff_hours'] = abs((lactate_df['lab_result_dttm'] - lactate_df['block_vent_start_dttm']).dt.total_seconds() / 3600)

# Sort by encounter_block, recorded_hour, and time_diff_hours to find the closest measurement to vent_start_time
lactate_df = lactate_df.sort_values(by=['encounter_block', 'recorded_date', 'recorded_hour', 'time_diff_hours'])

# Group by encounter_block and recorded_hour, and get the first row in each group (which is the closest measurement)
# closest lactate measurement is defined as closest to the vent_start_time in that hour.
# we keep the first recorded value in that hour 
closest_lactate_df = lactate_df.groupby(['encounter_block', 'recorded_date','recorded_hour']).first().reset_index()

labs_final = closest_lactate_df[['encounter_block', 'recorded_date', 'recorded_hour', 'lab_value_numeric']].copy()

# Rename the 'lab_value_numeric' column to 'lactate'
labs_final = labs_final.rename(columns={'lab_value_numeric': 'lactate'})

In [None]:
checkpoint_labs= pyCLIF.remove_duplicates(labs_final, [
    'encounter_block', 'recorded_date', 'recorded_hour'
], 'final_df')
del checkpoint_labs

In [None]:
# final_df = pd.merge(final_df, 
#                     labs_final, 
#                     on=['encounter_block', 'recorded_date', 'recorded_hour'], 
#                    how='left')

print("final_df shape before merging", final_df.shape)
final_df = pyCLIF.extend_hourly_dataset(
    base_df=final_df,
    addon_df=labs_final,
    merge_cols=['encounter_block', 'recorded_date', 'recorded_hour']
)
print("final_df shape after merging", final_df.shape)

In [None]:
final_df.columns

## SOFA

Calculate SOFA score for the first 24 hours since the start of first intubation episode

In [None]:
helper = pyCLIF.load_config()
tables_path= helper['tables_path']

sofa_input_df = all_ids_w_outcome[['encounter_block', 'block_vent_start_dttm']].copy()
sofa_input_df = sofa_input_df.rename(columns={'block_vent_start_dttm': 'start_dttm'})
sofa_input_df['stop_dttm'] = sofa_input_df['start_dttm'] + pd.Timedelta(hours=24)
id_mappings = all_ids_w_outcome[['encounter_block', 'hospitalization_id' ]].drop_duplicates()

sofa_df = sofa_score.compute_sofa(
            ids_w_dttm = sofa_input_df,          # id, start_dttm, end_dttm  (local time)
            tables_path = tables_path,
            use_hospitalization_id = False,         # or False + id_mapping (new id , hospitalization_id)
            id_mapping = id_mappings,              # first column should be your new id_variable, second column is hospitalization id
            helper_module = pyCLIF,                # ← your existing loader
            output_filepath = "../output/intermediate/sofa.parquet"
         )

In [None]:
final_df_blocks = sofa_df.merge(all_ids_w_outcome, on='encounter_block', how='left')
final_df_blocks = final_df_blocks.merge(hospitalization[['hospitalization_id', 'admission_dttm', 'age_at_admission']], 
                                      on='hospitalization_id', how='left')
final_df_blocks = final_df_blocks.merge(patient[['patient_id', 'race_category','ethnicity_category', 'sex_category','language_name']], 
                                      on='patient_id', how='left')

# First join ADT with all_ids to get closest ADT row to vent start
adt_with_blocks = pd.merge(
    all_ids_w_outcome[['encounter_block', 'block_vent_start_dttm', 'hospitalization_id']],
    adt,
    on='hospitalization_id'
)

# Calculate time difference between vent start and ADT in_dttm
adt_with_blocks['time_diff'] = abs(adt_with_blocks['block_vent_start_dttm'] - adt_with_blocks['in_dttm'])

# Get the closest ADT row for each encounter block
closest_adt = (adt_with_blocks
    .sort_values('time_diff')
    .groupby('encounter_block')
    .first()
    .reset_index()
)

# Join with final_df_blocks
final_df_blocks = final_df_blocks.merge(
    closest_adt[['encounter_block', 'location_name', 'location_category', 'in_dttm', 'out_dttm']],
    on='encounter_block',
    how='left'
)

final_df_blocks.columns

## Write analysis dataset 

In [None]:
final_df.to_parquet('../output/intermediate/final_df_hourly.parquet')
final_df_blocks.to_parquet('../output/intermediate/final_df_blocks.parquet')
all_ids_w_outcome.to_parquet('../output/intermediate/cohort_all_ids_w_outcome.parquet')
# Convert the dictionary to a DataFrame and save it as a CSV file
pd.DataFrame(list(strobe_counts.items()), columns=['Metric', 'Value']).to_csv('../output/final/strobe_counts.csv', index=False)

In [None]:
strobe_counts

In [None]:
## STROBE diagram
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, FancyArrowPatch
# Create figure
fig, ax = plt.subplots(figsize=(10, 10))
ax.axis('off')

# Box positions and texts
boxes = [
    {"text": f"All adult encounters after date filter\n(n = {strobe_counts['A_after_date_age_filter']})", "xy": (0.5, 0.9)},
    {"text": f"Linked Encounter Blocks\n(n = {strobe_counts['B_after_stitching']})", "xy": (0.5, 0.75)},
    {"text": f"Encounter blocks receiving IMV\n(n = {strobe_counts['C_imv_encounter_blocks']})", "xy": (0.5, 0.6)},
    {"text": f"Encounter blocks receiving IMV ≥ 4 hrs\n(n = {strobe_counts['G_blocks_with_vent_4_or_more']})", "xy": (0.5, 0.45)},
    {"text": f"Encounter blocks not on trach\n(n = {strobe_counts['G_final_blocks_without_trach_at_intubation']})", "xy": (0.5, 0.3)},
]

exclusions = [
    {"text": f"Linked hospitalizations\n(n = {strobe_counts['B_stitched_hosp_ids']})", "xy": (0.8, 0.825)},
    {"text": f"Excluded: Encounters on vent for <4 hrs\n(n = {strobe_counts['D_blocks_with_same_vent_start_end'] + strobe_counts['E_blocks_with_vent_end_before_vital_start'] + strobe_counts['G_blocks_with_vent_less_than_4']})", "xy": (0.8, 0.525)},
    {"text": f"Excluded: Encounters with Tracheostomy\n(n = {strobe_counts['G_final_blocks_with_trach_at_intubation']})", "xy": (0.8, 0.375)},
]

# Draw main boxes and arrows
for i, box in enumerate(boxes):
    x, y = box["xy"]
    ax.add_patch(Rectangle((x - 0.25, y - 0.05), 0.5, 0.1, edgecolor='black', facecolor='white'))
    ax.text(x, y, box["text"], ha='center', va='center', fontsize=10)
    if i < len(boxes) - 1:
        ax.add_patch(FancyArrowPatch((x, y - 0.05), (x, y - 0.1), arrowstyle='->', mutation_scale=15))

# Draw exclusion boxes and connectors
for excl in exclusions:
    x, y = excl["xy"]
    ax.add_patch(Rectangle((x - 0.20, y - 0.04), 0.38, 0.08, edgecolor='black', facecolor='#f8d7da'))
    ax.text(x, y, excl["text"], ha='center', va='center', fontsize=9)

plt.tight_layout()
plt.savefig(f'../output/final/graphs/strobe_diagram_{pyCLIF.helper["site_name"]}.png')
plt.close(fig)
print("Created STROBE diagram")